用pytorch进行BERT文本分类

BERT 是一个强大的语言模型,至少有两个原因:

  1. 它使用从 BooksCorpus (有 8 亿字)和 Wikipedia(有 25 亿字)中提取的未标记数据进行预训练。
  2. 顾名思义,它是通过利用编码器堆栈的双向特性进行预训练的。这意味着 BERT 不仅从左到右,而且从右到左从单词序列中学习信息。

BERT 模型需要一系列 tokens (words) 作为输入。在每个token序列中,BERT 期望输入有两个特殊标记:

  • [CLS] :这是每个sequence的第一个token,代表分类token。
  • [SEP] :这是让BERT知道哪个token属于哪个序列的token。这一特殊表征法主要用于下一个句子预测任务或问答任务。如果我们只有一个sequence,那么这个token将被附加到序列的末尾。

就像Transformer的普通编码器一样,BERT 将一系列单词作为输入,这些单词不断向上流动。每一层都应用自我注意,并将其结果通过前馈网络传递,然后将其传递给下一个编码器。

 

BERT 输出

每个位置输出一个大小为 hidden_ size的向量(BERT Base 中为 768)。对于我们在上面看到的句子分类示例,我们只关注第一个位置的输出(将特殊的 [CLS] token 传递到该位置)。

该向量现在可以用作我们选择的分类器的输入。该论文仅使用单层神经网络作为分类器就取得了很好的效果。

使用 BERT 进行文本分类

本文的主题是用 BERT 对文本进行分类。

在这篇文章中,我们将使用kaggle上的BBC 新闻分类数据集。

数据集已经是 CSV 格式,它有 2126 个不同的文本,每个文本都标记在 5 个类别中的一个之下:

sport(体育),business(商业),politics(政治),tech(科技),entertainment(娱乐)。


模型下载 https://huggingface.co/bert-base-cased/tree/main

数据集下载 bbc-news https://huggingface.co/datasets/SetFit/bbc-news/tree/main

 有4个400多MB的文件,pytorch的模型对应的是436MB的那个文件。

需要安装transforms库

pip install transforms 

全部的流程代码:

# 全部流程代码
import numpy as np
import torch
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
labels = {'business':0,'entertainment':1,'sport':2,'tech':3,'politics':4}class Dataset(torch.utils.data.Dataset):def __init__(self, df):self.labels = [labels[label] for label in df['category']]self.texts = [tokenizer(text, padding='max_length', max_length = 512, truncation=True,return_tensors="pt") for text in df['text']]def classes(self):return self.labelsdef __len__(self):return len(self.labels)def get_batch_labels(self, idx):# Fetch a batch of labelsreturn np.array(self.labels[idx])def get_batch_texts(self, idx):# Fetch a batch of inputsreturn self.texts[idx]def __getitem__(self, idx):batch_texts = self.get_batch_texts(idx)batch_y = self.get_batch_labels(idx)return batch_texts, batch_y# 数据集准备
# 拆分训练集、验证集和测试集 8:1:1
import pandas as pd
bbc_text_df = pd.read_csv('./bbc-news/bbc-text.csv')
# bbc_text_df.head()
df = pd.DataFrame(bbc_text_df)
np.random.seed(112)
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), [int(.8*len(df)), int(.9*len(df))])
print(len(df_train),len(df_val), len(df_test))
# 1780 222 223# 构建模型
from torch import nn
from transformers import BertModelclass BertClassifier(nn.Module):def __init__(self, dropout=0.5):super(BertClassifier, self).__init__()self.bert = BertModel.from_pretrained('bert-base-cased')self.dropout = nn.Dropout(dropout)self.linear = nn.Linear(768, 5)self.relu = nn.ReLU()def forward(self, input_id, mask):_, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)dropout_output = self.dropout(pooled_output)linear_output = self.linear(dropout_output)final_layer = self.relu(linear_output)return final_layer#从上面的代码可以看出,BERT Classifier 模型输出了两个变量:
#1. 在上面的代码中命名的第一个变量_包含sequence中所有 token 的 Embedding 向量层。
#2. 命名的第二个变量pooled_output包含 [CLS] token 的 Embedding 向量。对于文本分类任务,使用这个 Embedding 作为分类器的输入就足够了。
# 然后将pooled_output变量传递到具有ReLU激活函数的线性层。在线性层中输出一个维度大小为 5 的向量,每个向量对应于标签类别(运动、商业、政治、 娱乐和科技)。from torch.optim import Adam
from tqdm import tqdmdef train(model, train_data, val_data, learning_rate, epochs):# 通过Dataset类获取训练和验证集train, val = Dataset(train_data), Dataset(val_data)# DataLoader根据batch_size获取数据,训练时选择打乱样本train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)# 判断是否使用GPUuse_cuda = torch.cuda.is_available()device = torch.device("cuda" if use_cuda else "cpu")# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=learning_rate)if use_cuda:model = model.cuda()criterion = criterion.cuda()# 开始进入训练循环for epoch_num in range(epochs):# 定义两个变量,用于存储训练集的准确率和损失total_acc_train = 0total_loss_train = 0# 进度条函数tqdmfor train_input, train_label in tqdm(train_dataloader):train_label = train_label.to(device)mask = train_input['attention_mask'].to(device)input_id = train_input['input_ids'].squeeze(1).to(device)# 通过模型得到输出output = model(input_id, mask)# 计算损失batch_loss = criterion(output, train_label)total_loss_train += batch_loss.item()# 计算精度acc = (output.argmax(dim=1) == train_label).sum().item()total_acc_train += acc# 模型更新model.zero_grad()batch_loss.backward()optimizer.step()# ------ 验证模型 -----------# 定义两个变量,用于存储验证集的准确率和损失total_acc_val = 0total_loss_val = 0# 不需要计算梯度with torch.no_grad():# 循环获取数据集,并用训练好的模型进行验证for val_input, val_label in val_dataloader:# 如果有GPU,则使用GPU,接下来的操作同训练val_label = val_label.to(device)mask = val_input['attention_mask'].to(device)input_id = val_input['input_ids'].squeeze(1).to(device)output = model(input_id, mask)batch_loss = criterion(output, val_label)total_loss_val += batch_loss.item()acc = (output.argmax(dim=1) == val_label).sum().item()total_acc_val += accprint(f'''Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} | Train Accuracy: {total_acc_train / len(train_data): .3f} | Val Loss: {total_loss_val / len(val_data): .3f} | Val Accuracy: {total_acc_val / len(val_data): .3f}''') #我们对模型进行了 5 个 epoch 的训练,我们使用 Adam 作为优化器,而学习率设置为1e-6。
#因为本案例中是处理多类分类问题,则使用分类交叉熵作为我们的损失函数。
EPOCHS = 5
model = BertClassifier()
LR = 1e-6
train(model, df_train, df_val, LR, EPOCHS)# 测试模型
def evaluate(model, test_data):test = Dataset(test_data)test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)use_cuda = torch.cuda.is_available()device = torch.device("cuda" if use_cuda else "cpu")if use_cuda:model = model.cuda()total_acc_test = 0with torch.no_grad():for test_input, test_label in test_dataloader:test_label = test_label.to(device)mask = test_input['attention_mask'].to(device)input_id = test_input['input_ids'].squeeze(1).to(device)output = model(input_id, mask)acc = (output.argmax(dim=1) == test_label).sum().item()total_acc_test += accprint(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')# 用测试数据集进行测试    
evaluate(model, df_test)

 

EPOCHS = 5
model = BertClassifier()
LR = 1e-6
train(model, df_train, df_val, LR, EPOCHS)

在训练的这一步会非常耗时间,用GPU加速了,也需要大概39分钟.

因为BERT模型本身就是一个比较大的模型,参数非常多。

最后一步测试的时候,测试的准确率还是比较高的。达到 99.6%

 

模型的保存。这个在原文里面是没有提到的。

我们花了很多时间训练的模型如果不保存一下,下次还要重新训练岂不是费时费力?

# 保存模型
torch.save(model.state_dict(),"bertMy.pth")
# load 加载模型
model = BertClassifier()
model.load_state_dict(torch.load("bertMy.pth"))

 可以用下面的代码查看model里面的模型。

for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())

 也可以将保存的模型文件 bertMy.pth上传到netron网站进行模型可视化。Netronhttps://netron.app/

 

原文位于这个地址,但是原文中的代码缺了读csv那一段 :

保姆级教程,用PyTorch和BERT进行文本分类 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/8821.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vulkan Tutorial 10 重采样

目录 30 多重采样 获得可用的样本数 设置一个渲染目标 添加新的附件 30 多重采样 我们的程序现在可以为纹理加载多层次的细节,这修复了在渲染离观众较远的物体时出现的假象。现在的图像平滑了许多,然而仔细观察,你会发现在绘制的几何图形…

【C++】复杂的菱形继承 及 菱形虚拟继承的底层原理

文章目录 1. 单继承2. 多继承3. 菱形继承3.1 菱形继承的问题——数据冗余和二义性3.2 解决方法——虚拟继承3.3 虚拟继承的原理 4. 继承和组合5. 继承的反思和总结 1. 单继承 在上一篇文章中,我们给大家演示的其实都是单继承。 单继承的概念: 单继承&a…

Flutter如何获取屏幕的分辨率和实际画布的分辨率

Flutter如何获取分辨率 在Flutter中,你可以使用MediaQuery来获取屏幕的分辨率和实际画布的分辨率。 要获取屏幕的分辨率,你可以使用MediaQuery.of(context).size属性,它返回一个Size对象,其中包含屏幕的宽度和高度。下面是一个获…

POSTGRESQL SQL 执行用 IN 还是 EXISTS 还是 ANY

开头还是介绍一下群,如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题,有需求都可以加群群内有各大数据库行业大咖,CTO,可以解决你的问题。加群请联系 liuaustin3 ,在新加的朋友会分到3群(共…

SQL频率低但笔试会遇到: 触发器、索引、外键约束

一. 前言 在SQL面笔试中,对于表的连接方式,过滤条件,窗口函数等肯定是考察的重中之重,但是有一些偶尔会出现,频率比较低但是至少几乎会遇见一两次的题目,就比如触发器,索引和外键约束&#xff0…

Spring源码解析(二):bean容器的创建、默认后置处理器、扫描包路径bean

Spring源码系列文章 Spring源码解析(一):环境搭建 Spring源码解析(二): 目录 一、Spring源码基础组件1、bean定义接口体系2、bean工厂接口体系3、ApplicationContext上下文体系 二、AnnotationConfigApplicationContext注解容器1、创建bean工厂-beanFa…

【Nginx05】Nginx学习:HTTP核心模块(二)Server

Nginx学习:HTTP核心模块(二)Server 第一个重要的子模块就是这个 Server 相关的模块。Server 代表服务的意思,其实就是这个 Nginx 的 HTTP 服务端所能提供的服务。或者更直白点说,就是虚拟主机的配置。通过 Server &…

iview切换Select时选项丢失,重置Seletc时选项丢失

分析原因 在旧版本的iview中如果和filterable一起使用时,当值清空选项或者使用重置按钮清空时选项会丢失。 解决方式一 把去掉filterable 解决方式二 使用ref,调用clearSingleSelect()方法清空 ref"perfSelect" this.$refs.perfSelect.c…

Java链式编程

一、链式编程 1.1.释义 链式编程,也叫级联式编程,调用对象的函数时返回一个this对象指向对象本身,达到链式效果,可以级联调用。 1.2.特点 可以通过一个方法调用多个方法,将多个方法调用链接起来,形成一…

UE4中创建的瞄准偏移或者混合空间无法拖入动画

UE4系列文章目录 文章目录 UE4系列文章目录前言一、解决办法 前言 UE4 AimOffset(瞄准偏移)动画融合时,AimOffse动画拖入不了融合框的解决办法,你会发现动画无法拖入到融合框,ue4编辑器提示“Invalid Additive animation Type”,…

C#核心知识回顾——10.List、Dictionary、数据结构

1.List List<int> list new List<int>(); List<String> strings new List<String>();//增list.Add(0);list.Add(1);List<int> ints new List<int>();ints.Add(0);list.AddRange(ints);//插入list.Insert(0, 1);// 位置0插入1//删//1.移…

使用GPIO来模拟UART

前言 最近在看一些秋招的笔试和面试题&#xff0c;刚好看到一个老哥的经验贴&#xff0c;他面试的时候被问到了如果芯片串口资源不够了该怎么办&#xff1f;其实可以用IO口来模拟串口&#xff0c;但我之前也没有具体用代码实现过&#xff0c;借此机会用32开发板上的两个IO口来…