-
Notifications
You must be signed in to change notification settings - Fork 20
From PyTorch to ONNX
ONNX is an open format built to represent machine learning models. Existing vs-mlrt runtimes only support this format for inference.
Given any PyTorch model loaded in Python with type torch.nn.Module
(different models may have to be loaded in different ways), the conversion to the ONNX format is sometimes as easy as follows:
# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
'input': {0:'batch_size', 2:'width', 3:'height'},
'output': {0:'batch_size' , 2:'width', 3:'height'}
}
channels = 3
input = torch.ones(1, channels, 64, 64)
torch.export(
model, input, "output.onnx",
input_names=["input"],
dynamic_axes=dynamic_axes,
opset_version=14
)
However, sometimes errors may be raised, and the code defining the network structure has to be modified.
Common fixes includes:
-
torch.nn.functional.pad(x, -4, -4, -4, -4)
=>x[..., 4:-4, 4:-4]
- remove
x.shape
/x.size()
Sometimes it is hard to convert. This may occur due to the lack of support (e.g. GridSample in common video models, deformable convolution in basicvsr++). Sometimes the backend may be erroneous (e.g. TensorRT in dealing with slicing).
torch.nn.functional.grid_sample
(for the ort backend) can be converted by
# https://github.com/microsoft/onnxruntime/blob/c1cf16ed5d078b69db865c633e446a6038d22cea/onnxruntime/python/tools/pytorch_export_contrib_ops.py
def _reg(symbolic_fn):
from torch.onnx import register_custom_op_symbolic
name = f"::{symbolic_fn.__name__}"
register_custom_op_symbolic(name, symbolic_fn, 1)
def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
import torch.onnx.symbolic_helper as sym_help
mode = sym_help._maybe_get_const(mode, "i")
padding_mode = sym_help._maybe_get_const(padding_mode, "i")
mode_str = ['bilinear', 'nearest', 'bicubic'][mode]
padding_mode_str = ['zeros', 'border', 'reflection'][padding_mode]
align_corners = int(sym_help._maybe_get_const(align_corners, "b"))
return g.op("com.microsoft::GridSample", input, grid,
mode_s=mode_str,
padding_mode_s=padding_mode_str,
align_corners_i=align_corners)
_reg(grid_sampler)
torch.onnx.export(..., custom_opsets={"com.microsoft": 1})
- Runtimes
- Models
- Device-specific benchmarks
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 3090
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA Quadro P6000
- AMD Radeon RX 7900 XTX
- AMD Radeon Pro V620
- AMD Radeon Pro V520
- AMD Radeon VII
- AMD EPYC Zen4
- Intel Core Ultra 7 155H
- Intel Arc A380
- Intel Arc A770
- Intel Data Center GPU Flex 170
- Intel Data Center GPU Max 1100
- Intel Xeon Sapphire Rapids