last_hidden_state vs pooler_output的区别

一、问题来源:

from transformers import AutoTokenizer, AutoModel
import torch
# Load model from HuggingFace Hub
MODEL_NAME_PATH = 'xxxx/model/bge-large-zh'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_PATH)
model = AutoModel.from_pretrained(MODEL_NAME_PATH)

模型结构如下:

BertModel((embeddings): BertEmbeddings((word_embeddings): Embedding(21128, 1024, padding_idx=0)(position_embeddings): Embedding(512, 1024)(token_type_embeddings): Embedding(2, 1024)(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False))(encoder): BertEncoder((layer): ModuleList((0-23): 24 x BertLayer((attention): BertAttention((self): BertSelfAttention((query): Linear(in_features=1024, out_features=1024, bias=True)(key): Linear(in_features=1024, out_features=1024, bias=True)(value): Linear(in_features=1024, out_features=1024, bias=True)(dropout): Dropout(p=0.1, inplace=False))(output): BertSelfOutput((dense): Linear(in_features=1024, out_features=1024, bias=True)(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))(intermediate): BertIntermediate((dense): Linear(in_features=1024, out_features=4096, bias=True)(intermediate_act_fn): GELUActivation())(output): BertOutput((dense): Linear(in_features=4096, out_features=1024, bias=True)(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))))(pooler): BertPooler((dense): Linear(in_features=1024, out_features=1024, bias=True)(activation): Tanh())
)

Q1、cls的值和pooler的值是一样的吗?
Q2、最后的pooler层和hidden层是什么关系?

二、实验证明:

Q1、cls的值和pooler的值是一样的吗?

# Sentences we want sentence embeddings for
sentences = ["开心", "快乐", "难过", "天气", "今天会有大大的台风吗?"]
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt', max_length=200)
# for retrieval task, add an instruction to query
# encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():model_output = model(**encoded_input)# Perform pooling. In this case, cls pooling.sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)

print(‘cls:’, model_output[0][:, 0, :])

cls: tensor([[ 0.3269, -0.6412, -0.2382,  ...,  0.0255, -0.1801, -0.3025],[ 0.1351, -0.5155, -0.1700,  ...,  0.1093, -0.3750, -0.1323],[ 0.2752, -0.1703, -0.2730,  ...,  0.0376, -0.0339, -0.3541],[ 0.1346, -0.0378, -0.5070,  ...,  0.0078,  0.0472, -0.1815],[-0.4051,  0.1123, -0.3873,  ...,  0.3585,  0.4913,  0.3192]])

print(‘pooler:’, model_output[1])

pooler: tensor([[ 0.3888, -0.2329, -0.1749,  ...,  0.1678,  0.3938, -0.3191],[ 0.3949, -0.2882, -0.0945,  ...,  0.1802,  0.2705, -0.1891],[ 0.4765, -0.1235, -0.2330,  ...,  0.3005,  0.3487, -0.1290],[ 0.3851, -0.1853, -0.3189,  ...,  0.2757,  0.3601, -0.3220],[ 0.3008, -0.3742, -0.4550,  ...,  0.4318,  0.2130, -0.1575]])

cls的值和pooler的值不一样

Q2、最后的pooler层和hidden层是什么关系?

理论层面:

transformers.models.bert.modeling_bert.BertModel.forward方法中这么一行代码:

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

pooler的定义:

self.pooler = BertPooler(config) if add_pooling_layer else None

BertPooler的定义:

class BertPooler(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.activation = nn.Tanh()def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:# We "pool" the model by simply taking the hidden state corresponding# to the first token.first_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor)pooled_output = self.activation(pooled_output)return pooled_output

从上面的源码可以看出,pooler_output 就是[CLS]embedding又经历了一次全连接层的输出

数据层面:
model.pooler(model_output[0])
tensor([[ 0.3888, -0.2329, -0.1749,  ...,  0.1678,  0.3938, -0.3191],[ 0.3949, -0.2882, -0.0945,  ...,  0.1802,  0.2705, -0.1891],[ 0.4765, -0.1235, -0.2330,  ...,  0.3005,  0.3487, -0.1290],[ 0.3851, -0.1853, -0.3189,  ...,  0.2757,  0.3601, -0.3220],[ 0.3008, -0.3742, -0.4550,  ...,  0.4318,  0.2130, -0.1575]],grad_fn=<TanhBackward0>)

在这里插入图片描述
pooler_output 就是[CLS]embedding又经历了一次全连接层的输出

三、结论:

pooler就是将[CLS]这个token再过一下全连接层+Tanh激活函数,作为该句子的特征向量

四、Bert的Pooler_output的由来

我们知道,BERT的训练包含两个任务:MLM和NSP任务(Next Sentence Prediction)。 对这两个任务不熟悉的朋友可以参考:BERT源码实现与解读(Pytorch) 和 【论文阅读】BERT 两篇文章。

其中MLM就是挖空,然后让bert预测这个空是什么。做该任务是使用token embedding进行预测。

而Next Sentence Prediction就是预测bert接受的两句话是否为一对。例如:窗前明月光,疑是地上霜 为 True,窗前明月光,李白打开窗为False。

所以,NSP任务需要句子的语义信息来预测,但是我们看下源码是怎么做的。

class BertForNextSentencePrediction(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.bert = BertModel(config)self.cls = BertOnlyNSPHead(config)	# 这个就是一个 nn.Linear(config.hidden_size, 2)...def forward(...):...outputs = self.bert(...)pooled_output = outputs[1] # 取pooler_outputseq_relationship_scores = self.cls(pooled_output)	# 使用pooler_ouput送给后续的全连接层进行预测...从上面的源码可以看出,在NSP任务训练时,并不是直接使用[CLS]token的embedding作为句子特征传给后续分类头的,而是使用的是pooler_output。个人原因可能是因为直接使用[CLS]的embedding效果不够好。
但在MLM任务时,是直接使用的是last_hidden_state,有兴趣可以看一下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/60564.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DROP USER c##xyt CASCADE > ORA-01940: 无法删除当前连接的用户

多创建了一个用户&#xff0c;想要给它删除掉 一 上执行过程&#xff0c;确实删除成功了 Oracle Database 12c Enterprise Edition Release 12.1.0.2.0 - 64bit Production With the Partitioning, OLAP, Advanced Analytics and Real Application Testing optionsSQL> DR…

在Go语言单元测试中如何解决Redis存储依赖问题

登录程序示例 在 Web 开发中&#xff0c;登录需求是一个较为常见的功能。假设我们有一个 Login 函数&#xff0c;可以实现用户登录功能。它接收用户手机号 短信验证码&#xff0c;然后根据手机号从 Redis 中获取保存的验证码&#xff08;验证码通常是在发送验证码这一操作时保…

[C#] 简单的俄罗斯方块实现

一个控制台俄罗斯方块游戏的简单实现. 已在 github.com/SlimeNull/Tetris 开源. 思路 很简单, 一个二维数组存储当前游戏的方块地图, 用 bool 即可, true 表示当前块被填充, false 表示没有. 然后, 抽一个 “形状” 类, 形状表示当前玩家正在操作的一个形状, 例如方块, 直线…

Pycharm如何打断点进行调试?

断点调试&#xff0c;是编写程序中一个很重要的步骤&#xff0c;有些简单的程序使用print语句就可看出问题&#xff0c;而比较复杂的程序&#xff0c;函数和变量较多的情况下&#xff0c;这时候就需要打断点了&#xff0c;更容易定位问题。 一、添加断点 在代码的行标前面&…

吐血整理,Jenkins配置邮件发送测试报告持续集成,看这一篇就够了...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 开启SMTP服务 这…

Linux下进程的特点与环境变量

目录 进程的特点 进程特点的介绍 进程时如何实现并发性的 进程间如何切换 概念铺设 PC指针 上下文 环境变量 PATH 修改PATH HOME SHELL env 命令行参数 什么是命令行参数&#xff1f; 打印命令行参数 通过函数获得环境变量 getenv 命令行参数 env 修改环境变…

C 语言的逻辑运算符

C 语言的逻辑运算符包括三种&#xff1a; 逻辑运算符可以将两个关系表达式连接起来. Suppose exp1 and exp2 are two simple relational expressions, such as cat > rat and debt 1000 . Then you can state the following: ■ exp1 && exp2 is true only if bo…

百度chatgpt内测版

搜索AI伙伴 申请到了百度的chatgpt&#xff1a; 完整的窗口布局&#xff1a; 三个哲学问题&#xff1a; 灵感中心&#xff1a; 请做一副画&#xff0c;一个渔夫&#xff0c;冬天&#xff0c;下着大雪&#xff0c;在船上为了一家的生计在钓鱼&#xff0c;远处的山上也都是白雪&a…

北京多铁克FPGA笔试题目

1、使用D触发器来实现二分频 2、序列检测器&#xff0c;检测101&#xff0c;输出1&#xff0c;其余情况输出0 module Detect_101(input clk,input rst_n,input data, //输入的序列output reg flag_101 //检测到101序列的输出标志 );parameter S0 2d0;S1 2d1;S2 2d2;S4 …

W6100-EVB-PICO作为TCP Client 进行数据回环测试(五)

前言 上一章我们用W6100-EVB-PICO开发板通过DNS解析www.baidu.com&#xff08;百度域名&#xff09;成功得到其IP地址&#xff0c;那么本章我们将用我们的开发板作为客户端去连接服务器&#xff0c;并做数据回环测试&#xff1a;收到服务器发送的数据&#xff0c;并回传给服务器…

android开发之Android 自定义滑动解锁View

自定义滑动解锁View 需求如下&#xff1a; 近期需要做一个类似屏幕滑动解锁的功能&#xff0c;右划开始&#xff0c;左划暂停。 需求效果图如下 实现效果展示 自定义view如下 /** Desc 自定义滑动解锁View Author ZY Mail sunnyfor98gmail.com Date 2021/5/17 11:52 *…

【LeetCode】打家劫舍||

打家劫舍|| 题目描述算法分析编程代码 链接: 打家劫舍|| 在做这个题之前&#xff0c;建议大家做一下这个链接: 按摩师 我的博客里也有这个题的讲解&#xff0c;名字是按摩师 题目描述 算法分析 编程代码 class Solution { public:int maxrob(vector<int>nums,int left,…