bert 相似度任务训练简单版本,faiss 寻找相似 topk

目录

任务

代码

train.py

predit.py

faiss 最相似的 topk


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5# 准备数据集
class MyDataset(Dataset):def __init__(self, data_file_paths):self.texts = []self.labels = []# 分词器用默认的self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')# 自己实现对数据集的解析with open(data_file_paths, 'r', encoding='utf-8') as f:for line in f:text1, text2, label = line.split('\t')self.texts.append((text1, text2))self.labels.append(int(label))def __len__(self):return len(self.texts)def __getitem__(self, idx):text1, text2 = self.texts[idx]label = self.labels[idx]encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')return encoded_text, label# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0for epoch in range(epoch):trained_data = 0for batch in train_loader:inputs, labels = batch# 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉inputs['input_ids'] = inputs['input_ids'].squeeze(1)inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)# 因为要用GPU,将数据传输到gpu上inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(**inputs, labels=labels)loss, logits = outputs[:2]loss.backward()optimizer.step()trained_data += len(labels)trained_process = float(trained_data) / len(train_dataset)batch_after_last_save += 1total_batch += 1# 每训练 auto_save_batch 个 batch,保存一次模型if batch_after_last_save >= auto_save_batch:batch_after_last_save = 0model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))total_epoch += 1model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePredictiontokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')with torch.no_grad():with open('../data/test.txt', 'r', encoding='utf8') as f:lines = f.readlines()correct = 0for i, line in enumerate(lines):text1, text2, label = line.split('\t')encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')outputs = model(**encoded_text)res = torch.argmax(outputs.logits, dim=1).item()print(text1, text2, label, res)if str(res) == label.strip('\n'):correct += 1print(f'{i + 1}/{len(lines)}')print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

faiss 最相似的 topk

使用 faiss 寻找 topk 相似的,从结果上看最相似的基本都还是找到排到较为靠前的位置

import torch
import faiss
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel# 假设有一个数据集df,其中包含'index'列和'text'列
df = pd.read_csv('../data/DataAnalyst.csv', encoding='gbk')  # 根据实际情况加载数据集
df = df.dropna().drop_duplicates().reset_index()
df['index'] = df.index
df = df[['index', '公司所在商区']]  # 保留所需列
df['公司所在商区'] = df['公司所在商区'].map(lambda row: ','.join(eval(row)))# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')# 加载微调好的模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertModel.from_pretrained('../output/cn_equal_model_3_171.pth')
model.eval()# 将数据集转化为模型所需的格式并计算所有样本的向量表示
def encode_texts(df):text_vectors = []for index, row in df.iterrows():text = row['公司所在商区']inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')with torch.no_grad():embeddings = model(**inputs.to(device))['last_hidden_state'][:, 0]text_vectors.append(embeddings.cpu().numpy())print(f'{index + 1}/{len(df)}')return np.vstack(text_vectors)# 加载数据集并计算所有样本的向量
print('enbedding all data...')
all_embeddings = encode_texts(df)# 初始化Faiss索引
print('init faiss all embedding...')
index = faiss.IndexFlatIP(all_embeddings.shape[1])  # 使用内积空间,适用于余弦相似度
index.add(all_embeddings)
print('init faiss all embedding finish~~~')# 定义查找最相似样本的函数
def find_top_k_similar(query_text, k=100):print('当前 query_text embedding.')query_embedding = encode_single_text(query_text)print('begin to search topk....')D, I = index.search(query_embedding, k)  # 返回距离和索引top_k_indices = df.iloc[I[0]].index.tolist()  # 将索引转换为原始数据集的索引return top_k_indices# 编码单个文本的函数
def encode_single_text(text):inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')with torch.no_grad():embedding = model(**inputs.to(device))['last_hidden_state'][:, 0].cpu().numpy()print('当前 query_text embedding finish!')return embedding# 示例:找一个query_text的top10相似样本
query_text = "左家庄,国展,西坝河"
top10_indices = find_top_k_similar(query_text)
# 获取与查询文本最相似的前10条原始文本
top10_texts = [df.loc[index, '公司所在商区'] for index in top10_indices]print(f"与'{query_text}'最相似的前100条样本及其文本:")
for i, (idx, text) in enumerate(zip(top10_indices, top10_texts)):print(f"{i+1}. 索引:{idx},文本:{text}")

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
链接:https://pan.baidu.com/s/1CTntG1Z6AIhiPt6i8Ad97Q 
提取码:x6sz 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/511431.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot+vue实现民宿管理系统项目【项目源码+论文说明】

基于springbootvue民宿管理系统演示 摘要 伴随着我国旅游业的快速发展,民宿已成为最受欢迎的住宿方式之一。民宿借助互联网和移动设备的发展,展现出强大的生命力和市场潜力。民宿主要通过各种平台如携程、去哪儿、淘宝等在网络上销售线下住宿服务&#…

伪创新的迷惑手法-UMLChina建模知识竞赛第5赛季第6轮

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 参考潘加宇在《软件方法》和UMLChina公众号文章中发表的内容作答。在本文下留言回答。 只要最先答对前3题,即可获得本轮优胜。 如果有第4题,第4题为附加题&am…

Git入门学习笔记

Git 是一个非常强大的分布式版本控制工具! 在下载好Git之后,鼠标右击就可以显示 Git Bash 和 Git GUI,Git Bash 就像是在电脑上安装了一个小型的 Linux 系统! 1. 打开 Git Bash 2. 设置用户信息(这是非常重要的&…

JS逆向进阶篇【去哪儿旅行登录】【下篇-逆向Bella参数JS加密逻辑Python生成】

目录: 每篇前言:引子——本篇目的1、 代码混淆和还原(1)单独替换:(2)整个js文件替换: 2、算法入口分析3、 深入分析(0)整体分析:(1&am…

在K8S集群中部署SkyWalking

1. 环境准备 K8S 集群kubectlhelm 2. 为什么要部署SkyWalking? 我也不道啊,老板说要咱就得上啊。咦,好像可以看到服务的各项指标,像SLA,Apdex这些,主要是能够进行请求的链路追踪,bug排查的利…

stressapptest源码剖析:主函数main解析和sat类头文件分析

主函数main解析和sat类头文件分析 一、简介二、入口函数main.cc剖析三、SAT压力测试对象接口和数据结构总结 一、简介 stressapptest(简称SAT)是一种用于在Linux系统上测试系统稳定性和可靠性的工具,通过产生CPU、内存、磁盘等各种负载来测试…

制造业数字化赋能:1核心2关键3层面4方向

随着科技的飞速发展,制造业正站在数字化转型的风口浪尖。数字化转型不仅关乎企业效率与利润,更决定了制造业在全球竞争中的地位。那么,在这场波澜壮阔的数字化浪潮中,制造业如何抓住机遇,乘风破浪?本文将从…

【3GPP】【核心网】【5G】5G核心网协议解析(一)(超详细)

1. 5G核心网概念 5G核心网是支撑5G移动通信系统的关键组成部分,是实现5G移动通信的重要基础设施,它负责管理和控制移动网络中的各种功能和服务。它提供了丰富的功能和服务,支持高速、低时延、高可靠性的通信体验,并为不同行业和应…

专科生去华为面试,后续来了。。。

专科生去华为面试,后续来了。。。 大家好,我是銘,全栈开发程序员。 今天我正上班呢,一个之前的同事给我发信息,说他去华为面试了,我听到这个消息有点懵逼,我和他是同一年毕业的,我…

基于粒子群优化算法的图象聚类识别matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于粒子群优化算法的图象聚类识别。通过PSO优化方法,将数字图片的特征进行聚类,从而识别出数字0~9. 2.测试软件版本以及运行结果展示 M…

StarRocks——中信建投统一查询服务平台构建

目录 一、需求背景 1.1 数据加工链路复杂 1.2 大数据量下性能不足,查询响应慢 1.3 大量实时数据分散在各个业务系统,无法进行联合分析 1.4 缺少与预计算能力加速一些固定查询 二、构建统一查询服务平台 三、落地后的效果与价值 四、项目经验总结…

用docker部署后端项目

一、搭建局域网 1.1、介绍前后端项目搭建 需要4台服务器,在同一个局域网中 1.2、操作 # 搭建net-ry局域网,用于部署若依项目 net-ry:名字 docker network create net-ry --subnet172.68.0.0/16 --gateway172.68.0.1#查看 docker network ls…