copy is all you need前向绘图 和疑惑标记

疑惑的起因
在这里插入图片描述
简化前向图
在这里插入图片描述
GPT4解释

这段代码实现了一个神经网络模型,包含了BERT、GPT-2和MLP等模块。主要功能是给定一个文本序列和一个查询序列,预测查询序列中的起始和结束位置,使其对应文本序列中的一个短语。具体实现细节如下:

首先,使用BERT将文本序列编码成一个张量作为输入,并从GPT-2模型中提取查询序列的表示。然后,将查询序列的表示拆分为起始和结束两部分,并使用MLP将短语的起始和结束点表示投影到另一个空间中。
接下来,利用生成式模型计算每个查询起始和结束位置之间的所有令牌的概率分布,并根据此生成预测位置。使用交叉熵损失函数来计算这些预测和真实值之间的差距。
最后,计算起始和结束位置的准确性和整个短语的准确性,并返回损失和准确性的值来优化模型的性能。

import torch
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel
from torch import nn
from torch.nn import functional as Fclass Copyisallyouneed(nn.Module):def __init__(self, **args):super(Copyisallyouneed, self).__init__()self.args = args# bert-encoder modelself.phrase_encoder = AutoModel.from_pretrained(self.args['phrase_encoder_model'][self.args['lang']])self.bert_tokenizer = AutoTokenizer.from_pretrained(self.args['phrase_encoder_tokenizer'][self.args['lang']])self.bert_tokenizer.add_tokens(['<|endoftext|>', '[PREFIX]'])self.prefix_token_id = self.bert_tokenizer.convert_tokens_to_ids('[PREFIX]')self.phrase_encoder.resize_token_embeddings(self.phrase_encoder.config.vocab_size + 2)# model and tokenizerself.tokenizer = AutoTokenizer.from_pretrained(self.args['prefix_encoder_tokenizer'][self.args['lang']])self.vocab_size = len(self.tokenizer)self.pad = self.tokenizer.pad_token_id if self.args['lang'] == 'zh' else self.tokenizer.bos_token_idself.model = GPT2LMHeadModel.from_pretrained(self.args['prefix_encoder_model'][self.args['lang']])self.token_embeddings = nn.Parameter(list(self.model.lm_head.parameters())[0])# MLP: mapping bert phrase start representationsself.s_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))# MLP: mapping bert phrase end representationsself.e_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))self.gen_loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad)@torch.no_grad()def get_query_rep(self, ids):self.eval()output = self.model(input_ids=ids, output_hidden_states=True)['hidden_states'][-1][:, -1, :]return outputdef get_token_loss(self, ids, hs, ids_mask):# no pad tokenlabel = ids[:, 1:]logits = torch.matmul(hs[:, :-1, :],self.token_embeddings.t())# TODO: inner loss function remove the temperature factorlogits /= self.args['temp']loss = self.gen_loss_fct(logits.view(-1, logits.size(-1)), label.reshape(-1))chosen_tokens = torch.max(logits, dim=-1)[1]gen_acc = (chosen_tokens.reshape(-1) == label.reshape(-1)).to(torch.long)valid_mask = (label != self.pad).reshape(-1)valid_tokens = gen_acc & valid_maskgen_acc = valid_tokens.sum().item() / valid_mask.sum().item()return loss, gen_accdef forward(self, batch):## gpt2 query encoderids, ids_mask = batch['gpt2_ids'], batch['gpt2_mask']last_hidden_states = \self.model(input_ids=ids, attention_mask=ids_mask, output_hidden_states=True).hidden_states[-1]# get token lossloss_0, acc_0 = self.get_token_loss(ids, last_hidden_states, ids_mask)## encode the document with the BERT encoder modeldids, dids_mask = batch['bert_ids'], batch['bert_mask']output = self.phrase_encoder(dids, dids_mask, output_hidden_states=True)['hidden_states'][-1]  # [B, S, E]# collect the phrase start representations and phrase end representationss_rep = self.s_proj(output)e_rep = self.e_proj(output)s_rep = s_rep.reshape(-1, s_rep.size(-1))e_rep = e_rep.reshape(-1, e_rep.size(-1))  # [B_doc*S_doc, 768//2]# collect the query representationsquery = last_hidden_states[:, :-1].reshape(-1, last_hidden_states.size(-1))query_start = query[:, :self.model.config.hidden_size // 2]query_end = query[:, self.model.config.hidden_size // 2:]# training the representations of the start tokenscandidate_reps = torch.cat([self.token_embeddings[:, :self.model.config.hidden_size // 2],s_rep], dim=0)logits = torch.matmul(query_start, candidate_reps.t())logits /= self.args['temp']# build the padding mask for query sidequery_padding_mask = ids_mask[:, :-1].reshape(-1).to(torch.bool)# build the padding mask: 1 for valid and 0 for maskattention_mask = (dids_mask.reshape(1, -1).to(torch.bool)).to(torch.long)padding_mask = torch.ones_like(logits).to(torch.long)# Santiy check overpadding_mask[:, self.vocab_size:] = attention_mask# build the position mask: 1 for valid and 0 for maskpos_mask = batch['pos_mask']start_labels, end_labels = batch['start_labels'][:, 1:].reshape(-1), batch['end_labels'][:, 1:].reshape(-1)position_mask = torch.ones_like(logits).to(torch.long)query_pos = start_labels > self.vocab_size# ignore the padding maskposition_mask[query_pos, self.vocab_size:] = pos_maskassert padding_mask.shape == position_mask.shape# overall maskoverall_mask = padding_mask * position_mask## remove the position mask# overall_mask = padding_masknew_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), start_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_1 = (-loss_.sum(dim=-1)).mean()## split the token accuaracy and phrase accuracyphrase_indexes = start_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]phrase_start_acc = phrase_start_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]token_start_acc = token_start_acc.to(torch.float).mean().item()# training the representations of the end tokenscandidate_reps = torch.cat([self.token_embeddings[:, self.model.config.hidden_size // 2:],e_rep], dim=0)logits = torch.matmul(query_end, candidate_reps.t())  # [Q, B*]  logits /= self.args['temp']new_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), end_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_2 = (-loss_.sum(dim=-1)).mean()# split the phrase and token accuracyphrase_indexes = end_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]phrase_end_acc = phrase_end_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]token_end_acc = token_end_acc.to(torch.float).mean().item()return (loss_0,  # token lossloss_1,  # token-head lossloss_2,  # token-tail lossacc_0,  # token accuracyphrase_start_acc,phrase_end_acc,token_start_acc,token_end_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/83678.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java异常

异常体系结构Java异常的分类异常处理异常抛出异常声明异常捕获异常处理流程自定义异常 异常体系结构 .Error&#xff1a;指的是Java虚拟机无法解决的严重问题&#xff0c;比如&#xff1a;JVM的内部错误、资源耗尽等&#xff0c;典型代表&#xff1a;StackOverflowError&#…

使用Burp Suite进行Web应用渗透测试

使用Burp Suite进行Web应用渗透测试是一种常见的方法&#xff0c;可以帮助发现Web应用程序中的安全漏洞和弱点。 步骤&#xff1a; 准备工作&#xff1a; 首先&#xff0c;确保已经安装了Burp Suite&#xff0c;并配置浏览器以使用Burp Suite作为代理。 配置代理&#xff1a;…

cortex-A7 SPI实验 --- STM32MP157

实验目的&#xff1a; 1、数码管显示相同的值 0000 1111 ......9999 2、数码管显示不同的值 1234 一&#xff0c;SPI概念 1&#xff0c;SPI总线是 全双工三线 / 四线同步串行总线&#xff0c;有两根单向数据线(MOSI &#xff0c;MISO)&#xff0c;一根设备片选线&#xff0…

凉而不冷 柔而不弱 三菱重工海尔舒适风科技助您整夜安眠

古人云&#xff1a;安寝乃人生乐事。可随着夏天的到来&#xff0c;昼长夜短&#xff0c;家里的老人、儿童、父母都存在不同的入睡苦恼。对于儿童来说&#xff0c;空调温度调的太低容易踢被子着凉&#xff0c;温度调的高又怕孩子满头大汗&#xff1b;父母自身也会因为半夜帮孩子…

C++学习记录——이십유 C++11(2)

文章目录 1、类的新功能1、移动构造和移动赋值2、default、delete 2、可变参数模板3、STL容器的emplace 1、类的新功能 1、移动构造和移动赋值 逐成员按字节拷贝就是浅拷贝。一个类中&#xff0c;如果达成默认移动构造的要求&#xff0c;那么传右值就会使用移动构造了&#xf…

电脑显示“Operating System not found”该怎么办?

“Operating System not found”是一种常见的电脑错误提示&#xff0c;这类错误会导致你无法成功启动Windows。那么电脑显示“Operating System not found”该怎么办呢&#xff1f; 方法1. 检查硬盘 首先&#xff0c;您可以测试硬盘是否存在问题。为此&#xff0c;您可以采取以…

SpringCloud/SpringBoot多模块项目中配置公共AOP模块实现打印子模块Controller所有请求参数与日志

项目中遇到多个模块需要打印Controller请求日志&#xff0c;在每个模块里面加AOP并且配置单独的切面笔者认为代码冗余&#xff0c;于是乎就打算把AOP日志打印抽离成一个公共模块&#xff0c;谁想用就引入Maven坐标就行。 定义公共AOP模块 并编写AOP工具 AOP模块pom.xml如下 &…

如何保障Facebook账号登录稳定

当谈到保障Facebook账号的稳定性时&#xff0c;我们不得不提到那些令人头疼的情况——Facebook账号被封。尽管我们已经踏入数字化的未来&#xff0c;但是被封号似乎是一个时常困扰着社交媒体用户的问题。那么&#xff0c;让我们来看看一些常见的Facebook账号被封的原因&#xf…

Eureka:服务注册-信息配置-自我保护机制

首先在提供者服务下&#xff0c;添加一个依赖 <!-- Eureka --><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-eureka</artifactId><version>1.4.6.RELEASE</version><…

ES6中promise的使用

ES6中promise的使用 本文目录 ES6中promise的使用基础介绍箭头函数function函数状态 原型方法Promise.prototype.then()Promise.prototype.catch() 静态方法Promise.all()Promise.race()Promise.any() 链式回调 基础介绍 官网&#xff1a;https://promisesaplus.com/ window.…

Android进阶之路 - EditText输入字体自适应

遇到这么一个需求&#xff1a;“控件宽度有限&#xff0c;随着输入内容&#xff0c;动态修改字体大小”&#xff0c;如果是你&#xff0c;只如何来实现&#xff1f;又有几种方式&#xff1f; 嗯&#xff0c;就是这么一个简单的需求&#xff0c;让我记录了俩篇blog Android进阶…

象小朋友学识字一样建构战略

战略认知派: 就像小孩子学识字的过程一样【安志强趣讲268期】 趣讲大白话&#xff1a;提高认知中长大 **************************** 基于认知心理学 战略的形成是认知构建的过程 最找连战略这个词都不一定知道 慢慢有些概念 慢慢形成整体的认知 战略认知派分两个分支 第一分支…