bert-NER 转化成 onnx 模型

保存模型

加载模型

from transformers import AutoTokenizer, AutoModel, AutoConfigNER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)
ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_model.eval()

测试ner效果

在这里插入图片描述

测试速度

在这里插入图片描述

导出到onnx

# !pip install onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple/# 导出 onnx 模型
import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManagerconfig = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')torch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),enable_onnx_checker=True,opset_version=onnx_config.default_onnx_opset,
)

加载ONNX模型

自定义pipeline

from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSessionclass PipeLineOnnx:def __init__(self, tokenizer, onnx_path, config):self.tokenizer = tokenizerself.config = config  # label2id, id2labeloptions = SessionOptions() # initialize session optionsoptions.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL# 设置线程数
#         options.intra_op_num_threads = 4# 这里的路径传上一节保存的onnx模型地址self.session = InferenceSession(onnx_path, sess_options=options, providers=["CPUExecutionProvider"])# disable session.run() fallback mechanism, it prevents for a reset of the execution providerself.session.disable_fallback() def __call__(self, text):inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')ids = inputs["input_ids"]inputs_offset = self.tokenizer.encode_plus(text, return_offsets_mapping=True).offset_mappinginputs_detach = {k: v.detach().cpu().numpy() for k, v in inputs.items()}# 运行 ONNX 模型# 这里的logits要有export的时候output_names相对应output = self.session.run(output_names=['logits'], input_feed=inputs_detach)[0]logits = torch.tensor(output)num_labels = len(self.config.label2id)active_logits = logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)softmax = torch.softmax(active_logits, axis=1)scores = torch.max(softmax, axis=1).values.cpu().detach().numpy()flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token leveltokens = self.tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())token_predictions = [self.config.id2label[i] for i in flattened_predictions.cpu().numpy()]wp_preds = list(zip(tokens, token_predictions)) # list of tuples. Each tuple = (wordpiece, prediction)ner_result = [{"index": idx, "word":i,"entity":j, "start": k[0], "end": k[1], "score": s} for idx, (i,j,k,s) in enumerate(zip(tokens, token_predictions, inputs_offset, scores)) if j != 'O']return post_process(ner_result)def allow_merge(a, b):a_flag, a_type = a.split('-')b_flag, b_type = b.split('-')if b_flag == 'B' or a_flag == 'E':return Falseif a_type != b_type:return Falseif (a_flag, b_flag) in [("B", "I"),("B", "E"),("I", "I"),("I", "E")]:return Truereturn Falsedef divide_entities(ner_results):divided_entities = []current_entity = []for item in sorted(ner_results, key=lambda x: x['index']):if not current_entity:current_entity.append(item)elif allow_merge(current_entity[-1]['entity'], item['entity']):current_entity.append(item)else:divided_entities.append(current_entity)current_entity = [item]divided_entities.append(current_entity)return divided_entitiesdef merge_entities(same_entities):def avg(scores):return sum(scores)/len(scores)return {'entity': same_entities[0]['entity'].split("-")[1],'score': avg([e['score'] for e in same_entities]),'word': ''.join(e['word'].replace('##', '') for e in same_entities),'start': same_entities[0]['start'],'end': same_entities[-1]['end']}def post_process(ner_results):return [merge_entities(i) for i in divide_entities(ner_results)]

加载模型

from transformers import AutoTokenizer, AutoConfigNER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)pipe2 = PipeLineOnnx(ner_tokenizer, "bert-ner.onnx", config=ner_config)

测试效果

在这里插入图片描述

测试速度

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/682145.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5月9(信息差)

🌍 可再生能源发电量首次占全球电力供应的三成 🎄马斯克脑机接口公司 Neuralink 计划将 Link 功能扩展至现实世界,实现控制机械臂、轮椅等 马斯克脑机接口公司 Neuralink 计划将 Link 功能扩展至现实世界,实现控制机械臂、轮椅等…

python代码无法点击进入,如何破???

python代码无法点击进入,如何破??? 举个栗子: model.chat是无法进入的,这时可以使用如下的命令进行操作: ?model.chat

谷歌CEO最新访谈:AI浪潮仍处于早期阶段,公司未来最大威胁是执行力不足

作为搜索领域无可争议的霸主,谷歌改变了我们生活的方方面面,从日常琐事到工作事务,再到我们的沟通方式。多年来,谷歌一直是互联网的窗口,为我们提供大量知识和信息,但如今,随着其他类似平台的崛…

python面向函数

组织好的,可重复利用的,用来实现单一,或相关联功能的代码段,避免重复造轮子,增加程序复用性。 定义方法为def 函数名 (参数) 参数可动态传参,即使用*args代表元组形式**kwargs代表字典形式,代替…

使用 SSH 连接 GitHub Action 服务器

前言 Github Actions 是 GitHub 推出的持续集成 (Continuous integration,简称 CI) 服务它提供了整套虚拟服务器环境,基于它可以进行构建、测试、打包、部署项目,如果你的项目是开源项目,可以不限时使用服务器硬件规格&#xff1…

QT--4

QT 使用定时器完成闹钟 #include "widget.h" #include "ui_widget.h"void Widget::timestart() {timer.start(1000); }void Widget::timeend() {timer.stop(); }Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(t…

ntfs文件系统的优势 NTFS文件系统的特性有哪些 ntfs和fat32有什么区别 苹果电脑怎么管理硬盘

对于数码科技宅在新购得磁盘之后,出于某种原因会在新的磁盘安装操作系统。在安装操作系统时,首先要对磁盘进行分区和格式化,而在此过程中,操作者们需要选择文件系统。文件系统也决定了之后操作的流程程度,一般文件系统…

10大排序方法,其中这里只介绍前7种(第4种C语言,其它C++语言)

排序方法有十种,分别是:一、冒泡排序;二、选择排序;三、插入排序;四、希尔排序;五、归并排序;六、快速排序;七、堆排序;八、计数排序;九、桶排序;…

ICode国际青少年编程竞赛- Python-2级训练场-坐标与列表练习

ICode国际青少年编程竞赛- Python-2级训练场-坐标与列表练习 1、 for i in range(6):Spaceship.step(Item[i].x - Spaceship.x)Dev.step(Item[i].y - Dev.y)Dev.step(Spaceship.y - Dev.y)2、 for i in range(5):Spaceship.step(Item[i].x - Spaceship.x)Flyer[i].step(Item[…

气象多要素百叶箱

气象多要素百叶箱(485型) 该一体式百叶箱可广泛适用于环境监测,即噪声采集、PM2.5和PM10、温湿度、大气压力、光照于一体,设备采用标准MODBUS-RTU通信协议,RS485信号输出,通信距离最大可达2000米&#xff0…

Cmake编译源代码生成库文件以及使用

在项目实战中,通过模块化设计能够使整个工程更加简洁明了。简单的示例如下: 1、项目结构 project_folder/├── CMakeLists.txt├── src/│ ├── my_library.cpp│ └── my_library.h└── app/└── main.cpp2、CMakeList文件 # CMake …

【go项目01_学习记录08】

学习记录 1 模板文件1.1 articlesStoreHandler() 使用模板文件1.2 统一模板 1 模板文件 重构 articlesCreateHandler() 和 articlesStoreHandler() 函数,将 HTML 抽离并放置于独立的模板文件中。 1.1 articlesStoreHandler() 使用模板文件 . . . func articlesSt…