From 3cb493f871acfe164bae5b6b7790219936860348 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 17:55:29 +0200 Subject: [PATCH 1/7] Implements aten_index_put if inputs are SymbolicTensor --- .../function_libs/torch_lib/ops/core.py | 105 +++++++++++++++++- .../function_libs/torch_lib/e2e_ops_tests.py | 43 +++++++ 2 files changed, 143 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a688a4277..08d4fab280 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4383,6 +4383,11 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + if any( + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) + for indice in indices + ): + return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. @@ -4452,14 +4457,104 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): # Flatten values to match the indices flat_values = op.Reshape(values, [-1]) - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") - else: - result = op.ScatterND(self, new_index, flat_values) - + scatter_kwargs = dict(reduction="add") if accumulate else {} + result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs) return result +def _aten_index_put_dynamic( + x: TReal, + indices: Sequence[INT64], + values: TReal, + accumulate: bool = False, +) -> TReal: + def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): + if ind is not None: + return op.Cast(ind, to=INT64.dtype), False + return ( + op.Cast( + op.Range( # Range does not return a typed result + 0, + op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), + 1, + ), + to=INT64.dtype, + ), + True, + ) + + rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] + assert all(rk1s) and len(rk1s) == len(x.shape), ( + f"input_put not implemented for indices={indices}, " + f"where rk1s={rk1s}, rank(x)={len(x.shape)}" + ) + shape_x = op.Shape(x) + exped = [] + fixed = [] + reshape_value_shape2 = [] + expand_value_shape = [] + for i, ind in enumerate(indices): + if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): + ind.dtype = ir.DataType.INT64 + ind, expanded = _make_range_or_cast(ind, shape_x, False, i) + if expanded: + exped.append((i, ind)) + expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) + reshape_value_shape2.append([1]) + else: + expand_value_shape.append([1]) + reshape_value_shape2.append(op.Shape(ind)) + fixed.append((i, ind)) + + reshape_value_shape1 = [1] * len(indices) + if len(fixed) <= 1: + reshape_value_shape1 = None + elif fixed: + reshape_value_shape1[fixed[-1][0]] = -1 + + def _mkstride(x, i): + if i >= len(x.shape) - 1: + return [1] + if i == len(x.shape) - 2: + return op.Shape(x, start=i + 1) + return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) + + shape = [1] * (len(x.shape) + 1) + mfixed = [] + if fixed: + new_shape = shape.copy() + new_shape[-1] = -1 + mfixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] + + mexped = [] + for i, e in exped: + new_shape = shape.copy() + new_shape[i] = -1 + mexped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) + + # final sum + unflat = None + for a in [*mfixed, *mexped]: + if unflat is None: + unflat = a + continue + unflat = op.Add(unflat, a) + + # value_shape + expanded_values = values + if reshape_value_shape1 is not None: + expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) + # Bug here: Error calling operator 'Concat' with args + # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) + expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) + flat_ind = op.Reshape(unflat, [-1]) + expanded_values = op.Reshape(expanded_values, [-1]) + flat_x = op.Reshape(x, [-1]) + scat_kwargs = {"reduction": "add"} if accumulate else {} + flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) + return op.Reshape(flat_up_x, op.Shape(x)) + + @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( self: TReal, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 1b0410c27f..437017af97 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -225,6 +225,49 @@ def forward(self, q, k, v): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + onnx_program = torch.onnx.export( + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From c9472f3bfcf9eb42aa7818de1deb86851fde23f1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 18:27:21 +0200 Subject: [PATCH 2/7] disable case with one index --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 08d4fab280..aefce5b038 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4383,7 +4383,7 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - if any( + if len(indices) > 1 and any( isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) for indice in indices ): From 434fcfb0130ae099a78964fca1d83f38220b8254 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 19:15:44 +0200 Subject: [PATCH 3/7] type constant --- .../function_libs/torch_lib/ops/core.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index aefce5b038..e434529d81 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4468,15 +4468,21 @@ def _aten_index_put_dynamic( values: TReal, accumulate: bool = False, ) -> TReal: + def _1dint(i: int): + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) + + def _0dint(i: int): + return op.Constant(value_int=ir.AttrInt64("value_int", i)) + def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): if ind is not None: return op.Cast(ind, to=INT64.dtype), False return ( op.Cast( op.Range( # Range does not return a typed result - 0, + _0dint(0), op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), - 1, + _0dint(1), ), to=INT64.dtype, ), @@ -4500,21 +4506,21 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): if expanded: exped.append((i, ind)) expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) - reshape_value_shape2.append([1]) + reshape_value_shape2.append(_1dint(1)) else: - expand_value_shape.append([1]) + expand_value_shape.append(_1dint(1)) reshape_value_shape2.append(op.Shape(ind)) fixed.append((i, ind)) - reshape_value_shape1 = [1] * len(indices) + reshape_value_shape1 = [_1dint(1)] * len(indices) if len(fixed) <= 1: reshape_value_shape1 = None elif fixed: - reshape_value_shape1[fixed[-1][0]] = -1 + reshape_value_shape1[fixed[-1][0]] = _1dint(-1) def _mkstride(x, i): if i >= len(x.shape) - 1: - return [1] + return _1dint(1) if i == len(x.shape) - 2: return op.Shape(x, start=i + 1) return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) @@ -4547,9 +4553,9 @@ def _mkstride(x, i): # Bug here: Error calling operator 'Concat' with args # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) - flat_ind = op.Reshape(unflat, [-1]) - expanded_values = op.Reshape(expanded_values, [-1]) - flat_x = op.Reshape(x, [-1]) + flat_ind = op.Reshape(unflat, _1dint(-1)) + expanded_values = op.Reshape(expanded_values, _1dint(-1)) + flat_x = op.Reshape(x, _1dint(-1)) scat_kwargs = {"reduction": "add"} if accumulate else {} flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) return op.Reshape(flat_up_x, op.Shape(x)) From ae6adca76ea8a6286053a07534ecb4dbd1a840e4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 10:15:39 +0200 Subject: [PATCH 4/7] another fix --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ecca0de8e5..45adf39df3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4386,7 +4386,7 @@ def aten_index_put( if len(indices) > 1 and any( isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) for indice in indices - ): + ) and len(values.shape) == 1: return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): From 851036421aa9373039cd019405f5a9fc179c7a83 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 10:31:58 +0200 Subject: [PATCH 5/7] rename --- .../function_libs/torch_lib/ops/core.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 45adf39df3..330cf02ef0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4383,10 +4383,14 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - if len(indices) > 1 and any( - isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) - for indice in indices - ) and len(values.shape) == 1: + if ( + len(indices) > 1 + and any( + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) + for indice in indices + ) + and len(values.shape) == 1 + ): return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): @@ -4526,21 +4530,21 @@ def _mkstride(x, i): return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) shape = [1] * (len(x.shape) + 1) - mfixed = [] + r_fixed = [] if fixed: new_shape = shape.copy() new_shape[-1] = -1 - mfixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] + r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] - mexped = [] + r_exped = [] for i, e in exped: new_shape = shape.copy() new_shape[i] = -1 - mexped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) + r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) # final sum unflat = None - for a in [*mfixed, *mexped]: + for a in [*r_fixed, *r_exped]: if unflat is None: unflat = a continue @@ -4550,8 +4554,6 @@ def _mkstride(x, i): expanded_values = values if reshape_value_shape1 is not None: expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) - # Bug here: Error calling operator 'Concat' with args - # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) flat_ind = op.Reshape(unflat, _1dint(-1)) expanded_values = op.Reshape(expanded_values, _1dint(-1)) From e4d574a4025a5e4137985c5963164b0e785f3d27 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 11:27:44 +0200 Subject: [PATCH 6/7] style --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 330cf02ef0..3d38c7d806 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4494,10 +4494,6 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): ) rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] - assert all(rk1s) and len(rk1s) == len(x.shape), ( - f"input_put not implemented for indices={indices}, " - f"where rk1s={rk1s}, rank(x)={len(x.shape)}" - ) shape_x = op.Shape(x) exped = [] fixed = [] From 86d482d588ac58e5c0df8a3902cb7ee255296537 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 11:47:22 +0200 Subject: [PATCH 7/7] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3d38c7d806..078b4207de 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4386,7 +4386,7 @@ def aten_index_put( if ( len(indices) > 1 and any( - isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access for indice in indices ) and len(values.shape) == 1 @@ -4493,14 +4493,13 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): True, ) - rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] shape_x = op.Shape(x) exped = [] fixed = [] reshape_value_shape2 = [] expand_value_shape = [] for i, ind in enumerate(indices): - if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): + if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access ind.dtype = ir.DataType.INT64 ind, expanded = _make_range_or_cast(ind, shape_x, False, i) if expanded: