Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 103 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4229,6 +4229,15 @@ def aten_index_put(
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
if (
len(indices) > 1
and any(
isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access
for indice in indices
)
and len(values.shape) == 1
):
return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand Down Expand Up @@ -4298,14 +4307,103 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
result = op.ScatterND(self, new_index, flat_values)

scatter_kwargs = dict(reduction="add") if accumulate else {}
result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs)
return result


def _aten_index_put_dynamic(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

def _0dint(i: int):
return op.Constant(value_int=ir.AttrInt64("value_int", i))

def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
if ind is not None:
return op.Cast(ind, to=INT64.dtype), False
return (
op.Cast(
op.Range( # Range does not return a typed result
_0dint(0),
op.Squeeze(op.Shape(x, start=dim, end=dim + 1)),
_0dint(1),
),
to=INT64.dtype,
),
True,
)

shape_x = op.Shape(x)
exped = []
fixed = []
reshape_value_shape2 = []
expand_value_shape = []
for i, ind in enumerate(indices):
if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access
ind.dtype = ir.DataType.INT64
ind, expanded = _make_range_or_cast(ind, shape_x, False, i)
if expanded:
exped.append((i, ind))
expand_value_shape.append(op.Shape(x, start=i, end=i + 1))
reshape_value_shape2.append(_1dint(1))
else:
expand_value_shape.append(_1dint(1))
reshape_value_shape2.append(op.Shape(ind))
fixed.append((i, ind))

reshape_value_shape1 = [_1dint(1)] * len(indices)
if len(fixed) <= 1:
reshape_value_shape1 = None
elif fixed:
reshape_value_shape1[fixed[-1][0]] = _1dint(-1)

def _mkstride(x, i):
if i >= len(x.shape) - 1:
return _1dint(1)
if i == len(x.shape) - 2:
return op.Shape(x, start=i + 1)
return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1)

shape = [1] * (len(x.shape) + 1)
r_fixed = []
if fixed:
new_shape = shape.copy()
new_shape[-1] = -1
r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed]

r_exped = []
for i, e in exped:
new_shape = shape.copy()
new_shape[i] = -1
r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape))

# final sum
unflat = None
for a in [*r_fixed, *r_exped]:
if unflat is None:
unflat = a
continue
unflat = op.Add(unflat, a)

# value_shape
expanded_values = values
if reshape_value_shape1 is not None:
expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0))
expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0))
flat_ind = op.Reshape(unflat, _1dint(-1))
expanded_values = op.Reshape(expanded_values, _1dint(-1))
flat_x = op.Reshape(x, _1dint(-1))
scat_kwargs = {"reduction": "add"} if accumulate else {}
flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs)
return op.Reshape(flat_up_x, op.Shape(x))


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
Expand Down
43 changes: 43 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,49 @@ def forward(self, q, k, v):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_dynamic(self):
for dimension in [3, 4, 2]:
with self.subTest(dimension=dimension):

class Model(torch.nn.Module):
def __init__(self, dimension):
super().__init__()
self.params = torch.zeros(
(4, 5)
if dimension == 2
else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5))
)
self.dimension = dimension

def forward(self, update, index1, index2):
copy = self.params.clone()
if self.dimension == 2:
copy[index1, index2] = update
elif self.dimension == 3:
copy[:, index1, index2] = update
else:
copy[:, :, index1, index2] = update
return copy

update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2], dtype=torch.int64)
index2 = torch.tensor([3, 4], dtype=torch.int64)
feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2)))
onnx_program = torch.onnx.export(
Model(dimension),
tuple(feeds.values()),
input_names=["update", "index1", "index2"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes={
"update": {0: "dn"},
"index1": {0: "dn"},
"index2": {0: "dn"},
},
)
_testing.assert_onnx_program(onnx_program)

def test_bitwise_and_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
Loading