bert 相似度任务训练,简单版本

目录

任务

代码

train.py

predit.py

数据


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5# 准备数据集
class MyDataset(Dataset):def __init__(self, data_file_paths):self.texts = []self.labels = []# 分词器用默认的self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')# 自己实现对数据集的解析with open(data_file_paths, 'r', encoding='utf-8') as f:for line in f:text1, text2, label = line.split('\t')self.texts.append((text1, text2))self.labels.append(int(label))def __len__(self):return len(self.texts)def __getitem__(self, idx):text1, text2 = self.texts[idx]label = self.labels[idx]encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')return encoded_text, label# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0for epoch in range(epoch):trained_data = 0for batch in train_loader:inputs, labels = batch# 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉inputs['input_ids'] = inputs['input_ids'].squeeze(1)inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)# 因为要用GPU,将数据传输到gpu上inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(**inputs, labels=labels)loss, logits = outputs[:2]loss.backward()optimizer.step()trained_data += len(labels)trained_process = float(trained_data) / len(train_dataset)batch_after_last_save += 1total_batch += 1# 每训练 auto_save_batch 个 batch,保存一次模型if batch_after_last_save >= auto_save_batch:batch_after_last_save = 0model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))total_epoch += 1model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePredictiontokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')with torch.no_grad():with open('../data/test.txt', 'r', encoding='utf8') as f:lines = f.readlines()correct = 0for i, line in enumerate(lines):text1, text2, label = line.split('\t')encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')outputs = model(**encoded_text)res = torch.argmax(outputs.logits, dim=1).item()print(text1, text2, label, res)if str(res) == label.strip('\n'):correct += 1print(f'{i + 1}/{len(lines)}')print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/502383.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI也来打掼蛋,难道人工智能也能当领导?

引言:探索AI在复杂卡牌游戏中的决策能力 在人工智能(AI)的研究领域中,游戏被视为现实世界的简化模型,常常是研究的首选平台。这些研究主要关注游戏代理的决策过程。例如,中国的传统卡牌游戏“掼蛋”&#…

Windows Docker 部署 Jenkins

一、简介 今天介绍一下在 Windows Docker 中部署 Jenkins 软件。在 Windows Docker 中,分为两种情况 Linux 容器和 Windows 容器。Linux 容器是通常大多数使用的方式,Windows 容器用于 CI/CD 依赖 Windows 环境的情况。 二、Linux 容器 Linux 容器内部…

LeetCode -- 79.单词搜索

1. 问题描述 给定一个 m x n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中,返回 true ;否则,返回 false 。 单词必须按照字母顺序,通过相邻的单元格内的字母构成,其中“相邻”单元格是那些水…

【前沿热点视觉算法|Sora|GPT4一键升级】一种新的图像分割方法:具有边界注意的两级解码网络

计算机视觉算法分享。问题或建议,请文章私信或者文章末尾扫码加微信留言。sora 具体介绍和使用方法:OpenAI Sora 下一代生产力:最新小白必看教程 | 解剖Sora的前世今生 | Sora核心源码目前 openai 官方还未开放 sora 灰度,不过根据…

黑马JavaWeb课程中安装vue脚手架出现的问题

1 安装node.js 要想前端工程化,必须安装node.js,前端工程化的环境。 在成功安装node.js后, 修改全局包安装路径为Node.js安装目录, 修改npm镜像源为淘宝镜像源,这里出现第一个问题,视频中给的淘宝镜像为&…

抽象类、模板方法模式

抽象类概述 在Java中abstract是抽象的意思,如果一个类中的某个方法的具体实现不能确定,就可以申明成abstract修饰的抽象方法(不能写方法体了),这个类必须用abstract修饰,被称为抽象类。 抽象方法定义&…

WSL2编译RV1126 SDK

接上一篇《WSL2部署RV1126 SDK编译环境》 1 编译配置 ./build.sh device/rockchip/rv1126_rv1109/aio-rv1126-jd4.mk 2 关闭Qt(可选) vim buildroot/configs/firefly_rv1126_rv1109_defconfig 3 启用ROS(可选) vim buildroot/conf…

链表基础知识详解(非常详细简单易懂)

概述: 链表作为 C 语言中一种基础的数据结构,在平时写程序的时候用的并不多,但在操作系统里面使用的非常多。不管是RTOS还是Linux等使用非常广泛,所以必须要搞懂链表,链表分为单向链表和双向链表,单向链表很…

Rocky Linux 安装部署 Zabbix 6.4

一、Zabbix的简介 Zabbix是一种开源的企业级监控解决方案,用于实时监测服务器、网络设备和应用程序的性能和可用性。它提供了强大的数据收集、处理和可视化功能,同时支持事件触发、报警通知和自动化任务等功能。Zabbix易于安装和配置,支持跨平…

HTTPS是什么,详解它的加密过程

目录 1.前言 2.两种加密解密方式 2.1对称加密 2.2非对称加密 3.HTTPS的加密过程 3.1针对明文的对称加密 3.2针对密钥的非对称加密 3.3证书的作用 1.前言 我们知道HTTP协议是超文本传输协议,它被广泛的应用在客户端服务器上,用来传输文字,图片,视频,js,html等.但是这种传…

【牛客】VL60 使用握手信号实现跨时钟域数据传输

题目描述 分别编写一个数据发送模块和一个数据接收模块,模块的时钟信号分别为clk_a,clk_b。两个时钟的频率不相同。数据发送模块循环发送0-7,在每个数据传输完成之后,间隔5个时钟,发送下一个数据。请在两个模块之间添加…

Harbor高可用(haproxy和keepalived)

Harbor高可用(haproxy和keepalived) 文章目录 Harbor高可用(haproxy和keepalived)1.Harbor高可用集群部署架构1.1 主机初始化1.1.1 设置网卡名和ip地址1.1.2 设置主机名1.1.3 配置镜像源1.1.4 关闭防火墙1.1.5 禁用SELinux1.1.6 设…