Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch API

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入
input_name = 'input'
output_name = 'output'
torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx",opset_version=11,input_names=[input_name],output_names=[output_name],dynamic_axes={input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torch
import torch.nn as nnclass Model_Net(nn.Module):def __init__(self):super(Model_Net, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)def forward(self, data):data = self.layer1(data)return dataif __name__ == "__main__":# 设置输入参数Batch_size = 8Channel = 3Height = 256Width = 256input_data = torch.rand((Batch_size, Channel, Height, Width))# 实例化模型model = Model_Net()# 导出为静态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name])# 导出为动态输入torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx",opset_version=11,input_names=[input_name],output_names=[output_name],dynamic_axes={input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型

import numpy as np
import onnx
import onnxruntimeif __name__ == "__main__":input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32)input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32)# 导入 Onnx 模型Onnx_file = "./Dynamics_InputNet.onnx"Model = onnx.load(Onnx_file)onnx.checker.check_model(Model) # 验证Onnx模型是否准确# 使用 onnxruntime 推理model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])input_name = model.get_inputs()[0].nameoutput_name = model.get_outputs()[0].nameoutput1 = model.run([output_name], {input_name:input_data1})output2 = model.run([output_name], {input_name:input_data2})print('output1.shape: ', np.squeeze(np.array(output1), 0).shape)print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部