Skip to content

Dynamic Padding bug, with a minimal fix! #2539

@kistenklaus

Description

@kistenklaus

Here is a minimal example to reproduce the issue:

import torch
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, I):
        H = I.size(2) # H is symbolic (dynamic_shape)
        x = torch.nn.functional.pad(I, (0, 0, 0, H), mode="replicate")
        return x

torch.onnx.export(Net(), 
                  (torch.rand(1,1,1,1),), 
                  "net.onnx",
                  dynamo=True,
                  dynamic_shapes={"I" : {2 : torch.export.Dim("H")}},
                  report=True)

Now i'm very much not a python dev, actually have ever done anything in python.
So take anything i say with a grain of salt =^).

But i ran into this issue, because in U-Nets for image denoising we sometimes pad upto an alignment, to ensure that following concat nodes see, two tensors with identical spatial extent, which means that at least in this field dynamic padding is a incredibly commonly used operation.

Fix:

I looked through the error logs a bit (see below), and i think what's goin on is that the
_process_padding function at /onnxscript/blob/main/onnxscript/function_libs/torch_lib/ops/nn.py is missing a splat operator in the dynamic value path, idk if it's actually called this, but good name =^). Maybe iam also wrong here, not a python dev afterall, but this at least fixed the issue for me:

def _process_padding(padding, rank):
    """Convert PyTorch padding for ONNX Pad."""
    assert isinstance(padding, (list, tuple))
    if all(isinstance(pad, int) for pad in padding):
        paddings = padding
        zeros = [0] * (rank * 2 - len(paddings))
        paddings = [*paddings, *zeros]
        paddings = paddings[-2::-2] + paddings[-1::-2]
        return op.Constant(value=ir.tensor(paddings, dtype=ir.DataType.INT64))
    else:
        paddings = []
        for pad in padding:
            if isinstance(pad, int):
                paddings.append(op.Constant(value_ints=[pad]))
            else:
                # Dynamic value
                paddings.append(op.Reshape(pad, [-1]))
        # Create a series of 1d zero tensors
        zero = op.Constant(value_ints=[0])
        zeros = [zero] * (rank * 2 - len(paddings))
        paddings = [*paddings, *zeros]
        # Interleave the padding values
        paddings = paddings[-2::-2] + paddings[-1::-2]
        return op.Concat(*paddings, axis=0) # <- ONLY LINE THAT'S CHANGED!

torch_nn_ops._process_padding = _process_padding

Markdown Report:

PyTorch ONNX Conversion Report

✅ Obtain model graph with `torch.export.export(..., strict=False)`
⚪ Obtain model graph with `torch.export.export(..., strict=True)`
⚪ Obtain model graph with `torch.export._draft_export.draft_export`
✅ Decompose operators for ONNX compatibility
❌ Translate the graph into ONNX
⚪ Run `onnx.checker` on the ONNX model
⚪ Execute the model with ONNX Runtime
⚪ Validate model output accuracy

Error messages

Traceback (most recent call last):

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 565, in _call_op
    converted_named_inputs = _process_python_sequences(
        op_signature,
    ...<3 lines>...
        self.opset,
    )

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 415, in _process_python_sequences
    dtype = _determine_input_dtype(param, arg, type_binding)

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 234, in _determine_input_dtype
    raise ValueError(
    ...<3 lines>...
    )

ValueError: Could not determine the dtype for the input 'inputs'. param=inputs: T, arg=([SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589456', producer=anonymous_node:139994164994864, index=0), SymbolicTensor(name='anonymous:139994164589776', producer=anonymous_node:139994164994432, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164588816', producer=anonymous_node:139994164995152, index=0), SymbolicTensor(name='anonymous:139994164589616', producer=anonymous_node:139994164994720, index=0)],), param_type_constraint=T=INT16 | UINT16 | INT64 | FLOAT16 | COMPLEX64 | BOOL | STRING | COMPLEX128 | UINT8 | INT32 | DOUBLE | BFLOAT16 | UINT64 | FLOAT | INT8 | UINT32, type_binding={}


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 630, in eval
    outputs = self._call_op(
        op_signature, named_inputs, named_attrs, num_outputs
    )

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 574, in _call_op
    raise _errors.GraphConstructionError(
    ...<2 lines>...
    ) from e

torch.onnx._internal.exporter._errors.GraphConstructionError: Error processing Python constants for operator '::Concat'. named_inputs={'inputs': ([SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589456', producer=anonymous_node:139994164994864, index=0), SymbolicTensor(name='anonymous:139994164589776', producer=anonymous_node:139994164994432, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164588816', producer=anonymous_node:139994164995152, index=0), SymbolicTensor(name='anonymous:139994164589616', producer=anonymous_node:139994164994720, index=0)],)}, named_attrs={'axis': 0}, opset=, op_signature=''::Concat(inputs: T, axis: INT = None) -> (T) where T=INT16 | UINT16 | INT64 | FLOAT16 | COMPLEX64 | BOOL | STRING | COMPLEX128 | UINT8 | INT32 | DOUBLE | BFLOAT16 | UINT64 | FLOAT | INT8 | UINT32.


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 591, in _handle_call_function_node_with_lowering
    outputs = onnx_function(*onnx_args, **onnx_kwargs)

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/values.py", line 625, in __call__
    return self.func(*args, **kwargs)
           ~~~~~~~~~^^^^^^^^^^^^^^^^^

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/function_libs/torch_lib/ops/nn.py", line 1519, in aten_pad
    paddings = _process_padding(pad, rank)

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/function_libs/torch_lib/ops/nn.py", line 1506, in _process_padding
    return op.Concat(paddings, axis=0)
           ~~~~~~~~~^^^^^^^^^^^^^^^^^^

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/onnx_opset/_impl/opset13.py", line 389, in Concat
    return op(*self._prepare_inputs(schema, *inputs), axis=axis)

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/values.py", line 309, in __call__
    return evaluator.default().eval(schema, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 637, in eval
    raise _errors.GraphConstructionError(
        f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}."
    ) from e

torch.onnx._internal.exporter._errors.GraphConstructionError: Error calling operator 'Concat' with args ([SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589456', producer=anonymous_node:139994164994864, index=0), SymbolicTensor(name='anonymous:139994164589776', producer=anonymous_node:139994164994432, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164588816', producer=anonymous_node:139994164995152, index=0), SymbolicTensor(name='anonymous:139994164589616', producer=anonymous_node:139994164994720, index=0)],) and kwargs {'axis': 0}.


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 790, in _translate_fx_graph
    _handle_call_function_node_with_lowering(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        model,
        ^^^^^^
    ...<6 lines>...
        node_name_to_local_functions=node_name_to_local_functions,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 593, in _handle_call_function_node_with_lowering
    raise _errors.GraphConstructionError(
        f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'"
    ) from e

torch.onnx._internal.exporter._errors.GraphConstructionError: Error when calling function 'TracedOnnxFunction(<function aten_pad at 0x7f52f0ba9ee0>)' with args '[SymbolicTensor(name='i', type=Tensor(FLOAT), shape=Shape([1, 1, SymbolicDim(s95), 1])), [0, 0, 0, SymbolicTensor(name='sym_size_int_2', type=Tensor(INT64), shape=Shape([]), producer='node_sym_size_int_2', index=0)], 'replicate']' and kwargs '{}'


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1459, in export
    onnx_program = _exported_program_to_onnx_program(
        decomposed_program, registry=registry
    )

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1090, in _exported_program_to_onnx_program
    values = _translate_fx_graph(
        fx_graph,
    ...<4 lines>...
        registry=registry,
    )

  File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 816, in _translate_fx_graph
    raise _errors.ConversionError(
        f"Error when translating node {node.format_node()}. See the stack trace for more information."
    ) from e

torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %pad : [num_users=1] = call_function[target=torch.ops.aten.pad.default](args = (%i, [0, 0, 0, %sym_size_int_2], replicate), kwargs = {}). See the stack trace for more information.

Exported program

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, i: "f32[1, 1, s95, 1]"):
             # 
            sym_size_int_2: "Sym(s95)" = torch.ops.aten.sym_size.int(i, 2)
            
             # File: /home/kistenklaus/Documents/kit/hiwi/vkcnn/fuckyou_dynamo.py:47 in forward, code: x = torch.nn.functional.pad(I, (0, 0, 0, H), mode="replicate")
            pad: "f32[1, 1, 2*s95, 1]" = torch.ops.aten.pad.default(i, [0, 0, 0, sym_size_int_2], 'replicate');  i = sym_size_int_2 = None
            return (pad,)
            
Graph signature: 
    # inputs
    i: USER_INPUT
    
    # outputs
    pad: USER_OUTPUT
    
Range constraints: {s95: VR[0, int_oo]}

Analysis

PyTorch ONNX Conversion Analysis

Model Information

The model has 0 parameters and 0 buffers (non-trainable parameters).
Number of parameters per dtype:

defaultdict(<class 'int'>, {})

Number of buffers per dtype:

defaultdict(<class 'int'>, {})

Inputs:

  • i: TensorMetadata(shape=torch.Size([1, 1, s95, 1]), dtype=torch.float32, requires_grad=False, stride=(s95, s95, 1, 1), memory_format=torch.channels_last, is_quantized=False, qparams={})

Outputs:

  • pad: TensorMetadata(shape=torch.Size([1, 1, 2*s95, 1]), dtype=torch.float32, requires_grad=False, stride=(2*s95, 2*s95, 1, 1), memory_format=torch.channels_last, is_quantized=False, qparams={})

The FX graph has 4 nodes in total. Number of FX nodes per op:

  • placeholder: 1
  • call_function: 2
  • output: 1

Of the call_function nodes, the counts of operators used are:

  • aten.sym_size.int: 1
  • aten.pad.default: 1

ONNX Conversion Information

All operators in the model have registered ONNX decompositions.

Decomposition comparison

Ops exist only in the ExportedProgram before decomposition: []

Ops exist only in the ExportedProgram after decomposition: []

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions