diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5127f3f9f..76227800e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4229,6 +4229,15 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + if ( + len(indices) > 1 + and any( + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access + for indice in indices + ) + and len(values.shape) == 1 + ): + return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. @@ -4298,14 +4307,103 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): # Flatten values to match the indices flat_values = op.Reshape(values, [-1]) - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") - else: - result = op.ScatterND(self, new_index, flat_values) - + scatter_kwargs = dict(reduction="add") if accumulate else {} + result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs) return result +def _aten_index_put_dynamic( + x: TReal, + indices: Sequence[INT64], + values: TReal, + accumulate: bool = False, +) -> TReal: + def _1dint(i: int): + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) + + def _0dint(i: int): + return op.Constant(value_int=ir.AttrInt64("value_int", i)) + + def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): + if ind is not None: + return op.Cast(ind, to=INT64.dtype), False + return ( + op.Cast( + op.Range( # Range does not return a typed result + _0dint(0), + op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), + _0dint(1), + ), + to=INT64.dtype, + ), + True, + ) + + shape_x = op.Shape(x) + exped = [] + fixed = [] + reshape_value_shape2 = [] + expand_value_shape = [] + for i, ind in enumerate(indices): + if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access + ind.dtype = ir.DataType.INT64 + ind, expanded = _make_range_or_cast(ind, shape_x, False, i) + if expanded: + exped.append((i, ind)) + expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) + reshape_value_shape2.append(_1dint(1)) + else: + expand_value_shape.append(_1dint(1)) + reshape_value_shape2.append(op.Shape(ind)) + fixed.append((i, ind)) + + reshape_value_shape1 = [_1dint(1)] * len(indices) + if len(fixed) <= 1: + reshape_value_shape1 = None + elif fixed: + reshape_value_shape1[fixed[-1][0]] = _1dint(-1) + + def _mkstride(x, i): + if i >= len(x.shape) - 1: + return _1dint(1) + if i == len(x.shape) - 2: + return op.Shape(x, start=i + 1) + return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) + + shape = [1] * (len(x.shape) + 1) + r_fixed = [] + if fixed: + new_shape = shape.copy() + new_shape[-1] = -1 + r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] + + r_exped = [] + for i, e in exped: + new_shape = shape.copy() + new_shape[i] = -1 + r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) + + # final sum + unflat = None + for a in [*r_fixed, *r_exped]: + if unflat is None: + unflat = a + continue + unflat = op.Add(unflat, a) + + # value_shape + expanded_values = values + if reshape_value_shape1 is not None: + expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) + expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) + flat_ind = op.Reshape(unflat, _1dint(-1)) + expanded_values = op.Reshape(expanded_values, _1dint(-1)) + flat_x = op.Reshape(x, _1dint(-1)) + scat_kwargs = {"reduction": "add"} if accumulate else {} + flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) + return op.Reshape(flat_up_x, op.Shape(x)) + + @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( self: TReal, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 754f5e2a2..2402d024c 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -225,6 +225,49 @@ def forward(self, q, k, v): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + onnx_program = torch.onnx.export( + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, + ) + _testing.assert_onnx_program(onnx_program) + def test_bitwise_and_scalar(self): class Model(torch.nn.Module): def forward(self, x):