import numpy as np
import onnx
import onnxruntime
import onnxruntime. backend as backendmodel = onnx. load( 'test.onnx' )
node = model. graph. node
graph = model. graph
new_node_0 = onnx. helper. make_node( "Mul" , inputs= [ "input_image" , "1" ] , outputs= [ "mutiply" ] ,
) mutiply_node = onnx. helper. make_node( "Constant" , inputs= [ ] , outputs= [ "1" ] , value= onnx. helper. make_tensor( 'value' , onnx. TensorProto. FLOAT, [ ] , [ 2.0 ] )
) new_node_1 = onnx. helper. make_node( "Add" , inputs= [ "mutiply" , "2" ] , outputs= [ "add" ] ,
) add_node = onnx. helper. make_node( "Constant" , inputs= [ ] , outputs= [ "2" ] , value= onnx. helper. make_tensor( 'value' , onnx. TensorProto. FLOAT, [ ] , [ - 1.0 ] )
)
old_squeeze_node = model. graph. node[ 0 ]
old_squeeze_node. input [ 0 ] = "add"
model. graph. node. remove( old_squeeze_node) graph. node. insert( 0 , mutiply_node)
graph. node. insert( 1 , new_node_0)
graph. node. insert( 2 , add_node)
graph. node. insert( 3 , new_node_1)
graph. node. insert( 4 , old_squeeze_node)
onnx. checker. check_model( model)
onnx. save( model, 'out.onnx' )
print ( onnxruntime. get_device( ) )
rt = backend. prepare( model, "CPU" )
out = rt. run( np. ones( [ 1 , 1 , 128 , 128 ] , dtype= np. float32) )
print ( out)
第二种使用可供训练的初始化参数
import numpy as np
import onnx
import onnxruntime
import onnxruntime. backend as backendmodel = onnx. load( 'test.onnx' )
node = model. graph. node
graph = model. graph
mutiply_node = onnx. helper. make_tensor( name= '1' , data_type= onnx. TensorProto. FLOAT, dims= [ 1 ] , vals = np. array( [ 2.0 ] , dtype= np. float32) ) graph. initializer. append( mutiply_node) new_node_0 = onnx. helper. make_node( "Mul" , inputs= [ "input_image" , "1" ] , outputs= [ "mutiply" ] ,
) add_node = onnx. helper. make_tensor( name= '2' , data_type= onnx. TensorProto. FLOAT, dims= [ 1 ] , vals = np. array( [ - 1. ] , dtype= np. float32) ) graph. initializer. append( add_node) new_node_1 = onnx. helper. make_node( "Add" , inputs= [ "mutiply" , "2" ] , outputs= [ "add" ] ,
)
old_squeeze_node = model. graph. node[ 0 ]
old_squeeze_node. input [ 0 ] = "add"
model. graph. node. remove( old_squeeze_node) graph. node. insert( 0 , new_node_0)
graph. node. insert( 1 , new_node_1)
graph. node. insert( 2 , old_squeeze_node)
onnx. checker. check_model( model)
onnx. save( model, 'out.onnx' )
print ( onnxruntime. get_device( ) )
rt = backend. prepare( model, "CPU" )
out = rt. run( np. ones( [ 1 , 1 , 128 , 128 ] , dtype= np. float32) )
print ( out)