BERT 微调实战

news/2024/11/8 15:04:02/文章来源:https://www.cnblogs.com/bingohuang/p/18535089

带着问题来学习

  1. BERT 的预训练过程是如何完成的,在预训练过程中,采用了哪两种任务?

  2. 本次实战是用 SQuAD 数据集微调 BERT, 来完成我们的问答任务,你能否用 IMDB 影评数据集来微调 BERT,改进 BERT 的结果准确率?

文章最后会公布问题的参考答案~

一、BERT 简介

BERT 全称 Bidirectional Encoder Representations from Transformers,是 Google 在2018 年提出来的,核心架构是多层 Transformer 编码器,引入了 Masked Language Model(MLM)和 Next Sentence Prediction(NSP)两个任务来训练模型。对于每个任务,可以通过在预训练模型的顶部添加一些额外的层来微调模型。

微调 BERT 需要用到 HuggingFace 组件,建议先学习参考:HuggingFace 核心组件及应用实战

二、BERT 实战:原生 BERT 完成问答任务

我们用 Google 原生发布的 BERT 去做问答任务,看看它效果如何。完成问答任务步骤如下:

  1. 准备问题和问答类型任务

  2. 下载含有问题任务头的原始版 BERT

  3. 直接用 BERT 做推理,回答问题

代码(BERT_for_QA.py)如下:

from transformers import BertTokenizer, BertForQuestionAnswering
import torch
import numpy as np# Set the random seed for PyTorch and NumPy
torch.manual_seed(0)
np.random.seed(0)tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')question, text = "What is the capital of China?", "The capital of China is Beijing."inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
with torch.no_grad():outputs = model(**inputs)answer_start_index = torch.argmax(outputs.start_logits)
answer_end_index = torch.argmax(outputs.end_logits) + 1predict_answer_tokens = inputs['input_ids'][0][answer_start_index:answer_end_index]
predicted_answer = tokenizer.decode(predict_answer_tokens)print("What is the capital of China?", predicted_answer)

运行结果如下:

C:\Users\Lenovo\anaconda3\envs\pytorch211\python.exe "BERT_for_QA.py" 
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
What is the capital of China? the

其中,我们的问题是:What is China's capital?

并提供的上下文: The capital of China is Beijing.

BERT 回答是: the

😓可以看到,原生 BERT 的回答有点无厘头,因为它没有通过问答任务进行训练,效果很差。接下来我们就来微调BERT看看。

三、BERT 实战:微调 BERT 完成问答任务

用 SQuAD 数据集去微调BERT,之后再尝试让它完成问答任务,再看看效果如何,这里首先介绍下 SQuAD数据集。

1、SQuAD数据集

SQuAD(Stanford Question Answering Dataset)数据集是 Stanford 大学发布的用于问答任务的标准数据集,它从维基百科中抽取出来了很多问题和答案。

SQuAD 数据集中每个问题(Question)接一个上下文(Context),答案(Answering)必须包含在上下文中,它本质上是一个抽取类型的任务,从一个大段的文本中,抽取几个相邻的文字,代表这个问题的答案。

2、BERT 微调流程

BERT 预训练+微调的核心流程图如下:

对 BERT 进行微调,一般有两种方式:

  1. 只微调分类输出头,保持预训练 BERT 大量的参数都不变,只聚焦于我们自己的分类输出头。相当于是 BERT 已经有很多人类的自然语言处理知识了,我们只是告诉它要干什么就够了。

  2. 在微调的过程中把 BERT 的参数整体也进行微调。

通常用第1种方式就够了,除非你的任务比较复杂和特殊,才回采用第2中方式。

梳理下微调 BERT 步骤,方便更好的编写代码:

  1. 准备问题和问答类型任务

  2. 下载含有问答任务头的原始版 BERT

  3. 转换 SQuAD 特征数据集

  4. 用 SQuAD 微调 BERT

  5. 用微调后的 BERT 做推理,回答问题

为了提高训练的速度,其中第3步可以提前完成。

2、转换为 BERT 输入特征

这里是将 SQuAD 2.0 数据集下载下来,并通过程序先转换为BERT输入特征,代码(squad_feature_creation.py)如下:

import pickle
from transformers.data.processors.squad import SquadV2Processor, squad_convert_examples_to_features
from transformers import BertTokenizer# 初始化SQuAD Processor, 数据集, 和分词器
processor = SquadV2Processor()
train_examples = processor.get_train_examples('SQuAD')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')if __name__ == '__main__':# 将SQuAD 2.0示例转换为BERT输入特征train_features = squad_convert_examples_to_features(examples=train_examples,tokenizer=tokenizer,max_seq_length=384,doc_stride=128,max_query_length=64,is_training=True,return_dataset=False,threads=1)# 将特征保存到磁盘上with open('SQuAD_train_features.pkl', 'wb') as f:pickle.dump(train_features, f)

本地运行,会生成一份SQuAD_train_features.pkl文件,运行结果如下:

C:\Users\Lenovo\anaconda3\envs\pytorch211\python.exe "squad_feature_creation.py" 
100%|██████████| 442/442 [00:18<00:00, 24.29it/s]
100%|██████████| 442/442 [00:18<00:00, 23.87it/s]
convert squad examples to features: 100%|██████████| 130319/130319 [08:20<00:00, 260.32it/s]
add example index and unique id: 100%|██████████| 130319/130319 [00:00<00:00, 1472740.97it/s]

3、BERT 微调并回答问题

接下来就是核心的 BERT 微调环节,代码(BERT_SQuAD_Finetuned.py)如下(完整代码见附件):

from transformers import BertForQuestionAnswering, BertTokenizer, BertForQuestionAnswering, AdamW
import torch
from torch.utils.data import TensorDataset# 是否有GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 下载未经微调的BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased').to(device)# 评估未经微调的BERT的性能
def china_capital():question, text = "What is the capital of China?", "The capital of China is Beijing."inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")with torch.no_grad():outputs = model(**inputs.to(device))answer_start_index = torch.argmax(outputs.start_logits)answer_end_index = torch.argmax(outputs.end_logits) + 1predict_answer_tokens = inputs['input_ids'][0][answer_start_index:answer_end_index]predicted_answer = tokenizer.decode(predict_answer_tokens)print("What is the capital of China?", predicted_answer)
china_capital()    from transformers import BertTokenizer, BertForQuestionAnswering, AdamW
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data.processors.squad import SquadV2Processor, SquadExample, squad_convert_examples_to_features# 加载SQuAD 2.0数据集的特征
import pickle
with open('SQuAD_train_features.pkl', 'rb') as f:train_features = pickle.load(f)# 定义训练参数
train_batch_size = 8
num_epochs = 3
learning_rate = 3e-5# 将特征转换为PyTorch张量
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in train_features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)train_dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_start_positions, all_end_positions)
num_samples = 100
train_dataset = TensorDataset(all_input_ids[:num_samples], all_attention_mask[:num_samples], all_token_type_ids[:num_samples], all_start_positions[:num_samples], all_end_positions[:num_samples])
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size)# 加载BERT模型和优化器
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased').to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)# 微调BERT
for epoch in range(num_epochs):for step, batch in enumerate(train_dataloader):model.train()optimizer.zero_grad()input_ids, attention_mask, token_type_ids, start_positions, end_positions = tuple(t.to(device) for t in batch)outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, start_positions=start_positions, end_positions=end_positions)loss = outputs.lossloss.backward()optimizer.step()# Print the training loss every 500 stepsif step % 5 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")china_capital() # 保存微调后的模型
model.save_pretrained("SQuAD_finetuned_bert")

运行结果如下:

C:\Users\Lenovo\anaconda3\envs\pytorch211\python.exe "BERT_SQuAD_Finetuned.py" 
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
What is the capital of China? Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
C:\Users\Lenovo\anaconda3\envs\pytorch211\Lib\site-packages\transformers\optimization.py:429: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warningwarnings.warn(Epoch [1/3], Step [1/13], Loss: 5.9771
Epoch [1/3], Step [6/13], Loss: 5.2767
Epoch [1/3], Step [11/13], Loss: 4.7677
Epoch [2/3], Step [1/13], Loss: 3.9339
Epoch [2/3], Step [6/13], Loss: 4.0097
Epoch [2/3], Step [11/13], Loss: 4.0402
Epoch [3/3], Step [1/13], Loss: 2.9276
Epoch [3/3], Step [6/13], Loss: 2.9915
Epoch [3/3], Step [11/13], Loss: 2.4421
What is the capital of China? beijing

其中,我们的问题是:What is China's capital?

提供的上下文是: The capital of China is Beijing.

微调 BERT 前回答:空

微调 BERT 后回答: beijing

👍这次回答对了!

微调前让 BERT 回答 问题,它给的答案是空,再经过8批次3轮的训练之后,它就能准确的回答出是 beijing。真赞!

四、总结

我们首先学习了 BERT 的原理和架构,它是通过 MLM 和 NSP 两种预训练模式增加了其对语言的全局理解能力,接着学习了 BERT 的预训练和微调过程。接着,我们进行了两个实战,实战一使用原始 BERT 完成问答任务,实战二通过 SQuAD 数据集微调后的 BERT 来完成问答任务。从结果可以看到,经过具体任务微调之后的 BERT 能够给出更好的答案。

建议没玩过 BERT 的同学跑一遍代码,整体感受下模型的推理和微调的过程。

五、参考及附件

开头问题参考:

  1. BERT 的预训练过程是如何完成的,在预训练过程中,采用了哪两种任务?

    • 参考 BERT 预训练流程图,采用了 MLM 和 NSP 两种任务。
  2. 本次实战是用 SQuAD 数据集微调 BERT, 来完成我们的问答任务,你能否用 IMDB 影评数据集来微调 BERT,改进 BERT 的结果准确率?

    • 这个任务就由同学们自行完成,可先学习 HuggingFace 核心组件及应用实战,了解如何通过 HuggingFace 和 IMDB 数据集来完成影评情绪判断任务。

内容参考:黄佳老师的《ChatGPT和预训练模型课》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/828825.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频智能分析网关视频分析网关区域人数统计检测算法探析

随着城市化进程的加快和公共安全管理需求的提升,对公共场所、工业区域等人流量密集场所的监控和管理变得尤为重要。传统的视频监控系统已经无法满足现代智能化管理的需求,市场迫切需要一种能够实现实时监控、智能分析和自动报警的高效解决方案。基于此,区域人数统计视频分析…

SDN实验报告

SDN上机实验 实验目的能够使用Mininet的实现网络拓扑构建;熟悉Open vSwitch交换机的基本配置;熟悉OpenFlow协议的通信原理掌握pox控制器的基本使用方法;掌握Ryu控制期的基本使用方法;掌握北向应用的基本开发方法实验环境 基础环境选择ubuntu-20.04.6-desktop-amd64 实验内容…

双11买ToDesk远程控制&云电脑,看这一篇就够了!

今年双十一各大商家实在是太卷了,预售定金满减凑单一堆花活。但小编发现ToDesk远程控制&云电脑的双十一活动不一般。 囊括了远程控制各种会员版本的年包优惠,云电脑的计时机包时机活动,充值还送钱,优惠力度大,而且直接减钱,不费脑子就拿下超值价格。小编给大家简单整理了…

Java 面试用什么项目?全是商场秒杀 RPC,我吐了

看了几百份简历,真的超过 90% 的小伙伴的项目是商城、RPC、秒杀、论坛、外卖、点评等等烂大街的项目,人人都知道这些项目烂大街了,但大部分同学还是得硬着头皮做,没办法,网络上能找到的、教程比较完善的就这些项目了,做的话好歹有个项目,不做那就真能写学校做的垃圾学生…

极狐GitLab 签约某清洁能源高科技企业,助力零碳技术开创更加美好的零碳世界

客户背景 该客户是一家全球领先的清洁能源高科技公司,总部位于江苏省。公司自成立之初就致力于为全球客户提供清洁、高效、安全的能源解决方案,希望能用高科技技术让新能源发挥更大价值,让世界变得更加美好。当前,该客户在多个能源领域都有领先的产品和成熟的解决方案,也一…

DAC8568IAPWR 数据手册 具有 2.5V、2ppm/C 内部基准电压的 DAC7568、DAC8168、DAC8568 12/14/16 位、8 通道、超低毛刺、电压输出数模转换器芯片

DAC7568、DAC8168 和 DAC8568 分别为 12 位、14位和 16 位低功耗、电压输出、八通道数模转换器(DAC)。这些器件包括一个 2.5V、2ppm/C 内部基准电压(默认禁用),可提供 2.5V 或 5V 的满量程输出电压范围。内部基准电压初始精度为 0.004%,而且可在 VREFIN/VREFOUT 引脚上提供…

ue4资产序列化从入门到精通: 第一章 初识序列化

一、写作目的:(全文字数4926,阅读大约需25min) 首先,我有一个相关的需求要做,然后在拜读了网络上各大UE4序列化解析的文章后,发现大都讲的很模糊,对新入序列化大门的小白非常不友好。有的直接贴上一大段代码(好似直接糊脸上的不解释连招),也有的讲着讲着嘎然而止,也有的…

Hadoop及Spark环境配置与运行实例

本文章为Hadoop与Spark环境配置及Hadoop环境下使用mapreduce进行wordcount、Spark环境下使用KMeans进行鸢尾花数据集聚类实例运行实验记录。一、参考资料重要说明本文章为大数据分析课程实验之Hadoop与Spark平台配置记录及示例演示,其中Hadoop配置部分绝大多数内容源自参考资料…

salesforce零基础学习(一百四十一)刷新dev sandbox需要强制group

本篇参考:https://help.salesforce.com/s/articleView?id=sf.data_sandbox_selective_access.htm&type=5 背景:最近同事刷新sandbox发现点击create不生效,并且无任何提示(后续可能优化)。习惯了直接创建或者刷新的老司机们可能看不出来Sandbox Access标红提示来着,恰…

TPS26600PWPR 数据手册 一款集成反向输入极性保护的 工业电子保险丝芯片 浪涌保护器

TPS2660x 器件是一系列功能丰富的紧凑型高电压电子保险丝,具有一整套保护 功能)。4.2V 至 60V 的宽电源输入范围可实现对众多常用直流总线电压的控制。器件可以承受并保护由高达 60V 的正负电源供电的负载。集成的背靠背 FET 提供反向电流阻断功能,因此器件非常适合在电源故…

因为采购同行,造成的一次Java heap space 堆内存溢出

Caused by: java.sql.SQLException: Java heap space不多说了,没见过这样的。 报错原因是JVM内存XMX超了 Xms512m -Xmx2048m下班记得打卡