4.2 文本相似度(三)

换个思路,再训练一次。

1 基本框架

试想,如果有一个语句需要从预料库中匹配,每一次匹配都会伴随着大量的耗时:

一次匹配20ms, 1 000 000次呢,1 000 000 *20/ 1000 = 20 000S ~5.56H。效率极其的低:

使用如下策略解决:候选文本与匹配文本分别进入模型(并行)然后输出两个向量,通过cos相似度得到是否匹配结果;

2  使用方法介绍

2.1 CosineSimilarity

   pytorch.nn.CosineSimilarity(余弦相似度)是一种常用的相似度度量方法,用于衡量两个向量之间的相似程度。在自然语言处理和信息检索等领域,余弦相似度常用于比较文本、向量表示或特征之间的相似性。

2.2 CosineEmbeddingLoss

3 代码

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset,load_from_disk
import traceback
import torchfrom sklearn.model_selection import train_test_split#dataset = load_dataset("json", data_files="../data/train_pair_1w.json", split="train")
dataset = load_dataset("csv", data_files="/Users/user/studyFile/2024/nlp/text_similar/data/Chinese_Text_Similarity.csv", split="train")
datasets = dataset.train_test_split(test_size=0.2,shuffle=True)tokenizer = AutoTokenizer.from_pretrained("../chinese_macbert_base")
def process_function(examples):sentences = []labels = []for sen1, sen2, label in zip(examples["sentence1"], examples["sentence2"], examples["label"]):sentences.append(sen1)sentences.append(sen2)labels.append(1 if int(label) == 1 else -1)# input_ids, attention_mask, token_type_idstokenized_examples = tokenizer(sentences, max_length=128, truncation=True, padding="max_length")# tokenized_examples format likes {'input_ids':[[],[],[]], 'token_type_ids':[[],[],[]],'attention_mask':[[],[],[]]}# k : input_ids, v: [[],[],[],[]]tokenized_examples = {k: [v[i: i + 2] for i in range(0, len(v), 2)] for k, v in tokenized_examples.items()}tokenized_examples["labels"] = labelsreturn tokenized_examplestokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
tokenized_datasets#搭建模型
from transformers import BertForSequenceClassification, BertPreTrainedModel, BertModel
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from torch.nn import CosineSimilarity, CosineEmbeddingLoss
import torch
class DualModel(BertPreTrainedModel):def __init__(self, config: PretrainedConfig, *inputs, **kwargs):super().__init__(config, *inputs, **kwargs)self.bert = BertModel(config)self.post_init()def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,):return_dict = return_dict if return_dict is not None else self.config.use_return_dict# Step1 分别获取sentenceA 和 sentenceB的输入senA_input_ids, senB_input_ids = input_ids[:, 0], input_ids[:, 1]senA_attention_mask, senB_attention_mask = attention_mask[:, 0], attention_mask[:, 1]senA_token_type_ids, senB_token_type_ids = token_type_ids[:, 0], token_type_ids[:, 1]# Step2 分别获取sentenceA 和 sentenceB的向量表示senA_outputs = self.bert(senA_input_ids,attention_mask=senA_attention_mask,token_type_ids=senA_token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)senA_pooled_output = senA_outputs[1]    # [batch, hidden]senB_outputs = self.bert(senB_input_ids,attention_mask=senB_attention_mask,token_type_ids=senB_token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)senB_pooled_output = senB_outputs[1]    # [batch, hidden]# step3 计算相似度cos = CosineSimilarity()(senA_pooled_output, senB_pooled_output)    # [batch, ]# step4 计算lossloss = Noneif labels is not None:loss_fct = CosineEmbeddingLoss(0.3)loss = loss_fct(senA_pooled_output, senB_pooled_output, labels)output = (cos,)return ((loss,) + output) if loss is not None else outputmodel = DualModel.from_pretrained("../chinese_macbert_base")import evaluate
acc_metric = evaluate.load("./metric_accuracy.py")
f1_metirc = evaluate.load("./metric_f1.py")def eval_metric(eval_predict):predictions, labels = eval_predictpredictions = [int(p > 0.7) for p in predictions]labels = [int(l > 0) for l in labels]# predictions = predictions.argmax(axis=-1)acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metirc.compute(predictions=predictions, references=labels)acc.update(f1)return acc
# 模型参数
train_args = TrainingArguments(output_dir="./dual_model",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizeper_device_eval_batch_size=32,  # 验证时的batch_sizelogging_steps=10,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数learning_rate=2e-5,              # 学习率weight_decay=0.01,               # weight_decaymetric_for_best_model="f1",      # 设定评估指标load_best_model_at_end=True)     # 训练完成后加载最优模型trainer = Trainer(model=model, args=train_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], compute_metrics=eval_metric)
trainer.train()

        太耗时了,没有GPU.....

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/697763.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PB案例学习笔记】-01创建应用、窗口与控件

写在前面 这是PB案例学习笔记系列文章的第一篇,也是最基础的一篇。后续文章中【创建程序基本框架】部分操作都跟这篇文章一样, 将不再重复。该系列文章是针对具有一定PB基础的读者,通过一个个由浅入深的编程实战案例学习,提高编…

今天开发了一款软件,我竟然只用敲了一个字母(文末揭晓)

软件课题:Python实现打印100内数学试题软件及开发过程 一、需求管理: 1.实现语言:Python 2.打印纸张:A4 3.铺满整张纸 4.打包成exe 先看效果: 1. 2.电脑打印预览 3.打印到A4纸效果(晚上拍的&#x…

JavaEE初阶-多线程5

文章目录 一、线程池1.1 线程池相关概念1.2 线程池标准类1.3 线程池工厂类1.4 实现自己的线程池 二、定时器2.1 java标准库中的定时器使用2.2 实现一个自己的定时器2.2.1 定义任务类2.2.2 定义定时器 一、线程池 1.1 线程池相关概念 池这个概念在计算机中比较常见&#xff0c…

AI网络爬虫:用kimichat自动批量提取网页内容

首先,在网页中按下F12键,查看定位网页元素: 然后在kimi中输入提示词: 你是一个Python编程专家,要完成一个爬取网页内容的Python脚本,具体步骤如下: 在F盘新建一个Excel文件:提示词…

MySQL基础使用指南

难度就是价值所在。大家好,今天给大家分享一下关于MySQL的基础使用,MySQL 是一个流行的关系型数据库管理系统,被广泛应用于各种类型的应用程序开发中。本文中将介绍 MySQL 的基础使用方法,包括创建数据库、创建表格以及进行增删改…

Hive的join操作

假设有三张表,结构和数据如下:-- 创建表 test_a,test_b,test_c CREATE TABLE test_a( id int, name string ) ROW FORMAT DELIMITED FIELDS TERMINATED BY \t;--分别导入数据到三个表中 --test_a 1 a1 2 a2 4 a4 --test_b 1 b1 3 b3 4 b4 --…

LeetCode 力扣题目:买卖股票的最佳时机 IV

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

针对关键 PuTTY 私钥恢复漏洞的 PoC 发布

安全研究人员针对广泛使用的 PuTTY SSH 和 Telnet 客户端中的一个关键漏洞发布了概念验证 (PoC) 漏洞利用。 该漏洞CVE-2024-31497允许攻击者恢复 PuTTY 版本 0.68 至 0.80 中使用 NIST P-521 椭圆曲线生成的私钥。 该漏洞源于 PuTTY在使用 P-521 曲线时偏向生成ECDSA随机数。…

ATA-308C功率放大器的基本原理和性能参数是什么

功率放大器是一种用于放大电信号功率的电子器件。它将输入的小信号电压或电流经过放大后,输出一个较大的电信号功率,以驱动负载或其他设备。功率放大器在各个领域中都有广泛应用,例如音频放大器、无线通信系统、工业控制等。 功率放大器的基本…

Elasticsearch解决字段膨胀问题

文章目录 背景Flattened类型的产生Flattened类型的定义基于Flattened类型插入数据更新Flattened字段并添加数据Flattened类型检索 Flattened类型的不足 背景 Elasticsearch映射如果不进行特殊设置,则默认为dynamic:true。dynamic:true实际上支持不加约束地动态添加…

PCIE协议-2-事务层规范-Completion Rules

2.2.9 完成规则 所有Read、Non-Posted Write和AtomicOp请求都需要完成(Completion)。完成包含一个完成头标,对于某些类型的完成,完成头标之后会跟随一定数量的DWs数据。完成头标的每个字段的规则在以下各节中定义。 完成通过ID路…

打破边界:Facebook的社交实验与未来愿景

数字化时代,社交媒体已经成为人们日常生活的重要组成部分,而Facebook作为其中的佼佼者,一直在积极探索社交领域的新可能性。本文将探讨Facebook在社交实验和未来愿景方面的努力,以及其如何打破传统边界,开拓社交的新领…