深度学习的onnx模型插入新节点构建新模型

在这里插入图片描述

import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend as backendmodel = onnx.load('test.onnx')
node = model.graph.node
graph = model.graph# 1.2搜索目标节点
# for i in range(len(node)):
#     if node[i].op_type == 'Conv':
#         node_rise = node[i]
#         if node_rise.output[0] == '203':
#             print(i)
# print(node[159])new_node_0 = onnx.helper.make_node("Mul",inputs=["input_image","1"],outputs=["mutiply"],
)mutiply_node = onnx.helper.make_node("Constant",inputs=[],outputs=["1"],value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [], [2.0])
)new_node_1 = onnx.helper.make_node("Add",inputs=["mutiply","2"],outputs=["add"],
)add_node = onnx.helper.make_node("Constant",inputs=[],outputs=["2"],value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [], [-1.0])
)#删除老节点 
old_squeeze_node = model.graph.node[0]
old_squeeze_node.input[0] = "add"
model.graph.node.remove(old_squeeze_node)graph.node.insert(0, mutiply_node)
graph.node.insert(1, new_node_0)
graph.node.insert(2, add_node)
graph.node.insert(3, new_node_1)
graph.node.insert(4, old_squeeze_node)
onnx.checker.check_model(model)
onnx.save(model, 'out.onnx')# session = onnxruntime.InferenceSession("out.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# out = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: np.ones([1, 1, 128, 128], dtype=np.float32)})[0]
# print(out)print(onnxruntime.get_device())
rt = backend.prepare(model, "CPU")
out = rt.run(np.ones([1, 1, 128, 128], dtype=np.float32))
print(out)

在这里插入图片描述

第二种使用可供训练的初始化参数

import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend as backendmodel = onnx.load('test.onnx')
node = model.graph.node
graph = model.graph# 1.2搜索目标节点
# for i in range(len(node)):
#     if node[i].op_type == 'Conv':
#         node_rise = node[i]
#         if node_rise.output[0] == '203':
#             print(i)
# print(node[159])mutiply_node = onnx.helper.make_tensor(name='1',data_type=onnx.TensorProto.FLOAT,dims= [1],vals = np.array([2.0], dtype=np.float32))graph.initializer.append(mutiply_node)new_node_0 = onnx.helper.make_node("Mul",inputs=["input_image","1"],outputs=["mutiply"],
)add_node = onnx.helper.make_tensor(name='2',data_type=onnx.TensorProto.FLOAT,dims= [1],vals = np.array([-1.], dtype=np.float32))graph.initializer.append(add_node)new_node_1 = onnx.helper.make_node("Add",inputs=["mutiply","2"],outputs=["add"],
)#删除老节点 
old_squeeze_node = model.graph.node[0]
old_squeeze_node.input[0] = "add"
model.graph.node.remove(old_squeeze_node)graph.node.insert(0, new_node_0)
graph.node.insert(1, new_node_1)
graph.node.insert(2, old_squeeze_node)
onnx.checker.check_model(model)
onnx.save(model, 'out.onnx')# session = onnxruntime.InferenceSession("out.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# out = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: np.ones([1, 1, 128, 128], dtype=np.float32)})[0]
# print(out)print(onnxruntime.get_device())
rt = backend.prepare(model, "CPU")
out = rt.run(np.ones([1, 1, 128, 128], dtype=np.float32))
print(out)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/504813.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA面向对象高级部分—多态

面向对象高级部分—多态 认识多态 对象多态,对象既可以指向老师对象,也可以指向学生对象。 注意事项: 成员变量不谈多态,编译看左边,运行看左边 成员变量编译的是父类People,所以编译的是左边的People&a…

STM32------分析GPIO寄存器

一、初始LED原理图 共阴极led LED发光二极管,需要有电流通过才能点亮,当有电压差就会产生电流 二极管两端的电压差超过2.7v就会有电流通过 电阻的作用 由于公式IV/R 不加电阻容易造成瞬间电流无穷大 发光二极管工作电流为10-20MA 3.3v / 1kΩ 3.…

【风格迁移】URST:解决超高分辨率图像的风格迁移问题

URST:解决超高分辨率图像的风格迁移问题 提出背景URST框架的整体架构 提出背景 论文:https://arxiv.org/pdf/2103.11784.pdf 代码:https://github.com/czczup/URST?v1 有一张高分辨率的风景照片,分辨率为1000010000像素&#…

枚举类、泛型、API

枚举类 枚举类可以实现单例设计模式。 枚举的常见应用场景:用来表示一组信息,然后作为参数进行传输。 泛型 API

Benchmark学习笔记

小记一篇Benchmark的学习笔记 1.什么是benchmark 在维基百科中,是这样子讲的 “As computer architecture advanced, it became more difficult to compare the performance of various computer systems simply by looking at their specifications.Therefore, te…

Onlyfans怎么绑定虚拟卡订阅,视频图文教学!!!

前言 onlyfans软件是一个创立于2016年的订阅式社交媒体平台,创作者可以在自己的账号发布原创的照片或视频,并需要注意的是,网络上可能存在非法或不道德的应用将其设置成付费模式,若用户想查看则需要每月交费订阅。 图文视频教学&a…

steam++加速问题:出现显示443端口被 vmware-hostd(9860)占用的错误。

目录 前言: 正文: 前言: 使用Steam对GitHub进行加速处理时,建议使用2.8.6版本。 下载地址如下:Release 2.8.6 BeyondDimension/SteamTools GitHub 下载时注意自己的系统位数 正文: 使用GitHub时会使…

RocketMQ学习笔记一

课程来源:002-MQ简介_哔哩哔哩_bilibili (尚硅谷老雷,时长19h) 第1章 RocketMQ概述 1. MQ是什么? 2. MQ用途有哪些? 限流削峰;异步解耦;数据收集。 3. 常见MQ产品有哪些&对比…

如何根据PalWorldSettings.ini重新生成定制的WorldOption.sav文件?

这个过程涉及到将PalWorldSettings.ini 文件中的设置与WorldOption.sav 文件进行匹配和替换。具体的操作步骤可能包括检查PalWorldSettings.ini 文件中的设置是否与WorldOption.sav 文件中的设置相匹配,然后根据这些设置重新生成或修改WorldOption.sav 文件&#xf…

还在用Jenkins?快来试试这款简而轻的自动部署软件!

最近发现了一个比 Jenkins 使用更简单的项目构建和部署工具,完全可以满足个人以及一些小企业的需求,分享一下。 Jpom 是一款 Java 开发的简单轻量的低侵入式在线构建、自动部署、日常运维、项目监控软件。 日常开发中,Jpom 可以解决下面这些…

深入了解Java虚拟机(JVM)

Java虚拟机(JVM)是Java程序运行的核心组件,它负责解释执行Java字节码,并在各种平台上执行。JVM的设计使得Java具有跨平台性,开发人员只需编写一次代码,就可以在任何支持Java的系统上运行。我们刚开始学习Ja…

简易内存池2 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C 题目描述 请实现一个简易内存池,根据请求命令完成内存分配和释放。 内存池支持两种操作命令,REQUEST和RELEASE,其格式为: REQUEST请求的内存大小 …