本地运行《使用AMD上的SentenceTransformers构建语义搜索》

Building semantic search with SentenceTransformers on AMD — ROCm Blogs

这篇博客解释了如何在Sentence Compression数据集上训练SentenceTransformers模型来执行语义搜索。使用BERT基础模型(不区分大小写)作为基础的变换器,并应用Hugging Face的PyTorch库。
训练这个自定义模型的目标是将其用于执行语义搜索。语义搜索是一种信息检索方法,它理解搜索查询的意图和上下文,而不仅仅是匹配关键词。例如,搜索“苹果派食谱”(查询)将返回关于如何制作苹果派的结果(文档),而不仅仅是包含“苹果”和“派”这些词的页面。
可以在这个[GitHub文件夹](https://github.com/ROCm/rocm-blogs/tree/release/blogs/artificial-intelligence/sentence_transformers_amd/)中找到与这篇博客文章相关的文件。

介绍SentenceTransformers

从头开始训练一个SentenceTransformers模型包括一个过程,即教导模型理解和编码句子为有意义的、高维度的向量。在这篇博客中,专注于一个包含等价句子对的数据集。总的来说,培训过程的目标是让模型学习如何将语义上相似的句子映射在向量空间中的接近位置,同时将不相似的句子分隔开。与可能无法捕获某些领域或用例的特定性质的通用预训练模型相比,自定义训练模型确保模型能够精确调整以理解与特定领域或应用相关的上下文和语义。
感兴趣的是执行非对称语义搜索。在这种方法中,模型承认查询和文档本质上可以是不同的。例如,具有简短查询和长文档。非对称语义搜索使用编码,使搜索更加有效,即使在文本类型或长度不匹配时也是如此。这对于信息从大型文档或数据库检索的应用非常有用,其中查询通常比他们搜索的内容更短且不那么详细。这里有一个语义搜索如何工作的例子:

查询:巴黎位于法国吗?
语料库中最相似的句子:
法国的首都是巴黎(得分:0.6829)
巴黎是欧洲的一个城市,有着传统和杰出的食物,是法国的首都(得分:0.6044)
澳大利亚以其传统和杰出的食物而闻名(得分:-0.0159)

基于AMD GPU的实现

案例在Ubuntu 22.04.3 LTS系统ROCm 6.0.2和PyTorch 2.3.0版本进行。

$ python
Python 3.12.1 | packaged by Anaconda, Inc. | (main, Jan 19 2024, 15:51:05) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'2.3.0+rocm6.0'

安装以下Python包:

pip install datasets ipywidgets -U transformers sentence-transformers

从HF-Mirror - Huggingface 镜像站下载Sentence Compression 数据集:

./hfd.sh embedding-data/sentence-compression --dataset --tool aria2c -x 4

句子压缩(Sentence Compression)数据集包含18万对等价句子。这些句子对演示了如何把较长的句子压缩成较短的句子,同时保持相同的含义。

导入Python包

from datasets import load_dataset  # 从datasets库中导入load_dataset方法  
from sentence_transformers import InputExample, util  # 从sentence_transformers库中导入InputExample和util模块  
from torch.utils.data import DataLoader  # 从torch库中导入DataLoader类  
from torch import nn  # 从torch库中导入神经网络相关模块  
from sentence_transformers import losses  # 从sentence_transformers库中导入losses模块  
from sentence_transformers import SentenceTransformer, models  # 从sentence_transformers库中导入SentenceTransformer和models模块

准备数据集:

dataset_id = "./sentence-compression"
dataset = load_dataset(dataset_id)

查看数据集中的一个样本:

# 查看一个样本
print(dataset['train']['set'][1])
['Major League Baseball Commissioner Bud Selig will be speaking at St. Norbert College next month.',
'Bud Selig to speak at St. Norbert College']

SentenceTransformers库要求数据集需要有特定的格式,确保数据与模型架构兼容。

创建训练样本的列表(使用数据集的一半来进行说明)。这种方法减少了计算负载并加速了训练过程。

# 转换数据集为所需格式
train_examples = []
train_data = dataset['train']['set']n_examples = dataset['train'].num_rows // 2  # 选择一半的数据集进行训练for example in train_data[:n_examples]:original_sentence = example[0]compressed_sentence = example[1]input_example = InputExample(texts=[original_sentence, compressed_sentence])train_examples.append(input_example)

实例化`DataLoader`类。这个类为提供了一个有效地迭代数据集的方式。

# 使用训练样本实例化DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

实现

在句子转换器模型中,目的是将不定长的输入句子映射为一个固定大小的向量。首先,将输入句子传递给一个转换模型。在这个例子中,使用了BERT基础模型(不区分大小写版本)作为基础转换模型,它会输出输入句子中每个词的上下文化嵌入向量。获取到每个词的嵌入向量后,使用汇聚层(Pooling layer)来将这些向量整合为一个单独的句子嵌入向量。最后,通过添加一个全连接层(具有双曲正切激活函数的dense层)进行额外的变换。这个层的作用是降低汇聚后的句子嵌入向量的维度,同时使用非线性激活函数让模型能够捕捉数据中更复杂的模式。

# 创建一个自定义模型
# 使用一个已存在的嵌入模型
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)# 对token嵌入向量使用汇聚函数
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())# 全连接层
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=256, activation_function=nn.Tanh())# 定义整体模型
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])

训练

在训练过程中,选择合适的损失函数是至关重要的,这取决于具体应用和数据集的结构。在这里,使用了`MultipleNegativesRankingLoss`函数。这个函数在句子的语义搜索应用中特别有用,因为模型需要根据句子的相关性对其进行排序。它的工作方式是将一对语义相似的句子(正例对)与多个语义不相似的句子进行对比。这个函数非常适合句子压缩数据集,因为它能够区分语义相似和不相似的句子。

# 鉴于有等效句子对的数据集,选择MultipleNegativesRankingLoss
train_loss = losses.MultipleNegativesRankingLoss(model = model)
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs = 5)

推理

评估这个模型。

# 从 sentence_transformers 导入 SentenceTransformer 和 util
import torch# 待编码的句子们(文档/语料库)
sentences = ['Paris, which is a city in Europe with traditions and remarkable food, is the capital of France','The capital of France is Paris','Australia is known for its traditions and remarkable food',"""Despite the heavy rains that lasted for most of the week, the outdoor music festival,which featured several renowned international artists, was able to proceed as scheduled,much to the delight of fans who had traveled from all over the country""","""Photosynthesis, a process used by plans and other organisms to convert light intochemical energy, plays a crucial role in maintaining the balance of oxygen and carbondioxide in the Earth's atmosphere."""
]# 编码句子
sentences_embeddings = model.encode(sentences, convert_to_tensor=True)# 查询句子:
queries = ['Is Paris located in France?', 'Tell me something about Australia','music festival proceeding despite heavy rains','what is the process that some organisms use to transform light into chemical energy?']# 使用余弦相似度为每个查询寻找语料库中最接近的句子
for query in queries:# 编码当前查询query_embedding = model.encode(query, convert_to_tensor=True)# 余弦相似度及查询最接近的文档cos_scores = util.cos_sim(query_embedding, sentences_embeddings)[0] # 计算相似度得分top_results = torch.argsort(cos_scores, descending = True) # 得分降序排列print("\n\n======================\n\n")print("Query:", query)print("\nSimilar sentences in corpus:") # 输出语料库中相似的句子# 遍历输出与查询最相似的句子及其得分for idx in top_results:print(sentences[idx], "(Score: {:.4f})".format(cos_scores[idx]))

通过使用几个新示例来测试模型,从而展示其有效性。

======================Query: Is Paris located in France?Similar sentences in corpus:
The capital of France is Paris (Score: 0.7907)
Paris, which is a city in Europe with traditions and remarkable food, is the capital of France (Score: 0.7081)Photosynthesis, a process used by plans and other organisms to convert light intochemical energy, plays a crucial role in maintaining the balance of oxygen and carbondioxide in the Earth's atmosphere.(Score: 0.0657)
Australia is known for its traditions and remarkable food (Score: 0.0162)Despite the heavy rains that lasted for most of the week, the outdoor music festival,which featured several renowned international artists, was able to proceed as scheduled,much to the delight of fans who had traveled from all over the country(Score: -0.0934)======================Query: Tell me something about AustraliaSimilar sentences in corpus:
Australia is known for its traditions and remarkable food (Score: 0.6730)
Paris, which is a city in Europe with traditions and remarkable food, is the capital of France (Score: 0.1489)
The capital of France is Paris (Score: 0.1146)Despite the heavy rains that lasted for most of the week, the outdoor music festival, which featured several renowned international artists, was able to proceed as scheduled, much to the delight of fans who had traveled from all over the country(Score: 0.0694)Photosynthesis, a process used by plans and other organisms to convert light intochemical energy, plays a crucial role in maintaining the balance of oxygen and carbondioxide in the Earth's atmosphere.(Score: -0.0241)======================Query: music festival proceeding despite heavy rainsSimilar sentences in corpus:Despite the heavy rains that lasted for most of the week, the outdoor music festival,which featured several renowned international artists, was able to proceed as scheduled,much to the delight of fans who had traveled from all over the country(Score: 0.7855)
Paris, which is a city in Europe with traditions and remarkable food, is the capital of France (Score: 0.0700)Photosynthesis, a process used by plans and other organisms to convert light intochemical energy, plays a crucial role in maintaining the balance of oxygen and carbondioxide in the Earth's atmosphere.(Score: 0.0351)
The capital of France is Paris (Score: 0.0037)
Australia is known for its traditions and remarkable food (Score: -0.0552)======================Query: what is the process that some organisms use to transform light into chemical energy?Similar sentences in corpus:Photosynthesis, a process used by plans and other organisms to convert light intochemical energy, plays a crucial role in maintaining the balance of oxygen and carbondioxide in the Earth's atmosphere.(Score: 0.6085)Despite the heavy rains that lasted for most of the week, the outdoor music festival,which featured several renowned international artists, was able to proceed as scheduled,much to the delight of fans who had traveled from all over the country(Score: 0.1370)
Paris, which is a city in Europe with traditions and remarkable food, is the capital of France (Score: 0.0141)
Australia is known for its traditions and remarkable food (Score: 0.0102)
The capital of France is Paris (Score: -0.0128)

完整代码

from datasets import load_dataset
from sentence_transformers import InputExample, util
from torch.utils.data import DataLoader
from torch import nn
from sentence_transformers import losses
from sentence_transformers import SentenceTransformer, modelsdataset_id = "./sentence-compression"
dataset = load_dataset(dataset_id)# 探索一个样本
print(dataset['train']['set'][1])#将数据集转换为所需格式
train_examples = [] # 创建训练样本列表
train_data = dataset['train']['set'] # 获取训练数据集n_examples = dataset['train'].num_rows//2 # 选择数据集的一半进行训练# 遍历选定的样本并创建输入示例
for example in train_data[:n_examples]:original_sentence = example[0] # 原始句子compressed_sentence = example[1] # 压缩后的句子input_example = InputExample(texts = [original_sentence, compressed_sentence]) # 创建输入示例train_examples.append(input_example) # 将输入示例添加到列表中# 使用训练示例实例化数据加载器
train_dataloader = DataLoader(train_examples, shuffle = True, batch_size = 16) # 初始化数据加载器# 创建自定义模型
# 使用现有的嵌入模型
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256) # 初始化词嵌入模型# 对令牌嵌入应用池化函数
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) # 初始化池化模型# 密集函数
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=256, activation_function=nn.Tanh()) # 初始化密集层模型# 定义整体模型
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model]) # 组装模型各层组件# 给定等效句子数据集,选择MultipleNegativesRankingLoss作为训练损失
train_loss = losses.MultipleNegativesRankingLoss(model = model) # 初始化训练损失
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs = 5) # 训练模型# from sentence_transformers import SentenceTransformer, util
import torch# 待编码的句子们(文档/语料库)
sentences = ['Paris, which is a city in Europe with traditions and remarkable food, is the capital of France','The capital of France is Paris','Australia is known for its traditions and remarkable food',"""Despite the heavy rains that lasted for most of the week, the outdoor music festival,which featured several renowned international artists, was able to proceed as scheduled,much to the delight of fans who had traveled from all over the country""","""Photosynthesis, a process used by plans and other organisms to convert light intochemical energy, plays a crucial role in maintaining the balance of oxygen and carbondioxide in the Earth's atmosphere."""
]# 编码句子
sentences_embeddings = model.encode(sentences, convert_to_tensor=True)# 查询句子:
queries = ['Is Paris located in France?', 'Tell me something about Australia','music festival proceeding despite heavy rains','what is the process that some organisms use to transform light into chemical energy?']# 使用余弦相似度为每个查询寻找语料库中最接近的句子
for query in queries:# 编码当前查询query_embedding = model.encode(query, convert_to_tensor=True)# 余弦相似度及查询最接近的文档cos_scores = util.cos_sim(query_embedding, sentences_embeddings)[0] # 计算相似度得分top_results = torch.argsort(cos_scores, descending = True) # 得分降序排列print("\n\n======================\n\n")print("Query:", query)print("\nSimilar sentences in corpus:") # 输出语料库中相似的句子# 遍历输出与查询最相似的句子及其得分for idx in top_results:print(sentences[idx], "(Score: {:.4f})".format(cos_scores[idx]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/701485.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPT-4o:全面深入了解 OpenAI 的 GPT-4o

GPT-4o:全面深入了解 OpenAI 的 GPT-4o 关于 GPT-4o 的所有信息ChatGPT 增强的用户体验改进的多语言和音频功能GPT-4o 优于 Whisper-v3M3Exam 基准测试中的表现 GPT-4o 的起源追踪语言模型的演变GPT 谱系:人工智能语言的开拓者多模式飞跃:超越…

Threejs 学习笔记 | 灯光与阴影

文章目录 Threejs 学习笔记 | 灯光与阴影如何让灯光照射在物体上有阴影LightShadow - 阴影类的基类平行光的shadow计算投影属性 - DirectionalLightShadow类平行光的投射相机 聚光灯的shadow计算投影属性- SpotLightShadow类聚光灯的投射相机 平行光 DirectionalLight聚光灯 Sp…

生活服务商家拥抱数字化,鸿运果系统加速“服务生意数字化”进程

在数字化转型的大潮中,生活服务商家正积极拥抱变革,以适应新的市场环境和消费者需求。鸿运果系统作为专业的“服务生意”数字化解决方案提供商,正助力商家加速数字化转型,推动行业向智能化、个性化服务转型。 数字化转型的背景 …

C++ 多态的相关问题

目录 1. 第一题 2. 第二题 3. inline 函数可以是虚函数吗 4. 静态成员函数可以是虚函数吗 5. 构造函数可以是虚函数吗 6. 析构函数可以是虚函数吗 7. 拷贝构造和赋值运算符重载可以是虚函数吗 8. 对象访问普通函数快还是访问虚函数快 9. 虚函数表是什么阶段生成的&…

宿舍管理系统代码详解(主页面)

本篇将对管理系统的主页面的代码进行详细的介绍。 目录 一、主页面前端代码 1.样式展示 2.代码详解 (1)template部分 (2)script部分 (3)路由导航守卫 (4)在vue中引用vue 一、主页…

iOS--底层学习--GCD的简单认识

iOS--底层学习--GCD的简单认识 前言什么是GCDGCD的优点GCD中的任务和队列任务队列 GCD的使用队列的创建和获取任务的创建队列嵌套任务和队列中的一些要点 GCD线程间的通信从后台线程切换到主线程通过队列传递数据使用Dispatch Group进行线程间协调 GCD的方法dispatch_barrier_a…

长事务的理解和预防

我们常常听说数据库发生了“长事务”而导致很严重的后果。那么何为长事务?长事务是如何产生的?长事务对数据库有什么影响?如何防止长事务的产生?以下对这几方面进行阐述和说明,以加深对SinoDB长事务的理解。 1&#x…

Oracle 流stream将删除的数据保存

Oracle 流stream将删除的数据保存 --实验的目的是捕获hr.employees表的删除行,将删除行插入到emp_del表中。 --设置初始化参数 AQ_TM_PROCESSES1 COMPATIBLE9.2.0 LOG_PARALLELISM1 --查看数据库的名称,我的为ora9,将以下的ora9全部替换为你的数据库名称…

怎么给视频加水印?2招轻松搞定

在数字媒体时代,视频水印作为一种有效的版权保护手段,被广泛应用于各种场景。给视频添加水印不仅可以防止内容被恶意盗用,还能增加视频的辨识度,提升品牌形象。本文将为您介绍2种简单易行的方法,教您怎么给视频加水印&…

Cartographer前后端梳理

0. 简介 最近在研究整个SLAM框架的改进处,想着能不能从Cartographer中找到一些亮点可以用于参考。所以这一篇博客希望能够梳理好Cartographer前后端优化,并从中得到一些启发。carto整体是graph-based框架,前端是scan-map匹配,后端…

SpringBoot 3.2.5 + ElasticSearch 8.12.0 - SpringData 开发指南

目录 一、SpringData ElasticSearch 1.1、环境配置 1.2、创建实体类 1.3、ElasticSearchTemplate 的使用 1.3.1、创建索引库,设置映射 1.3.2、创建索引映射注意事项 1.3.3、简单的 CRUD 1.3.4、三种构建搜索条件的方式 1.3.5、NativeQuery 搜索实战 1.3.6…

鸿蒙开发之跨设备文件访问

分布式文件系统为应用提供了跨设备文件访问的能力,开发者在多个设备安装同一应用时,通过基础文件接口,可跨设备读写其他设备该应用分布式文件路径(/data/storage/el2/distributedfiles/)下的文件。 例如:多…