Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

Eclipse Deeplearning4j GitChat课程:Deeplearning4j 快速入门_专栏
Eclipse Deeplearning4j 系列博客:万宫玺的专栏_wangongxi_CSDN博客
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
Eclipse Deeplearning4j 社区:https://community.konduit.ai/

DSSM是微软在2013年提出的,最早用于搜索引擎语义召回的双塔模型。目前在工业界也广泛用于推荐召回、搜索相关性排序、语义召回等环节。DSSM是一个轻量级模型,在线上serving的时候,可以通过对query向量和doc向量计算内积,得到的相似值用来衡量query和doc的相似度,从而进行进一步的排序。下面就分别从DSSM模型结构、基于DL4J的DSSM建模、对开源数据集LCQMC的建模等几个环节来介绍如何使用DSSM模型。当然,由于DSSM模型的论文发表时间较早,发表时给出的模型结构比较简单,在我们具体实现的时候,会做一些调整,具体在介绍模型搭建的部分会提到。

1. DSSM模型简述

在论文中,query和doc分别通过各自独立的神经网络映射成一个语义向量。需要注意的是,原论文中doc是一个包含正样本和负样本的集合。正样本取1个,负样本取4个。论文中有提到,正样本是搜素后被点击的样本,负样本则是随机选取的搜索未被点击的样本集合。通过分别计算query的语义向量和正负doc样本的语义向量的余弦相似度,再通过softmax函数得到正负样本的概率分布后,和label计算交叉熵损失。这就是DSSM模型的大致的idea。下面先看下论文中对于DSSM描绘的架构图:
在这里插入图片描述
通过模型架构图可以看到,论文中是使用最简单的MLP对输入进行映射。这里需要提一下word hashing的操作。由于2013年时候word embedding技术还不是较广泛的使用,因此论文中的word hashing是在n-gram语言模型的基础上,通过hash操作将接近50W的词表计算每个词的索引值。这在当时是一种比较高效的做法,目前由于硬件的进步以及embedding技术的进一步成熟,可以直接使用预训练的embedding向量或者做端到端的建模。因此,在第三部分中构建DSSM模型的过程中,我们也是使用的端到端的方案。
在这里插入图片描述

上面这张截图中模型训练的有关描述。就像在本节开始时候提到的,通过softmax计算query和每个doc的余弦相似度的值归一化概率分布。由于softmax函数与cosine相似度的一致性,因此相似度越高的query-doc pair,其softmax值也会越接近于1。在损失函数部分,使用的是经典的log loss。这部分没啥说的。
另外需要说明的是,从ranking loss的角度,论文中的loss应当属于list-wise loss。当然,如果将负样本减少到一个或者doc集合中只有一个正样本或负样本(softmax更改为sigmoid函数),那就退化成pair-wise loss或者point-wise loss。为了方便起见,在第三部分的建模过程中,我们会使用point-wise loss。
对于搜索场景来说,双塔的输入分别是query和doc。对于推荐场景来说,双塔的输入可以是user和item或者item和item,用于U2I的召回或者I2I的召回。

2. LCQMC数据集

LCQMC是哈工大和阿里共同开源的用于QA的数据集,详情可参见论文。下载链接为:地址。压缩包中共有三个文件,三个文件都是以制表符作为分隔符。我们先来看下用于训练的部分数据的截图:
在这里插入图片描述
文件中有三列。最后一列用1或者0来代表 text_a 和 text_b两列文本的是否相关。如果把text_a列文本看作是query,那text_b列可以看作是doc。用于验证的文件中的内容也和训练文件中的数据格式相似,这里就不做另外截图了。

最后提一下,训练样本数量是:238767,验证的样本数量是:12501。

3. 基于DL4J的DSSM模型构建

在第一部分中,我们提到DSSM的论文中双塔内部是使用MLP结构。考虑到MLP结构的单一性,我们使用Embedding+LSTM+MLP的结构作为双塔的内部结构。虽然query和doc对应的塔结构相同,但是不做参数的共享。另外,由于LCQMC数据集中label是1或者0,因此我们将DSSM的输出层改为sigmoid + binary cross entropy loss。具体我们先给出代码片段:

private static ComputationGraph getDSSM(final int QUERY_VOCAB_SIZE, final int DOC_VOCAB_SIZE, final int VECTOR_SIZE) {ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(5 * 1e-3)).weightInit(WeightInit.XAVIER).seed(12345L).graphBuilder().addInputs("query", "doc").setInputTypes(InputType.recurrent(QUERY_VOCAB_SIZE), InputType.recurrent(DOC_VOCAB_SIZE)).addLayer("query-embedding", new EmbeddingSequenceLayer.Builder().nIn(QUERY_VOCAB_SIZE + 1).nOut(VECTOR_SIZE).build(), "query").addLayer("query-embedding-lstm", new LSTM.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "query-embedding").addLayer("doc-embedding", new EmbeddingSequenceLayer.Builder().nIn(DOC_VOCAB_SIZE + 1).nOut(VECTOR_SIZE).build(), "doc").addLayer("doc-embedding-lstm", new LSTM.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "doc-embedding").addVertex("query-embedding-lstm-last-output", new LastTimeStepVertex("query"), "query-embedding-lstm").addVertex("doc-embedding-lstm-last-output", new LastTimeStepVertex("doc"), "doc-embedding-lstm").addLayer("query-output", new DenseLayer.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE / 2).activation(Activation.LEAKYRELU).build(), "query-embedding-lstm-last-output").addLayer("doc-output", new DenseLayer.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE / 2).activation(Activation.LEAKYRELU).build(), "doc-embedding-lstm-last-output").addVertex("query-output-l2-norm", new L2NormalizeVertex(), "query-output").addVertex("doc-output-l2-norm", new L2NormalizeVertex(), "doc-output").addVertex("cosing-similar", new ElementWiseVertex(ElementWiseVertex.Op.Product), "query-output-l2-norm", "doc-output-l2-norm").addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT)	//bce.nIn(VECTOR_SIZE / 2).nOut(1).activation(Activation.SIGMOID).build(), "cosing-similar").setOutputs("out").build();ComputationGraph net = new ComputationGraph(conf);net.setListeners(new ScoreIterationListener(1));net.init();return net;
}

由于存在两个输入,因此使用DL4J中的ComputationGraph。这里需要说明的有几点:

  • LastTimeStepVertex的作用:获取LSTM最后一个time step输出的张量
  • L2NormalizeVertex的作用:L2归一化,将query和doc的向量转化为单位向量
  • ElementWiseVertex的作用:通过设置Op为点积,实际为计算query和doc单位向量的内积,因此L2NormalizeVertex + ElementWiseVertex联合起来的作用是计算向量间的余弦相似度值
  • 输出端使用sigmoid + bce 作point-wise的损失函数
    在这里插入图片描述

上面的截图中通过summary接口打印的模型结构和待训练参数。可见待训练参数68W。

另外,对于该静态方法,输入的几个参数QUERY_VOCAB_SIZE,DOC_VOCAB_SIZE,VECTOR_SIZE分别代表LCQMC数据集中text_a的词表大小和text_b的词表大小,以及词向量的大小。

需要指出的是,在第四部分进行建模的操作中,我们使用中文单字作为query和doc的最小粒度特征,而不做分词的处理。

4. DSSM模型训练和评估

首先介绍下数据处理的部分:

  • 读取训练文件,构建中文单字和单字的索引,存储在map结构中。同时记录最长的文本长度,用于后续的padding操作。
  • 再次读取文件,对每条记录构建MultiDataSet对象,并存储在LinkedList对象中。MultiDataSet对象中会存储query和doc作为输入,label作为输出,此外还有query和doc的mask张量,用于统一变长文本的处理。

我们看下具体的实现逻辑:

class DataSetInfo{public Map<String,Integer> queryDict = new TreeMap<>();public Map<String,Integer> docDict = new TreeMap<>();public int queryMaxLen = 0;public int docMaxLen = 0;
}private static DataSetInfo preprocess(String filePath) {DataSetInfo info = new DataSetInfo();try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){String line = null;int lineIndex = 0;while( (line = br.readLine()) != null ) {if( lineIndex == 0 ) {lineIndex++;continue;}String[] splits = line.split("\t");if( null == splits || splits.length != 3 )continue;String query = splits[0];String doc = splits[1];if( query != null && query.length() > 0 ) {info.queryMaxLen = Math.max(query.length(), info.queryMaxLen);for( char c : query.toCharArray() ) {String charStr = String.valueOf(c);if( !info.queryDict.containsKey(charStr) ) {int curIndex = info.queryDict.size();info.queryDict.put(charStr, curIndex);}}}if( doc != null && doc.length() > 0 ) {info.docMaxLen = Math.max(doc.length(), info.docMaxLen);for( char c : doc.toCharArray() ) {String charStr = String.valueOf(c);if( !info.docDict.containsKey(charStr) ) {int curIndex = info.docDict.size();info.docDict.put(charStr, curIndex);}}}}}catch(Exception ex) {ex.printStackTrace();}finally {int curIndex = info.queryDict.size();info.queryDict.put("UNK", curIndex);//curIndex = info.docDict.size();info.docDict.put("UNK", curIndex);}return info;
}

这部分处理逻辑比较清晰,主要是先定义个DataSetInfo的类,里面包含了单字和单字索引的映射关系,还有最大文本长度。在finally部分,我们使用UNK代表所有未登录词。接着看下MultiDataSet的构造:

private static List<org.nd4j.linalg.dataset.api.MultiDataSet> getMultiDataIter(String filePath, DataSetInfo dataInfo) {List<org.nd4j.linalg.dataset.api.MultiDataSet> list = new LinkedList<>();try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){String line = null;int lineIndex = 0;while( (line = br.readLine()) != null ) {if( lineIndex == 0 ) {lineIndex++;continue;}String[] splits = line.split("\t");String query = splits[0];String doc = splits[1];String label = splits[2];if( query == null || query.isEmpty() ||doc == null || doc.isEmpty() || label == null)continue;//double[][] queryIndexArray = new double[1][dataInfo.queryMaxLen];double[][] docIndexArray = new double[1][dataInfo.docMaxLen];double[][] queryIndexMaskArray = new double[1][dataInfo.queryMaxLen];double[][] docIndexMaskArray = new double[1][dataInfo.docMaxLen];double[][] labelIndexArray = new double[1][1];//for( int i = 0; i < query.length(); ++i ) {queryIndexArray[0][i] = dataInfo.queryDict.getOrDefault(String.valueOf(query.charAt(i)),dataInfo.queryDict.get("UNK"));queryIndexMaskArray[0][i] = 1.0;}for( int i = 0; i < doc.length(); ++i ) {docIndexArray[0][i] = dataInfo.docDict.getOrDefault(String.valueOf(doc.charAt(i)),dataInfo.docDict.get("UNK"));docIndexMaskArray[0][i] = 1.0;}labelIndexArray[0][0] = Double.parseDouble(label);//org.nd4j.linalg.dataset.api.MultiDataSet mds = new MultiDataSet(new INDArray[] {Nd4j.create(queryIndexArray), Nd4j.create(docIndexArray)},new INDArray[] {Nd4j.create(labelIndexArray)},new INDArray[] {Nd4j.create(queryIndexMaskArray), Nd4j.create(docIndexMaskArray)},null);list.add(mds);}}catch(Exception ex) {ex.printStackTrace();}return list;
}

该部分逻辑主要是通过一个静态方法来读取训练文本中的每一行数据,并且针对text_a和text_b以及label转换成一个MultiDataSet对象,并存储在一个LinkedList对象中。需要注意的是Mask部分的处理。Mask张量中用1.0代表有效,0.0代表无效的部分。下面我们看下训练建模和评估的部分。

final int batchSize = 256;
final int embedding_size = 64;
DataSetInfo dataInfo = preprocess("data/lcqmc/train.tsv");
ComputationGraph dssm = getDSSM(dataInfo.queryDict.size(), dataInfo.docDict.size(), embedding_size);
System.out.println(dssm.summary());
List<org.nd4j.linalg.dataset.api.MultiDataSet> trainDataList = getMultiDataIter("data/lcqmc/train.tsv", dataInfo);
List<org.nd4j.linalg.dataset.api.MultiDataSet> testDataList = getMultiDataIter("data/lcqmc/test.tsv", dataInfo);
System.out.println("Finish Loading Train Data");
for(int epoch = 0; epoch < 5; ++epoch) {Collections.shuffle(trainDataList);MultiDataSetIterator trainIter = new IteratorMultiDataSetIterator(trainDataList.iterator(), batchSize);dssm.fit(trainIter);Evaluation eval = dssm.evaluate(new IteratorMultiDataSetIterator(testDataList.iterator(), batchSize));System.out.println(eval);
}

通过10个epoch的训练,我们最终在验证集上得到70%左右的准确率, loss值在0.4左右。
在这里插入图片描述

5. 总结

DSSM是一个经典的双塔模型,但其也有明显的缺点,就是两个塔之间是独立的,没有信息的交叉。这种信息的交叉对应推荐场景来说是很重要的。DSSM论文中的结构比较简单,是MLP为主,且输入层使用词袋模型进行处理,这其实忽略的上下文的语义信息,因此我们在实现的时候,使用LSTM模型来捕获序列的完整语义信息。当然,由于时间原因,我们这边并没有做分词处理,相信经过分词处理,在LCQMC数据集上的准确率可以进一步得到提升。另外,双塔的结构可以很灵活,内部可以直接上BERT来做,这里变体就太多,不做过多陈述了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/312919.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

记录 Docker 中安装 ROS2

目录 1 安装 Docker 2 安装 ROS2 3 启动 Docker 4 测试 ROS2 环境 1 安装 Docker 1. 更新软件包sudo apt updatesudo apt upgrade2. 安装 docker 依赖sudo apt-get install ca-certificates curl gnupg lsb-release3. 添加 docker 官方 GPG 密钥curl -fsSL http://mirror…

Jetpack Compose中使用Android View

使用AndroidView创建日历 Composable fun AndroidViewPage() {AndroidView(factory {CalendarView(it)},modifier Modifier.fillMaxWidth(),update {it.setOnDateChangeListener { view, year, month, day ->Toast.makeText(view.context, "${year}年${month 1}月$…

Oracle开发经验总结

文章目录 1. 加注释2. 增加索引3. nvl(BOARDCODE&#xff0c;100)>004. 去掉distinct可以避免hash比较&#xff0c;提高性能5. like模糊查询优化(转化为instr()函数)6. SQL计算除数为0时&#xff0c;增加nullif判断7. 分页8. 查看执行计划9. <if test"productCode !…

【2023】hadoop基础介绍

&#x1f4bb;目录 Hadoop组成HDFSHDFS操作HDFS分布式文件存储NameNode元数据数据读写流程 YARN和MapReduceMapReduce&#xff1a;分布式计算YARN&#xff1a;资源管控调度YARN架构提交任务到**YARN中运行** Hadoop组成 hadoop安装教程可以看我这篇文章> &#x1f345;hado…

C/C++面向对象(OOP)编程-回调函数详解(回调函数、C/C++异步回调、函数指针)

本文主要介绍回调函数的使用&#xff0c;包括函数指针、异步回调编程、主要通过详细的例子来指导在异步编程和事件编程中如何使用回调函数来实现。 &#x1f3ac;个人简介&#xff1a;一个全栈工程师的升级之路&#xff01; &#x1f4cb;个人专栏&#xff1a;C/C精进之路 &…

JavaScript 基础通关

快速熟悉 JavaScript 的基础语法&#xff0c;比较高级的比如事件放在后面的笔记中。 JavaScript 1. JavaScript 介绍 1.1 JavaScript 基本介绍 JavaScript 是一门运行在客户端&#xff08;浏览器&#xff09;的编程语言&#xff0c;实现人机交互的效果。实现网页特效、表单验…

海康visionmaster-渲染结果:控件颜色:控件颜色修改的方法

描述 环境&#xff1a;VM4.0.0 VS2015 及以上 现象&#xff1a;简易修改 VM 控件的颜色&#xff1f; 解答 对二次开发中嵌入控件的颜色进行修改&#xff0c;具体代码如下&#xff1a; C# string colorinfo “ColorStyle3”; AppColorService.CurColorDefine colorinfo; “Co…

OpenWrt 编译入门(小白版)

编译环境 示例编译所用系统为 Ubuntu 22.04&#xff0c;信息如下 编译时由于网络问题&#xff0c;部分软件包可能出现下载问题&#xff0c;还请自备网络工具或尝试重新运行命令 编译步骤 下图为官网指示 编译环境设置&#xff08;Build system setup&#xff09; 这里根据我…

springboot实现用户操作日志记录

springboot实现用户操作日志记录 简介&#xff1a;之前写了《aop实现日志持久化记录》一文&#xff0c;主要介绍自定义aop标注方法上&#xff0c;通过切面方法对用户操作插入mysql。思路正确但是实际操作上存在一些小问题&#xff0c;本文将从项目出发&#xff0c;对细节进行补…

Nginx 代理静态资源,解决跨域问题

&#x1f602; 背景&#xff1a;移动端 H5 项目&#xff0c;依赖了一个外部的 JS 文件。访问时&#xff0c;出现跨域&#xff0c;导致请求被 block。 当前域名&#xff1a;https://tmcopss.test.com要访问的 JS 文件&#xff1a;https://tm.test.com/public/scripts/y-jssdk.j…

C++每日一练(8):图像相似度

题目描述 给出两幅相同大小的黑白图像&#xff08;用0-1矩阵&#xff09;表示&#xff0c;求它们的相似度。 说明&#xff1a;若两幅图像在相同位置上的像素点颜色相同&#xff0c;则称它们在该位置具有相同的像素点。两幅图像的相似度定义为相同像素点数占总像素点数的百分比。…

【华为机试】2023年真题B卷(python)-猴子爬山

一、题目 题目描述&#xff1a; 一天一只顽猴想去从山脚爬到山顶&#xff0c;途中经过一个有个N个台阶的阶梯&#xff0c;但是这猴子有一个习惯&#xff1a; 每一次只能跳1步或跳3步&#xff0c;试问猴子通过这个阶梯有多少种不同的跳跃方式&#xff1f; 二、输入输出 输入描述…