【从零开始实现意图识别】中文对话意图识别详解

前言

意图识别(Intent Recognition)是自然语言处理(NLP)中的一个重要任务,它旨在确定用户输入的语句中所表达的意图或目的。简单来说,意图识别就是对用户的话语进行语义理解,以便更好地回答用户的问题或提供相关的服务。

在NLP中,意图识别通常被视为一个分类问题,即通过将输入语句分类到预定义的意图类别中来识别其意图。这些类别可以是各种不同的任务、查询、请求等,例如搜索、购买、咨询、命令等。

下面是一个简单的例子来说明意图识别的概念:

用户输入: "我想订一张从北京到上海的机票。

意图识别:预订机票。

在这个例子中,通过将用户输入的语句分类到“预订机票”这个意图类别中,系统可以理解用户的意图并为其提供相关的服务。

意图识别是NLP中的一项重要任务,它可以帮助我们更好地理解用户的需求和意图,从而为用户提供更加智能和高效的服务。

在智能对话任务中,意图识别是一种非常重要的技术,它可以帮助系统理解用户的输入,从而提供更加准确和个性化的回答和服务。

模型

意图识别和槽位填充是对话系统中的基础任务。本仓库实现了一个基于BERT的意图(intent)和槽位(slots)联合预测模块。想法上实际与JoinBERT类似(GitHub:BERT for Joint Intent Classification and Slot Filling),利用 [CLS] token对应的last hidden state去预测整句话的intent,并利用句子tokens的last hidden states做序列标注,找出包含slot values的tokens。你可以自定义自己的意图和槽位标签,并提供自己的数据,通过下述流程训练自己的模型,并在JointIntentSlotDetector类中加载训练好的模型直接进行意图和槽值预测。

源GitHub:https://github.com/Linear95/bert-intent-slot-detector

在本文使用的模型中对数据进行了扩充、对代码进行注释、对部分代码进行了修改

Bert模型下载

Bert模型下载地址:https://huggingface.co/bert-base-chinese/tree/main

下载下方红框内的模型即可。

数据集介绍

训练数据以json格式给出,每条数据包括三个关键词:text表示待检测的文本,intent代表文本的类别标签,slots是文本中包括的所有槽位以及对应的槽值,以字典形式给出。

{

"text": "搜索西红柿的做法。",

"domain": "cookbook",

"intent": "QUERY",

"slots": {"ingredient": "西红柿"}

}

原始数据集:https://conference.cipsc.org.cn/smp2019/

本项目中在原始数据集中新增了部分数据,用来平衡数据。

模型训练

python train.py

# -----------training-------------
max_acc = 0
for epoch in range(args.train_epochs):total_loss = 0model.train()for step, batch in enumerate(train_dataloader):input_ids, intent_labels, slot_labels = batchoutputs = model(input_ids=torch.tensor(input_ids).long().to(device),intent_labels=torch.tensor(intent_labels).long().to(device),slot_labels=torch.tensor(slot_labels).long().to(device))loss = outputs['loss']total_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()if step % args.gradient_accumulation_steps == 0:# 用于对梯度进行裁剪,以防止在神经网络训练过程中出现梯度爆炸的问题。torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)optimizer.step()scheduler.step()model.zero_grad()train_loss = total_loss / len(train_dataloader)dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)flag = Falseif max_acc < dev_acc:max_acc = dev_accflag = Truesave_module(model, model_save_dir)print(f"[{epoch}/{args.train_epochs}] train loss: {train_loss}  dev intent_avg: {intent_avg} "f"def slot_avg: {slot_avg} save best model: {'*' if flag else ''}")dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)
print("last model dev intent_avg: {} def slot_avg: {}".format(intent_avg, slot_avg))

运行过程:

模型推理

python predict.py
 
def detect(self, text, str_lower_case=True):"""text : list of string, each string is a utterance from user"""list_input = Trueif isinstance(text, str):text = [text]list_input = Falseif str_lower_case:text = [t.lower() for t in text]batch_size = len(text)inputs = self.tokenizer(text, padding=True)with torch.no_grad():outputs = self.model(input_ids=torch.tensor(inputs['input_ids']).long().to(self.device))intent_logits = outputs['intent_logits']slot_logits = outputs['slot_logits']intent_probs = torch.softmax(intent_logits, dim=-1).detach().cpu().numpy()slot_probs = torch.softmax(slot_logits, dim=-1).detach().cpu().numpy()slot_labels = self._predict_slot_labels(slot_probs)intent_labels = self._predict_intent_labels(intent_probs)slot_values = self._extract_slots_from_labels(inputs['input_ids'], slot_labels, inputs['attention_mask'])outputs = [{'text': text[i], 'intent': intent_labels[i], 'slots': slot_values[i]}for i in range(batch_size)]if not list_input:return outputs[0]return outputs

推理结果:

模型检测相关代码

将概率值转换为实际标注值

def _predict_slot_labels(self, slot_probs):"""slot_probs : probability of a batch of tokens into slot labels, [batch, seq_len, slot_label_num], numpy array"""slot_ids = np.argmax(slot_probs, axis=-1)return self.slot_dict[slot_ids.tolist()]def _predict_intent_labels(self, intent_probs):"""intent_labels : probability of a batch of intent ids into intent labels, [batch, intent_label_num], numpy array"""intent_ids = np.argmax(intent_probs, axis=-1)return self.intent_dict[intent_ids.tolist()]

槽位验证(确保检测结果的正确性)

def _extract_slots_from_labels_for_one_seq(self, input_ids, slot_labels, mask=None):results = {}unfinished_slots = {}  # dict of {slot_name: slot_value} pairsif mask is None:mask = [1 for _ in range(len(input_ids))]def add_new_slot_value(results, slot_name, slot_value):if slot_name == "" or slot_value == "":return resultsif slot_name in results:results[slot_name].append(slot_value)else:results[slot_name] = [slot_value]return resultsfor i, slot_label in enumerate(slot_labels):if mask[i] == 0:continue# 检测槽位的第一字符(B_)开头if slot_label[:2] == 'B_':slot_name = slot_label[2:]  # 槽位名称 (B_ 后面)if slot_name in unfinished_slots:results = add_new_slot_value(results, slot_name, unfinished_slots[slot_name])unfinished_slots[slot_name] = self.tokenizer.decode(input_ids[i])# 检测槽位的后面字符(I_)开头elif slot_label[:2] == 'I_':slot_name = slot_label[2:]if slot_name in unfinished_slots and len(unfinished_slots[slot_name]) > 0:unfinished_slots[slot_name] += self.tokenizer.decode(input_ids[i])for slot_name, slot_value in unfinished_slots.items():if len(slot_value) > 0:results = add_new_slot_value(results, slot_name, slot_value)return results

源码获取

NLP/bert-intent-slot at main · mzc421/NLP (github.com)icon-default.png?t=N7T8https://github.com/mzc421/NLP/tree/main/bert-intent-slot

链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/212046.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RtpRtcp】1: webrtc m79:audio的ChannelReceive 创建并使用

m79中,RtpRtcp::Create 的调用很少 不知道谁负责创建ChannelReceiveclass ChannelReceive : public ChannelReceiveInterface,public MediaTransportAudioSinkInterface {接收编码后的音频帧:接收rtcp包:

VTK表面画贴合线条避免陷入局部最小值

通常画线&#xff0c;由于距离太远会陷入局部最小值&#xff0c;通过算法解决这一问题。 1、陷入局部最小值 2、算法解决陷入局部最小值问题

大模型增量预训练参数说明

在增量预训练过程中通常需要设置三类或四类参数,模型参数,数据参数,训练参数,额外参数。 下面分别针对这四种参数进行说明。 欢迎关注公众号 模型参数 model_type模型类型,例如bloom,llama,baichuan,qwen等。 model_name_or_path模型名称或者路径。 tokenizer_name_or…

奥特曼重返CEO之位!AI发展成硅谷巨头与保守科学家权力之争

OpenAI&#xff0c;一家因ChatGPT而闻名的人工智能公司&#xff0c;近日陷入了一场激烈的权力之争。创始人兼CEO山姆奥特曼&#xff08;Sam Altman&#xff09;突然离职&#xff0c;引发了一系列连锁反应。然而&#xff0c;在经过一番波折和谈判后&#xff0c;奥特曼最终重返Op…

Python入门02 算术运算符及优先级

目录 1 REPL2 启动3 算术运算符4 算术运算符的优先级5 清除屏幕总结 上一节我们安装了Python的开发环境&#xff0c;本节我们介绍一下REPL的概念 1 REPL 首先解释一下python执行代码的一个交互环境的定义&#xff1a; Python REPL&#xff08;Read-Eval-Print Loop&#xff0c…

Springboot整合MybatisPlus及分页功能

1 引入pom <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot</artifactId><version>2.7.14</version> </dependency> <dependency><groupId>com.baomidou</groupId><a…

AQS源码解析

AQS源码解析 文章目录 AQS源码解析一、AQS二、共享资源 state三、FIFO 阻塞队列四、独占模式 acquire 获取资源五、独占模式 release 释放资源六、共享模式 acquireShared 获取资源七、共享模式 releaseShared 释放资源八、总结 一、AQS AQS 是 AbstractQueuedSynchronizer 的…

微型计算机原理MOOC题

一、8254 1.掉坑了&#xff0c;AL传到端口不意味着一定传到的是低位&#xff0c;要看控制字D5和D4&#xff0c;10是只写高位&#xff0c;所以是0A00.。。 2. 3. 4.待解决&#xff1a;

【腾讯云TDSQL-C Serverless产品体验】与云函数一起来一次无服务器体验

写在前面&#xff1a;博主是一只经过实战开发历练后投身培训事业的“小山猪”&#xff0c;昵称取自动画片《狮子王》中的“彭彭”&#xff0c;总是以乐观、积极的心态对待周边的事物。本人的技术路线从Java全栈工程师一路奔向大数据开发、数据挖掘领域&#xff0c;如今终有小成…

UML建模图文详解教程06——顺序图

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl本文参考资料&#xff1a;《UML面向对象分析、建模与设计&#xff08;第2版&#xff09;》吕云翔&#xff0c;赵天宇 著 顺序图概述 顺序图(sequence diagram&#xff0c;也…

网络图简单计算规则

单代号进度网络图&#xff08;节点法&#xff09; 概念 计算规则 &#xff08;顺时针计算法&#xff09; &#xff08;TF取之差&#xff09; &#xff08;T&#xff1a;持续时间&#xff09; ES → EF (ES取大EF加T) ↑ T ↑ &#xff08;TF&#xff1a;总时差&…

Leetcode算法系列| 1. 两数之和(四种解法)

目录 1.题目2.题解解法一&#xff1a;暴力枚举解法二&#xff1a;哈希表解法解法三&#xff1a;双指针(有序状态)解法四&#xff1a;二分查找(有序状态) 1.题目 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数…