在深度学习领域,PyTorch因其灵活性和用户友好的特性而广受欢迎。然而,在生产环境中,我们可能需要将PyTorch模型转换为其他格式,如ONNX或Caffe,以便利用其他框架的优势,比如更高效的推理速度。本文将介绍如何将PyTorch模型转换为ONNX格式,并进一步转换为Caffe模型,同时指出转换过程中需要注意的关键点。
1. PyTorch转ONNX
ONNX(开放神经网络交换格式)是一个开放的生态系统,允许AI开发者在不同的框架、工具之间移植模型。首先,我们需要安装必要的库:
pip install torch torchvision onnx
示例:转换一个简单的PyTorch模型
假设我们有一个简单的PyTorch模型:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(20, 50, 5)
self.fc1 = nn.Linear(50 * 4 * 4, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 50 * 4 * 4)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
为了转换这个模型,我们需要定义输入张量,并使用torch.onnx.export
函数:
model = SimpleModel()
model.eval()
# 定义一个输入张量
x = torch.randn(1, 1, 28, 28)
# 导出到ONNX
torch.onnx.export(model, # 运行的模型
x, # 模型输入 (或一个元组对于多个输入)
"simple_model.onnx", # 保存模型的路径
export_params=True, # 带有模型参数的存储
opset_version=10, # ONNX版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names = ['input'], # 输入名
output_names = ['output'], # 输出名
dynamic_axes={'input' : {0 : 'batch_size'}, # 批量大小动态轴
'output' : {0 : 'batch_size'}})
需要注意的关键点
- 确保模型处于
eval
模式,以禁用BatchNorm和Dropout层中的训练特定行为。 - 定义的输入张量
x
应该与实际模型输入的形状相匹配。 opset_version
:根据模型使用的操作选择合适的版本。较新的版本支持更多的操作,但可能不被所有框架支持。
2. 随着工具的不断完善,大多数情况下只需要经过上面的步骤即可完成模型转换需求。但是,有时候有一些特殊的层onnx并不支持,或者是要根据部署需求对模型调整,以下通过一个简单的示例进行展示。
2.1 对已经转换完成的onnx模型进行调整,比如需要删除onnx结构中的某些节点,然后将剩余的节点重新连接。
import onnx
onnx_model = onnx.load('model.onnx')
graph = onnx_model.graph
nodes = graph.node
input = graph.input
output = graph.output
print('input:\n', input)
print('output:\n', output)
print('----------------------------')
print('nodes[0]:\n', nodes[0])
print('nodes[1]:\n', nodes[1].input)
graph.node.remove(nodes[0])
graph.node.remove(nodes[0])
graph.node.remove(nodes[0])
graph.node.remove(nodes[0]) # 重复四次是因为每次删除后节点都会重新编号排序
graph.node[0].input[0] = 'input' # 将剩余节点和input连接起来
onnx.checker.check_graph(graph)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'modify.onnx')
2.2 在转换之前变换模型层。比如我之前碰到一个在forward中对最后的输出作x*std+mean操作的模型,这样导出的onnx就会多出乘法和加法节点,为了省去这两个节点,可以使用卷积实现此乘加操作;还有想要将图像超分辨率模型输出的RGB图像转换为YUV400格式,同样可以使用增加卷积的方式实现。
class ma_conv(nn.Module): # 卷积实现乘加操作
def __init__(self, c, mean, std):
super().__init__()
self.conv = nn.Conv2d(c, c, 1, 1, 0, 1, c, bias=True)
print(self.conv.bias.data, self.conv.weight.data)
self.conv.weight.data[:] = nn.Parameter(std.permute(1, 0, 2, 3))
self.conv.bias.data[:] = nn.Parameter(mean.squeeze())
print(self.conv.bias.data, self.conv.weight.data)
def forward(self, x):
x = self.conv(x)
return x
3. ONNX转Caffe
转换为ONNX格式后,我们可以使用各种工具将ONNX模型转换为其他格式,如Caffe。这一步通常需要使用专门的转换工具或库。以ONNX to Caffe2为例,虽然Caffe2已经与PyTorch合并,但转换工具仍然可以用于演示转换过程。
示例:使用onnx-caffe2转换
首先,确保安装了onnx-caffe2
:
pip install onnx-caffe2
然后,使用以下代码进行转换:
import onnx
from onnx_caffe2.backend import Caffe2Backend
# 加载ONNX模型
model = onnx.load("simple_model.onnx")
# 转换
caffe2_model = Caffe2Backend.prepare(model)
# 保存转换后的模型
with open("simple_model.pb", "wb") as f:
f.write(caffe2_model.SerializeToString())
需要注意的关键点
- 在转换过程中,确保ONNX模型使用的操作在目标框架中有对应的实现。
- 转换工具可能不支持所有ONNX操作,需要检查转换后的模型是否完整。
- 在转换大型或复杂模型时,可能需要根据目标框架进行手动调整或优化。
Caffe框架目前基本已经退出历史舞台,它的安装等也非常复杂,很多时候都会卡在环境配置这一步,可以使用一个不需要安装Caffe的PytorchToCaffe开源库作模型转换,直接将pytorch模型转换为caffe。