Skip to content

From PyTorch to ONNX

WolframRhodium edited this page Mar 1, 2022 · 4 revisions

ONNX is an open format built to represent machine learning models. Existing vs-mlrt runtimes only support this format for inference.

Given any PyTorch model loaded in Python with type torch.nn.Module (different models may have to be loaded in different ways), the conversion to the ONNX format is sometimes as easy as follows:

# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
    'input': {0:'batch_size', 2:'width', 3:'height'},
    'output': {0:'batch_size' , 2:'width', 3:'height'}
}

channels = 3
input = torch.ones(1, channels, 64, 64)

torch.export(
    model, input, "output.onnx",
    input_names=["input"],
    dynamic_axes=dynamic_axes,
    opset_version=14
)

However, sometimes errors may be raised, and the code defining the network structure has to be modified.


Common fixes includes:

  • torch.nn.functional.pad(x, -4, -4, -4, -4) => x[..., 4:-4, 4:-4]
  • remove x.shape / x.size()

Sometimes it is hard to convert. This may occur due to the lack of support (e.g. GridSample in common video models, deformable convolution in basicvsr++). Sometimes the backend may be erroneous (e.g. TensorRT in dealing with slicing).


torch.nn.functional.grid_sample (for the ort backend) can be converted by

# https://github.com/microsoft/onnxruntime/blob/c1cf16ed5d078b69db865c633e446a6038d22cea/onnxruntime/python/tools/pytorch_export_contrib_ops.py

def _reg(symbolic_fn):
    from torch.onnx import register_custom_op_symbolic
    name = f"::{symbolic_fn.__name__}"
    register_custom_op_symbolic(name, symbolic_fn, 1)

def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
    import torch.onnx.symbolic_helper as sym_help
    mode = sym_help._maybe_get_const(mode, "i")
    padding_mode = sym_help._maybe_get_const(padding_mode, "i")
    mode_str = ['bilinear', 'nearest', 'bicubic'][mode]
    padding_mode_str = ['zeros', 'border', 'reflection'][padding_mode]
    align_corners = int(sym_help._maybe_get_const(align_corners, "b"))

    return g.op("com.microsoft::GridSample", input, grid,
                mode_s=mode_str,
                padding_mode_s=padding_mode_str,
                align_corners_i=align_corners)
_reg(grid_sampler)

torch.onnx.export(..., custom_opsets={"com.microsoft": 1})