BERT-CRF 微调中文 NER 模型

文章目录

  • 数据集
  • 模型定义
  • 数据集预处理
    • BIO 标签转换
    • 自定义Dataset
    • 拆分训练、测试集
  • 训练
  • 验证、测试
  • 指标计算
  • 推理
  • 其它
    • 相关参数
    • CRF 模块

数据集

  • CLUE-NER数据集:https://github.com/CLUEbenchmark/CLUENER2020/blob/master/pytorch_version/README.md
    在这里插入图片描述

模型定义

import torch
import torch.nn as nn
from pytorch_crf import CRF
from transformers import BertPreTrainedModel, BertModelclass BertCrfForNer(BertPreTrainedModel):def __init__(self, config):super(BertCrfForNer, self).__init__(config)self.bert = BertModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, config.num_labels)self.crf = CRF(num_tags=config.num_labels, batch_first=True)self.num_labels = config.num_labelsself.init_weights()def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None,input_lens=None):outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)sequence_output = outputs[0]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output)outputs = (logits,)if labels is not None:loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)outputs =(-1*loss,)+outputsreturn outputs # (loss), scores

其中 CRF 模块 pytorch_crf.py 见后文。

数据集预处理

BIO 标签转换

ALLOW_LABEL = ["name", "organization", "address","company","government"]def generate_bio_tags(tokenizer, text_json, allowed_type = ALLOW_LABEL):def tokenize_with_location(tokenizer, input_data):encoded_input = tokenizer.encode_plus(input_data, return_offsets_mapping=True)return list(zip([tokenizer.decode(i) for i in  encoded_input.input_ids],encoded_input.offset_mapping))def get_bio_tag(labels, token_start, token_end):if token_start >= token_end:return "O"for entity_type, entities in labels.items():if entity_type in allowed_type:for entity_name, positions in entities.items():for position in positions:start, end = positionif token_start >= start and token_end <= end+1:if token_start == start:return f"B-{entity_type}"else:return f"I-{entity_type}"return "O"text = text_json["text"]labels = text_json["label"]# 使用BERT分词器进行分词tokenized_text = tokenize_with_location(tokenizer, text)tokens, bio_tags = [], []for token, loc in tokenized_text:loc_s, loc_e = locbio_tag = get_bio_tag(labels, loc_s, loc_e)bio_tags.append(bio_tag)tokens.append(token)return tokens, bio_tags# 输入JSON数据
input_json = {"text": "你们是最棒的!#英雄联盟d学sanchez创作的原声王", "label": {"game": {"英雄联盟": [[8, 11]]}}}
generate_bio_tags(tokenizer, input_json)
"""
(['[CLS]','你','们','是','最','棒','的','!','#','英','雄','联','盟','d','学','san','##che','##z','创','作','的','原','声','王','[SEP]'],['O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O'])"""

自定义Dataset

from tqdm.notebook import tqdm
import json
import pickle
import oscached_dataset = 'train.dataset.pkl'
train_file = 'train.json'
if not os.path.exists(cached_dataset):dataset = []with open(train_file, 'r') as file:for line in tqdm(file.readlines()):data = json.loads(line.strip())tokens, bio_tags = generate_bio_tags(tokenizer, data)if len(set(bio_tags)) > 1:dataset.append({"text": data["text"], "tokens": tokens, "tags": bio_tags})with open(cached_dataset, 'wb') as f:pickle.dump(dataset, f)else:with open(cached_dataset, 'rb') as f:dataset = pickle.load(f)

先把原始数据 {“text”: …, “label”: … } 转换成 {“text”: … , “tokens”: …, “tags”: …}

from itertools import product
from torch.utils.data import Dataset, DataLoaderlabels = ["O"] + [f"{i}-{j}" for i,j in product(['B','I'], ALLOW_LABEL)]
label2id = {k: v for v, k in enumerate(labels)}
id2label = {v: k for v, k in enumerate(labels)}class BertDataset(Dataset):def __init__(self, dataset, tokenizer, max_len):self.len = len(dataset)self.data = datasetself.tokenizer = tokenizerself.max_len = max_lendef __getitem__(self, index):# step 1: tokenize (and adapt corresponding labels)item = self.data[index]# step 2: add special tokens (and corresponding labels)tokenized_sentence = item["tokens"]labels = item["tags"] # add outside label for [CLS] token# step 3: truncating/paddingmaxlen = self.max_lenif (len(tokenized_sentence) > maxlen):# truncatetokenized_sentence = tokenized_sentence[:maxlen]labels = labels[:maxlen]else:# padtokenized_sentence = tokenized_sentence + ['[PAD]'for _ in range(maxlen - len(tokenized_sentence))]labels = labels + ["O" for _ in range(maxlen - len(labels))]# step 4: obtain the attention maskattn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]# step 5: convert tokens to input idsids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)label_ids = [label2id[label] for label in labels]# the following line is deprecated#label_ids = [label if label != 0 else -100 for label in label_ids]return {'ids': torch.tensor(ids, dtype=torch.long),'mask': torch.tensor(attn_mask, dtype=torch.long),#'token_type_ids': torch.tensor(token_ids, dtype=torch.long),'targets': torch.tensor(label_ids, dtype=torch.long)} def __len__(self):return self.len

拆分训练、测试集

import numpy as np
import random
def split_train_test_valid(dataset, train_size=0.9, test_size=0.1):dataset = np.array(dataset)total_size = len(dataset)# define the ratiostrain_len = int(total_size * train_size)test_len = int(total_size * test_size)# split the dataframeidx = list(range(total_size))random.shuffle(idx)  # 将index列表打乱data_train = dataset[idx[:train_len]]data_test = dataset[idx[train_len:train_len+test_len]]data_valid = dataset[idx[train_len+test_len:]]  # 剩下的就是validreturn data_train, data_test, data_validdata_train, data_test, data_valid = split_train_test_valid(dataset)
print("FULL Dataset: {}".format(len(dataset)))
print("TRAIN Dataset: {}".format(data_train.shape))
print("TEST Dataset: {}".format(data_test.shape))training_set = BertDataset(data_train, tokenizer, MAX_LEN)
testing_set = BertDataset(data_test, tokenizer, MAX_LEN)
train_params = {'batch_size': TRAIN_BATCH_SIZE,'shuffle': True,'num_workers': 0}test_params = {'batch_size': VALID_BATCH_SIZE,'shuffle': True,'num_workers': 0}
training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

训练

model = BertCrfForNer.from_pretrained('models/bert-base-chinese',
# model = AutoModelForTokenClassification.from_pretrained('save_model',num_labels=len(id2label),id2label=id2label,label2id=label2id)
if MULTI_GPU:model = torch.nn.DataParallel(model, )
model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')def train(epoch):tr_loss, tr_accuracy = 0, 0nb_tr_examples, nb_tr_steps = 0, 0tr_preds, tr_labels = [], []# put model in training modemodel.train()for idx, batch in enumerate(training_loader):ids = batch['ids'].to(device, dtype = torch.long)mask = batch['mask'].to(device, dtype = torch.long)targets = batch['targets'].to(device, dtype = torch.long)outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
#         loss, tr_logits = outputs.loss, outputs.logitsloss, tr_logits = outputs[0], outputs[1]if MULTI_GPU:loss = loss.mean()tr_loss += loss.item()nb_tr_steps += 1nb_tr_examples += targets.size(0)if idx % 100==0:loss_step = tr_loss/nb_tr_stepsprint(f"Training loss per 100 training steps: {loss_step}")# compute training accuracyflattened_targets = targets.view(-1) # shape (batch_size * seq_len,)num_labels = model.module.num_labels if MULTI_GPU else model.num_labelsactive_logits = tr_logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)# now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)targets = torch.masked_select(flattened_targets, active_accuracy)predictions = torch.masked_select(flattened_predictions, active_accuracy)tr_preds.extend(predictions)tr_labels.extend(targets)tmp_tr_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())tr_accuracy += tmp_tr_accuracy# gradient clippingtorch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=MAX_GRAD_NORM)# backward passoptimizer.zero_grad()loss.backward()optimizer.step()epoch_loss = tr_loss / nb_tr_stepstr_accuracy = tr_accuracy / nb_tr_stepsprint(f"Training loss epoch: {epoch_loss}")print(f"Training accuracy epoch: {tr_accuracy}")for epoch in range(EPOCHS):print(f"Training epoch: {epoch + 1}")train(epoch)
"""
Training epoch: 1
Training loss per 100 training steps: 76.82186126708984
Training loss per 100 training steps: 26.512494955912675
Training loss per 100 training steps: 18.23713019356799
Training loss per 100 training steps: 14.71561597431221
Training loss per 100 training steps: 12.793566083075698
Training loss epoch: 12.138352865534845
Training accuracy epoch: 0.9093487211512798
"""

验证、测试

def valid(model, testing_loader):# put model in evaluation modemodel.eval()eval_loss, eval_accuracy = 0, 0nb_eval_examples, nb_eval_steps = 0, 0eval_preds, eval_labels = [], []with torch.no_grad():for idx, batch in enumerate(testing_loader):ids = batch['ids'].to(device, dtype = torch.long)mask = batch['mask'].to(device, dtype = torch.long)targets = batch['targets'].to(device, dtype = torch.long)outputs = model(input_ids=ids, attention_mask=mask, labels=targets)loss, eval_logits = outputs[0], outputs[1]if MULTI_GPU:loss = loss.mean()eval_loss += loss.item()nb_eval_steps += 1nb_eval_examples += targets.size(0)if idx % 100==0:loss_step = eval_loss/nb_eval_stepsprint(f"Validation loss per 100 evaluation steps: {loss_step}")# compute evaluation accuracyflattened_targets = targets.view(-1) # shape (batch_size * seq_len,)num_labels = model.module.num_labels if MULTI_GPU else model.num_labelsactive_logits = eval_logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)# now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)targets = torch.masked_select(flattened_targets, active_accuracy)predictions = torch.masked_select(flattened_predictions, active_accuracy)eval_labels.extend(targets)eval_preds.extend(predictions)tmp_eval_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())eval_accuracy += tmp_eval_accuracy#print(eval_labels)#print(eval_preds)labels = [id2label[id.item()] for id in eval_labels]predictions = [id2label[id.item()] for id in eval_preds]#print(labels)#print(predictions)eval_loss = eval_loss / nb_eval_stepseval_accuracy = eval_accuracy / nb_eval_stepsprint(f"Validation Loss: {eval_loss}")print(f"Validation Accuracy: {eval_accuracy}")return labels, predictionslabels, predictions = valid(model, testing_loader)
"""
Validation loss per 100 evaluation steps: 5.371463775634766
Validation Loss: 5.623965330123902
Validation Accuracy: 0.9547014622783095
"""

指标计算

from seqeval.metrics import classification_reportprint(classification_report([labels], [predictions]))
"""precision    recall  f1-score   supportaddress       0.50      0.62      0.55       316company       0.65      0.77      0.70       270government       0.69      0.85      0.76       208name       0.87      0.87      0.87       374
organization       0.76      0.82      0.79       343micro avg       0.69      0.79      0.74      1511macro avg       0.69      0.79      0.73      1511
weighted avg       0.70      0.79      0.74      1511
"""

推理

from transformers import pipelinemodel_to_test = (model.module if hasattr(model, "module") else model
)
pipe = pipeline(task="token-classification", model=model_to_test.to("cpu"), tokenizer=tokenizer, aggregation_strategy="simple")pipe("我的名字是michal johnson,我的手机号是13425456344,我家住在东北松花江上8幢7单元6楼5号房")
"""
[{'entity_group': 'name','score': 0.83746755,'word': 'michal johnson','start': 5,'end': 19},{'entity_group': 'address','score': 0.924768,'word': '东 北 松 花 江 上 8 幢 7 单 元 6 楼 5 号 房','start': 42,'end': 58}]
"""

其它

相关参数

import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')MAX_LEN = 128
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 32
EPOCHS = 1
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10
MULTI_GPU = False
ALLOW_LABEL = ["name", "organization", "address","company","government"]

CRF 模块

参考:https://github.com/CLUEbenchmark/CLUENER2020/blob/master/pytorch_version/models/crf.py

import torch
import torch.nn as nn
from typing import List, Optionalclass CRF(nn.Module):"""Conditional random field.This module implements a conditional random field [LMP01]_. The forward computationof this class computes the log likelihood of the given sequence of tags andemission score tensor. This class also has `~CRF.decode` method which findsthe best tag sequence given an emission score tensor using `Viterbi algorithm`_.Args:num_tags: Number of tags.batch_first: Whether the first dimension corresponds to the size of a minibatch.Attributes:start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size``(num_tags,)``.end_transitions (`~torch.nn.Parameter`): End transition score tensor of size``(num_tags,)``.transitions (`~torch.nn.Parameter`): Transition score tensor of size``(num_tags, num_tags)``... [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001)."Conditional random fields: Probabilistic models for segmenting andlabeling sequence data". *Proc. 18th International Conf. on MachineLearning*. Morgan Kaufmann. pp. 282–289... _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm"""def __init__(self, num_tags: int, batch_first: bool = False) -> None:if num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()self.num_tags = num_tagsself.batch_first = batch_firstself.start_transitions = nn.Parameter(torch.empty(num_tags))self.end_transitions = nn.Parameter(torch.empty(num_tags))self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))self.reset_parameters()def reset_parameters(self) -> None:"""Initialize the transition parameters.The parameters will be initialized randomly from a uniform distributionbetween -0.1 and 0.1."""nn.init.uniform_(self.start_transitions, -0.1, 0.1)nn.init.uniform_(self.end_transitions, -0.1, 0.1)nn.init.uniform_(self.transitions, -0.1, 0.1)def __repr__(self) -> str:return f'{self.__class__.__name__}(num_tags={self.num_tags})'def forward(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: Optional[torch.ByteTensor] = None,reduction: str = 'mean') -> torch.Tensor:"""Compute the conditional log likelihood of a sequence of tags given emission scores.Args:emissions (`~torch.Tensor`): Emission score tensor of size``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,``(batch_size, seq_length, num_tags)`` otherwise.tags (`~torch.LongTensor`): Sequence of tags tensor of size``(seq_length, batch_size)`` if ``batch_first`` is ``False``,``(batch_size, seq_length)`` otherwise.mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.reduction: Specifies  the reduction to apply to the output:``none|sum|mean|token_mean``. ``none``: no reduction will be applied.``sum``: the output will be summed over batches. ``mean``: the output will beaveraged over batches. ``token_mean``: the output will be averaged over tokens.Returns:`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` ifreduction is ``none``, ``()`` otherwise."""if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'invalid reduction: {reduction}')if mask is None:mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)if mask.dtype != torch.uint8:mask = mask.byte()self._validate(emissions, tags=tags, mask=mask)if self.batch_first:emissions = emissions.transpose(0, 1)tags = tags.transpose(0, 1)mask = mask.transpose(0, 1)# shape: (batch_size,)numerator = self._compute_score(emissions, tags, mask)# shape: (batch_size,)denominator = self._compute_normalizer(emissions, mask)# shape: (batch_size,)llh = numerator - denominatorif reduction == 'none':return llhif reduction == 'sum':return llh.sum()if reduction == 'mean':return llh.mean()return llh.sum() / mask.float().sum()def decode(self, emissions: torch.Tensor,mask: Optional[torch.ByteTensor] = None,nbest: Optional[int] = None,pad_tag: Optional[int] = None) -> List[List[List[int]]]:"""Find the most likely tag sequence using Viterbi algorithm.Args:emissions (`~torch.Tensor`): Emission score tensor of size``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,``(batch_size, seq_length, num_tags)`` otherwise.mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.nbest (`int`): Number of most probable paths for each sequencepad_tag (`int`): Tag at padded positions. Often input varies in length andthe length will be padded to the maximum length in the batch. Tags atthe padded positions will be assigned with a padding tag, i.e. `pad_tag`Returns:A PyTorch tensor of the best tag sequence for each batch of shape(nbest, batch_size, seq_length)"""if nbest is None:nbest = 1if mask is None:mask = torch.ones(emissions.shape[:2], dtype=torch.uint8,device=emissions.device)if mask.dtype != torch.uint8:mask = mask.byte()self._validate(emissions, mask=mask)if self.batch_first:emissions = emissions.transpose(0, 1)mask = mask.transpose(0, 1)if nbest == 1:return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)def _validate(self, emissions: torch.Tensor,tags: Optional[torch.LongTensor] = None,mask: Optional[torch.ByteTensor] = None) -> None:if emissions.dim() != 3:raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')if emissions.size(2) != self.num_tags:raise ValueError(f'expected last dimension of emissions is {self.num_tags}, 'f'got {emissions.size(2)}')if tags is not None:if emissions.shape[:2] != tags.shape:raise ValueError('the first two dimensions of emissions and tags must match, 'f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')if mask is not None:if emissions.shape[:2] != mask.shape:raise ValueError('the first two dimensions of emissions and mask must match, 'f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')no_empty_seq = not self.batch_first and mask[0].all()no_empty_seq_bf = self.batch_first and mask[:, 0].all()if not no_empty_seq and not no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on')def _compute_score(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# tags: (seq_length, batch_size)# mask: (seq_length, batch_size)seq_length, batch_size = tags.shapemask = mask.float()# Start transition score and first emission# shape: (batch_size,)score = self.start_transitions[tags[0]]score += emissions[0, torch.arange(batch_size), tags[0]]for i in range(1, seq_length):# Transition score to next tag, only added if next timestep is valid (mask == 1)# shape: (batch_size,)score += self.transitions[tags[i - 1], tags[i]] * mask[i]# Emission score for next tag, only added if next timestep is valid (mask == 1)# shape: (batch_size,)score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]# End transition score# shape: (batch_size,)seq_ends = mask.long().sum(dim=0) - 1# shape: (batch_size,)last_tags = tags[seq_ends, torch.arange(batch_size)]# shape: (batch_size,)score += self.end_transitions[last_tags]return scoredef _compute_normalizer(self, emissions: torch.Tensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = emissions.size(0)# Start transition score and first emission; score has size of# (batch_size, num_tags) where for each batch, the j-th column stores# the score that the first timestep has tag j# shape: (batch_size, num_tags)score = self.start_transitions + emissions[0]for i in range(1, seq_length):# Broadcast score for every possible next tag# shape: (batch_size, num_tags, 1)broadcast_score = score.unsqueeze(2)# Broadcast emission score for every possible current tag# shape: (batch_size, 1, num_tags)broadcast_emissions = emissions[i].unsqueeze(1)# Compute the score tensor of size (batch_size, num_tags, num_tags) where# for each sample, entry at row i and column j stores the sum of scores of all# possible tag sequences so far that end with transitioning from tag i to tag j# and emitting# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + self.transitions + broadcast_emissions# Sum over all possible current tags, but we're in score space, so a sum# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of# all possible tag sequences so far, that end in tag i# shape: (batch_size, num_tags)next_score = torch.logsumexp(next_score, dim=1)# Set score to the next score if this timestep is valid (mask == 1)# shape: (batch_size, num_tags)score = torch.where(mask[i].unsqueeze(1), next_score, score)# End transition score# shape: (batch_size, num_tags)score += self.end_transitions# Sum (log-sum-exp) over all possible tags# shape: (batch_size,)return torch.logsumexp(score, dim=1)def _viterbi_decode(self, emissions: torch.FloatTensor,mask: torch.ByteTensor,pad_tag: Optional[int] = None) -> List[List[int]]:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)# return: (batch_size, seq_length)if pad_tag is None:pad_tag = 0device = emissions.deviceseq_length, batch_size = mask.shape# Start transition and first emission# shape: (batch_size, num_tags)score = self.start_transitions + emissions[0]history_idx = torch.zeros((seq_length, batch_size, self.num_tags),dtype=torch.long, device=device)oor_idx = torch.zeros((batch_size, self.num_tags),dtype=torch.long, device=device)oor_tag = torch.full((seq_length, batch_size), pad_tag,dtype=torch.long, device=device)# - score is a tensor of size (batch_size, num_tags) where for every batch,#   value at column j stores the score of the best tag sequence so far that ends#   with tag j# - history_idx saves where the best tags candidate transitioned from; this is used#   when we trace back the best tag sequence# - oor_idx saves the best tags candidate transitioned from at the positions#   where mask is 0, i.e. out of range (oor)# Viterbi algorithm recursive case: we compute the score of the best tag sequence# for every possible next tagfor i in range(1, seq_length):# Broadcast viterbi score for every possible next tag# shape: (batch_size, num_tags, 1)broadcast_score = score.unsqueeze(2)# Broadcast emission score for every possible current tag# shape: (batch_size, 1, num_tags)broadcast_emission = emissions[i].unsqueeze(1)# Compute the score tensor of size (batch_size, num_tags, num_tags) where# for each sample, entry at row i and column j stores the score of the best# tag sequence so far that ends with transitioning from tag i to tag j and emitting# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + self.transitions + broadcast_emission# Find the maximum score over all possible current tag# shape: (batch_size, num_tags)next_score, indices = next_score.max(dim=1)# Set score to the next score if this timestep is valid (mask == 1)# and save the index that produces the next score# shape: (batch_size, num_tags)score = torch.where(mask[i].unsqueeze(-1), next_score, score)indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx)history_idx[i - 1] = indices# End transition score# shape: (batch_size, num_tags)end_score = score + self.end_transitions_, end_tag = end_score.max(dim=1)# shape: (batch_size,)seq_ends = mask.long().sum(dim=0) - 1# insert the best tag at each sequence end (last position with mask == 1)history_idx = history_idx.transpose(1, 0).contiguous()history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))history_idx = history_idx.transpose(1, 0).contiguous()# The most probable path for each sequencebest_tags_arr = torch.zeros((seq_length, batch_size),dtype=torch.long, device=device)best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)for idx in range(seq_length - 1, -1, -1):best_tags = torch.gather(history_idx[idx], 1, best_tags)best_tags_arr[idx] = best_tags.data.view(batch_size)return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1)def _viterbi_decode_nbest(self, emissions: torch.FloatTensor,mask: torch.ByteTensor,nbest: int,pad_tag: Optional[int] = None) -> List[List[List[int]]]:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)# return: (nbest, batch_size, seq_length)if pad_tag is None:pad_tag = 0device = emissions.deviceseq_length, batch_size = mask.shape# Start transition and first emission# shape: (batch_size, num_tags)score = self.start_transitions + emissions[0]history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest),dtype=torch.long, device=device)oor_idx = torch.zeros((batch_size, self.num_tags, nbest),dtype=torch.long, device=device)oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag,dtype=torch.long, device=device)# + score is a tensor of size (batch_size, num_tags) where for every batch,#   value at column j stores the score of the best tag sequence so far that ends#   with tag j# + history_idx saves where the best tags candidate transitioned from; this is used#   when we trace back the best tag sequence# - oor_idx saves the best tags candidate transitioned from at the positions#   where mask is 0, i.e. out of range (oor)# Viterbi algorithm recursive case: we compute the score of the best tag sequence# for every possible next tagfor i in range(1, seq_length):if i == 1:broadcast_score = score.unsqueeze(-1)broadcast_emission = emissions[i].unsqueeze(1)# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + self.transitions + broadcast_emissionelse:broadcast_score = score.unsqueeze(-1)broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)# shape: (batch_size, num_tags, nbest, num_tags)next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission# Find the top `nbest` maximum score over all possible current tag# shape: (batch_size, nbest, num_tags)next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)if i == 1:score = score.unsqueeze(-1).expand(-1, -1, nbest)indices = indices * nbest# convert to shape: (batch_size, num_tags, nbest)next_score = next_score.transpose(2, 1)indices = indices.transpose(2, 1)# Set score to the next score if this timestep is valid (mask == 1)# and save the index that produces the next score# shape: (batch_size, num_tags, nbest)score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score)indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx)history_idx[i - 1] = indices# End transition score shape: (batch_size, num_tags, nbest)end_score = score + self.end_transitions.unsqueeze(-1)_, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)# shape: (batch_size,)seq_ends = mask.long().sum(dim=0) - 1# insert the best tag at each sequence end (last position with mask == 1)history_idx = history_idx.transpose(1, 0).contiguous()history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))history_idx = history_idx.transpose(1, 0).contiguous()# The most probable path for each sequencebest_tags_arr = torch.zeros((seq_length, batch_size, nbest),dtype=torch.long, device=device)best_tags = torch.arange(nbest, dtype=torch.long, device=device) \.view(1, -1).expand(batch_size, -1)for idx in range(seq_length - 1, -1, -1):best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags)best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbestreturn torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/643746.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VulnHub靶机 DC-8 打靶实战 详细渗透过程

VulnHub靶机 DC-8 打靶 详细渗透过程 目录 VulnHub靶机 DC-8 打靶 详细渗透过程一、将靶机配置导入到虚拟机当中二、渗透测试流程主机发现端口扫描Web渗透SQL注入登录后台反弹shell提权 一、将靶机配置导入到虚拟机当中 靶机地址&#xff1a; https://www.vulnhub.com/entry/…

人工智能时代的关键技术:深入探索向量数据库及其在AI中的应用

文章目录 1. 理解向量数据库&#xff1a;二维模型示例2. 向量数据库中的数据存储与检索3. 向量数据库如何工作&#xff1f;4. 向量数据库如何知道哪些向量相似&#xff1f; 在人工智能技术日益成熟的当下&#xff0c;向量数据库作为处理和检索高维数据的关键工具&#xff0c;对…

使用新版ESLint,搭配Prettier使用的配置方式

概述 ESLint重大更新(9.0.0版本)后,将不再支持非扁平化配置文件,并且移除了与Prettier冲突的规则,也就是说与Prettier搭配使用,不再需要使用插件“eslint-config-prettier”来处理冲突问题。 注:使用新版的前提条件是Node.js版本必须是18.18.0、20.9.0,或者是>=21.1…

鸿蒙官网学习3

鸿蒙官网学习3 每日小提示项目的模块类型跨设备预览调试阶段应用的替换方式有两种 打开老的demo工程报错UIAbility 每日小提示 项目的模块类型 moduleType分为三种&#xff0c;只有1&#xff0c;2的模块支持直接调试和运行 entryfeaturehar 跨设备预览 需要手动在config.j…

Tensorflow2.0笔记 - BatchNormalization

本笔记记录BN层相关的代码。关于BatchNormalization&#xff0c;可以自行百度&#xff0c;或参考这里&#xff1a; 一文读懂Batch Normalization - 知乎神经网络基础系列&#xff1a; 《深度学习中常见激活函数的原理和特点》《过拟合: dropout原理和在模型中的多种应用》深度…

代码随想录算法训练营DAY32|C++贪心算法Part.2|122.买卖股票的最佳时机II、55.跳跃游戏、45.跳跃游戏II

文章目录 122.买卖股票的最佳时机II思路CPP代码 55.跳跃游戏思路CPP代码 45.跳跃游戏II思路方法一代码改善 CPP代码 122.买卖股票的最佳时机II 力扣题目链接 文章讲解&#xff1a;122.买卖股票的最佳时机II 视频讲解&#xff1a; 状态&#xff1a;本题可以用动态规划&#xff0…

更易使用,OceanBase开发者工具 ODC 4.2.4 版本升级

亲爱的朋友们&#xff0c;大家好&#xff01;我们的ODC&#xff08;OceanBase Developer Center &#xff09;再次迎来了重要的升级V 4.2.4&#xff0c;这次我们诚意满满&#xff0c;从五个方面为大家精心打造了一个更加易用、贴心&#xff0c;且功能更强的新版本&#xff0c;相…

基础SQL DQL语句

基础查询 select * from 表名; 查询所有字段 create table emp(id int comment 编号,workno varchar(10) comment 工号,name varchar(10) comment 姓名,gender char(1) comment 性别,age tinyint unsigned comment 年龄,idcard char(18) comment 身份证号,worka…

javascript(第三篇)原型、原型链、继承问题,使用 es5、es6实现继承,一网打尽所有面试题

没错这是一道【去哪儿】的面试题目&#xff0c;手写一个 es5 的继承&#xff0c;我又没有回答上来&#xff0c;很惭愧&#xff0c;我就只知道 es5 中可以使用原型链实现继承&#xff0c;但是代码一行也写不出来。 关于 js 的继承&#xff0c;是在面试中除了【 this 指针、命名提…

Golang基础1

基本类型 bool 整数&#xff1a;byte(相当于uint8), rune(相当于int32), int/uint ,int8/uint8 ,int16/uint16 ,int32/uint32 ,int64/uint64 浮点数: float32 ,float64, complex64 ,complex128 array&#xff08;值类型&#xff09;、slice、map、chan&#xff08;引用类型…

Vitis HLS 学习笔记--对于启动时间间隔(II)的理解

目录 1. II的重要性 2. 案例分析 3. 总结 1. II的重要性 在Vitis HLS&#xff08;High-Level Synthesis&#xff09;中&#xff0c;启动时间间隔&#xff08;II&#xff0c;Iteration Interval&#xff09;是一个非常关键的概念&#xff0c;对于实现高性能的硬件加速器设计…

DS进阶:AVL树和红黑树

一、AVL树 1.1 AVL树的概念 二叉搜索树&#xff08;BST&#xff09;虽可以缩短查找的效率&#xff0c;但如果数据有序或接近有序二叉搜索树将退化为单支树&#xff0c;查找元素相当于在顺序表中搜索元素&#xff0c;效率低下。因此&#xff0c;两位俄罗斯的数学家G.M.Adelson-…