在深度学习领域,PyTorch因其灵活性和用户友好的特性而广受欢迎。然而,在生产环境中,我们可能需要将PyTorch模型转换为其他格式,如ONNX或Caffe,以便利用其他框架的优势,比如更高效的推理速度。本文将介绍如何将PyTorch模型转换为ONNX格式,并进一步转换为Caffe模型,同时指出转换过程中需要注意的关键点。

1. PyTorch转ONNX

ONNX(开放神经网络交换格式)是一个开放的生态系统,允许AI开发者在不同的框架、工具之间移植模型。首先,我们需要安装必要的库:

pip install torch torchvision onnx

示例:转换一个简单的PyTorch模型

假设我们有一个简单的PyTorch模型:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(50 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 50 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

为了转换这个模型,我们需要定义输入张量,并使用torch.onnx.export函数:

model = SimpleModel()
model.eval()

# 定义一个输入张量
x = torch.randn(1, 1, 28, 28)

# 导出到ONNX
torch.onnx.export(model,               # 运行的模型
                  x,                   # 模型输入 (或一个元组对于多个输入)
                  "simple_model.onnx", # 保存模型的路径
                  export_params=True,  # 带有模型参数的存储
                  opset_version=10,    # ONNX版本
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names = ['input'],   # 输入名
                  output_names = ['output'], # 输出名
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 批量大小动态轴
                                'output' : {0 : 'batch_size'}})

需要注意的关键点

  • 确保模型处于eval模式,以禁用BatchNorm和Dropout层中的训练特定行为。
  • 定义的输入张量x应该与实际模型输入的形状相匹配。
  • opset_version:根据模型使用的操作选择合适的版本。较新的版本支持更多的操作,但可能不被所有框架支持。

2. 随着工具的不断完善,大多数情况下只需要经过上面的步骤即可完成模型转换需求。但是,有时候有一些特殊的层onnx并不支持,或者是要根据部署需求对模型调整,以下通过一个简单的示例进行展示。

2.1 对已经转换完成的onnx模型进行调整,比如需要删除onnx结构中的某些节点,然后将剩余的节点重新连接。

import onnx

onnx_model = onnx.load('model.onnx')
graph = onnx_model.graph
nodes = graph.node
input = graph.input
output = graph.output

print('input:\n', input)
print('output:\n', output)
print('----------------------------')
print('nodes[0]:\n', nodes[0])
print('nodes[1]:\n', nodes[1].input)

graph.node.remove(nodes[0])
graph.node.remove(nodes[0])
graph.node.remove(nodes[0])
graph.node.remove(nodes[0]) # 重复四次是因为每次删除后节点都会重新编号排序
graph.node[0].input[0] = 'input' # 将剩余节点和input连接起来

onnx.checker.check_graph(graph)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'modify.onnx')

2.2 在转换之前变换模型层。比如我之前碰到一个在forward中对最后的输出作x*std+mean操作的模型,这样导出的onnx就会多出乘法和加法节点,为了省去这两个节点,可以使用卷积实现此乘加操作;还有想要将图像超分辨率模型输出的RGB图像转换为YUV400格式,同样可以使用增加卷积的方式实现。

class ma_conv(nn.Module):  # 卷积实现乘加操作
    def __init__(self, c, mean, std):
        super().__init__()
        self.conv = nn.Conv2d(c, c, 1, 1, 0, 1, c, bias=True)
        print(self.conv.bias.data, self.conv.weight.data)
        self.conv.weight.data[:] = nn.Parameter(std.permute(1, 0, 2, 3))
        self.conv.bias.data[:] = nn.Parameter(mean.squeeze())
        print(self.conv.bias.data, self.conv.weight.data)

    def forward(self, x):
        x = self.conv(x)
        return x

3. ONNX转Caffe

转换为ONNX格式后,我们可以使用各种工具将ONNX模型转换为其他格式,如Caffe。这一步通常需要使用专门的转换工具或库。以ONNX to Caffe2为例,虽然Caffe2已经与PyTorch合并,但转换工具仍然可以用于演示转换过程。

示例:使用onnx-caffe2转换

首先,确保安装了onnx-caffe2

pip install onnx-caffe2

然后,使用以下代码进行转换:

import onnx
from onnx_caffe2.backend import Caffe2Backend

# 加载ONNX模型
model = onnx.load("simple_model.onnx")

# 转换
caffe2_model = Caffe2Backend.prepare(model)

# 保存转换后的模型
with open("simple_model.pb", "wb") as f:
    f.write(caffe2_model.SerializeToString())

需要注意的关键点

  • 在转换过程中,确保ONNX模型使用的操作在目标框架中有对应的实现。
  • 转换工具可能不支持所有ONNX操作,需要检查转换后的模型是否完整。
  • 在转换大型或复杂模型时,可能需要根据目标框架进行手动调整或优化。

Caffe框架目前基本已经退出历史舞台,它的安装等也非常复杂,很多时候都会卡在环境配置这一步,可以使用一个不需要安装Caffe的PytorchToCaffe开源库作模型转换,直接将pytorch模型转换为caffe。

最后修改:2024 年 03 月 13 日
如果觉得我的文章对你有用,请随意赞赏