-
Notifications
You must be signed in to change notification settings - Fork 83
Description
Here is a minimal example to reproduce the issue:
import torch
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, I):
H = I.size(2) # H is symbolic (dynamic_shape)
x = torch.nn.functional.pad(I, (0, 0, 0, H), mode="replicate")
return x
torch.onnx.export(Net(),
(torch.rand(1,1,1,1),),
"net.onnx",
dynamo=True,
dynamic_shapes={"I" : {2 : torch.export.Dim("H")}},
report=True)
Now i'm very much not a python dev, actually have ever done anything in python.
So take anything i say with a grain of salt =^).
But i ran into this issue, because in U-Nets for image denoising we sometimes pad upto an alignment, to ensure that following concat nodes see, two tensors with identical spatial extent, which means that at least in this field dynamic padding is a incredibly commonly used operation.
Fix:
I looked through the error logs a bit (see below), and i think what's goin on is that the
_process_padding
function at /onnxscript/blob/main/onnxscript/function_libs/torch_lib/ops/nn.py is missing a splat operator in the dynamic value path, idk if it's actually called this, but good name =^). Maybe iam also wrong here, not a python dev afterall, but this at least fixed the issue for me:
def _process_padding(padding, rank):
"""Convert PyTorch padding for ONNX Pad."""
assert isinstance(padding, (list, tuple))
if all(isinstance(pad, int) for pad in padding):
paddings = padding
zeros = [0] * (rank * 2 - len(paddings))
paddings = [*paddings, *zeros]
paddings = paddings[-2::-2] + paddings[-1::-2]
return op.Constant(value=ir.tensor(paddings, dtype=ir.DataType.INT64))
else:
paddings = []
for pad in padding:
if isinstance(pad, int):
paddings.append(op.Constant(value_ints=[pad]))
else:
# Dynamic value
paddings.append(op.Reshape(pad, [-1]))
# Create a series of 1d zero tensors
zero = op.Constant(value_ints=[0])
zeros = [zero] * (rank * 2 - len(paddings))
paddings = [*paddings, *zeros]
# Interleave the padding values
paddings = paddings[-2::-2] + paddings[-1::-2]
return op.Concat(*paddings, axis=0) # <- ONLY LINE THAT'S CHANGED!
torch_nn_ops._process_padding = _process_padding
Markdown Report:
PyTorch ONNX Conversion Report
✅ Obtain model graph with `torch.export.export(..., strict=False)`
⚪ Obtain model graph with `torch.export.export(..., strict=True)`
⚪ Obtain model graph with `torch.export._draft_export.draft_export`
✅ Decompose operators for ONNX compatibility
❌ Translate the graph into ONNX
⚪ Run `onnx.checker` on the ONNX model
⚪ Execute the model with ONNX Runtime
⚪ Validate model output accuracy
Error messages
Traceback (most recent call last):
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 565, in _call_op
converted_named_inputs = _process_python_sequences(
op_signature,
...<3 lines>...
self.opset,
)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 415, in _process_python_sequences
dtype = _determine_input_dtype(param, arg, type_binding)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 234, in _determine_input_dtype
raise ValueError(
...<3 lines>...
)
ValueError: Could not determine the dtype for the input 'inputs'. param=inputs: T, arg=([SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589456', producer=anonymous_node:139994164994864, index=0), SymbolicTensor(name='anonymous:139994164589776', producer=anonymous_node:139994164994432, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164588816', producer=anonymous_node:139994164995152, index=0), SymbolicTensor(name='anonymous:139994164589616', producer=anonymous_node:139994164994720, index=0)],), param_type_constraint=T=INT16 | UINT16 | INT64 | FLOAT16 | COMPLEX64 | BOOL | STRING | COMPLEX128 | UINT8 | INT32 | DOUBLE | BFLOAT16 | UINT64 | FLOAT | INT8 | UINT32, type_binding={}
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 630, in eval
outputs = self._call_op(
op_signature, named_inputs, named_attrs, num_outputs
)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 574, in _call_op
raise _errors.GraphConstructionError(
...<2 lines>...
) from e
torch.onnx._internal.exporter._errors.GraphConstructionError: Error processing Python constants for operator '::Concat'. named_inputs={'inputs': ([SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589456', producer=anonymous_node:139994164994864, index=0), SymbolicTensor(name='anonymous:139994164589776', producer=anonymous_node:139994164994432, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164588816', producer=anonymous_node:139994164995152, index=0), SymbolicTensor(name='anonymous:139994164589616', producer=anonymous_node:139994164994720, index=0)],)}, named_attrs={'axis': 0}, opset=, op_signature=''::Concat(inputs: T, axis: INT = None) -> (T) where T=INT16 | UINT16 | INT64 | FLOAT16 | COMPLEX64 | BOOL | STRING | COMPLEX128 | UINT8 | INT32 | DOUBLE | BFLOAT16 | UINT64 | FLOAT | INT8 | UINT32.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 591, in _handle_call_function_node_with_lowering
outputs = onnx_function(*onnx_args, **onnx_kwargs)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/values.py", line 625, in __call__
return self.func(*args, **kwargs)
~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/function_libs/torch_lib/ops/nn.py", line 1519, in aten_pad
paddings = _process_padding(pad, rank)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/function_libs/torch_lib/ops/nn.py", line 1506, in _process_padding
return op.Concat(paddings, axis=0)
~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/onnx_opset/_impl/opset13.py", line 389, in Concat
return op(*self._prepare_inputs(schema, *inputs), axis=axis)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/onnxscript/values.py", line 309, in __call__
return evaluator.default().eval(schema, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_building.py", line 637, in eval
raise _errors.GraphConstructionError(
f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}."
) from e
torch.onnx._internal.exporter._errors.GraphConstructionError: Error calling operator 'Concat' with args ([SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589456', producer=anonymous_node:139994164994864, index=0), SymbolicTensor(name='anonymous:139994164589776', producer=anonymous_node:139994164994432, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164589296', producer=anonymous_node:139994164995296, index=0), SymbolicTensor(name='anonymous:139994164588816', producer=anonymous_node:139994164995152, index=0), SymbolicTensor(name='anonymous:139994164589616', producer=anonymous_node:139994164994720, index=0)],) and kwargs {'axis': 0}.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 790, in _translate_fx_graph
_handle_call_function_node_with_lowering(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
model,
^^^^^^
...<6 lines>...
node_name_to_local_functions=node_name_to_local_functions,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 593, in _handle_call_function_node_with_lowering
raise _errors.GraphConstructionError(
f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'"
) from e
torch.onnx._internal.exporter._errors.GraphConstructionError: Error when calling function 'TracedOnnxFunction(<function aten_pad at 0x7f52f0ba9ee0>)' with args '[SymbolicTensor(name='i', type=Tensor(FLOAT), shape=Shape([1, 1, SymbolicDim(s95), 1])), [0, 0, 0, SymbolicTensor(name='sym_size_int_2', type=Tensor(INT64), shape=Shape([]), producer='node_sym_size_int_2', index=0)], 'replicate']' and kwargs '{}'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1459, in export
onnx_program = _exported_program_to_onnx_program(
decomposed_program, registry=registry
)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1090, in _exported_program_to_onnx_program
values = _translate_fx_graph(
fx_graph,
...<4 lines>...
registry=registry,
)
File "/home/kistenklaus/Documents/kit/hiwi/vkcnn/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 816, in _translate_fx_graph
raise _errors.ConversionError(
f"Error when translating node {node.format_node()}. See the stack trace for more information."
) from e
torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %pad : [num_users=1] = call_function[target=torch.ops.aten.pad.default](args = (%i, [0, 0, 0, %sym_size_int_2], replicate), kwargs = {}). See the stack trace for more information.
Exported program
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, i: "f32[1, 1, s95, 1]"):
#
sym_size_int_2: "Sym(s95)" = torch.ops.aten.sym_size.int(i, 2)
# File: /home/kistenklaus/Documents/kit/hiwi/vkcnn/fuckyou_dynamo.py:47 in forward, code: x = torch.nn.functional.pad(I, (0, 0, 0, H), mode="replicate")
pad: "f32[1, 1, 2*s95, 1]" = torch.ops.aten.pad.default(i, [0, 0, 0, sym_size_int_2], 'replicate'); i = sym_size_int_2 = None
return (pad,)
Graph signature:
# inputs
i: USER_INPUT
# outputs
pad: USER_OUTPUT
Range constraints: {s95: VR[0, int_oo]}
Analysis
PyTorch ONNX Conversion Analysis
Model Information
The model has 0 parameters and 0 buffers (non-trainable parameters).
Number of parameters per dtype:
defaultdict(<class 'int'>, {})
Number of buffers per dtype:
defaultdict(<class 'int'>, {})
Inputs:
i
:TensorMetadata(shape=torch.Size([1, 1, s95, 1]), dtype=torch.float32, requires_grad=False, stride=(s95, s95, 1, 1), memory_format=torch.channels_last, is_quantized=False, qparams={})
Outputs:
pad
:TensorMetadata(shape=torch.Size([1, 1, 2*s95, 1]), dtype=torch.float32, requires_grad=False, stride=(2*s95, 2*s95, 1, 1), memory_format=torch.channels_last, is_quantized=False, qparams={})
The FX graph has 4 nodes in total. Number of FX nodes per op:
placeholder
: 1call_function
: 2output
: 1
Of the call_function nodes, the counts of operators used are:
aten.sym_size.int
: 1aten.pad.default
: 1
ONNX Conversion Information
All operators in the model have registered ONNX decompositions.
Decomposition comparison
Ops exist only in the ExportedProgram before decomposition: []
Ops exist only in the ExportedProgram after decomposition: []