Skip to content

Commit

Permalink
[Torch] Add index_put operator (#7465)
Browse files Browse the repository at this point in the history
* [Torch] Add index_put operator

* Skip test_frontends.py::test_load_model__pth
  • Loading branch information
apivovarov authored Feb 18, 2021
1 parent 944d8d1 commit 50e013d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,6 +2010,32 @@ def scatter(self, inputs, input_types):
src = inputs[3]
return _op.transform.scatter(data, index, src, axis)

def index_put(self, inputs, input_types):
in_tensor = inputs[0]
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
# accumulate parameter is ignored.
# torch.index_put default is False but Relay.scatter_nd accumulates values.
# We assume there is no duplicate indices in torch.index_put input
if not accumulate:
logging.warning(
"torch.index_put accumulate parameter is False. "
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
"Make sure there is no duplicate indices in torch.index_put input."
)
# Relay scatter_nd does not support input tensor
# We assume that torch.index_put is used with empty zero-values input tensor
# scatter_nd will create empty zero-values tensor with a given shape
out_shape = self.infer_shape(in_tensor)
logging.warning(
"tvm.relay.scatter_nd operator does not support input tensor parameter. "
"TVM assumes that torch.index_put is used with empty zero-values input tensor"
)
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(values, index_tensor, out_shape)

def scalar_tensor(self, inputs, input_types):
data = inputs[0]
cast_map = {
Expand Down Expand Up @@ -2326,6 +2352,8 @@ def create_convert_map(self):
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
"aten::scatter": self.scatter,
"aten::index_put": self.index_put,
"aten::index_put_": self.index_put,
"aten::scalar_tensor": self.scalar_tensor,
"aten::__interpolate": self.interpolate,
"aten::IntImplicit": self.identity,
Expand Down
1 change: 1 addition & 0 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant):
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx")


@pytest.mark.skip(reason="https://github.com/apache/tvm/issues/7455")
def test_load_model__pth(pytorch_resnet18):
# some CI environments wont offer torch, so skip in case it is not present
pytest.importorskip("torch")
Expand Down
32 changes: 32 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,38 @@ def test_fn_scatter_add(dim):
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)


def test_forward_index_put():
# torch.index_put for 2D tensor and default accumulate (False)
def test_fn_index_put2():
return lambda data, xidx, yidx, values: torch.index_put(
data, indices=[xidx, yidx], values=values
)

# torch.index_put for 3D tensor and accumulate=True
def test_fn_index_put3a():
return lambda data, xidx, yidx, zidx, values: torch.index_put(
data, indices=[xidx, yidx, zidx], values=values, accumulate=True
)

shape = (3, 5)
in_data = torch.zeros(shape)
xidx = torch.tensor([0, 1, 2, 2])
yidx = torch.tensor([0, 1, 3, 4])
values = torch.tensor([2.0, 4.0, 7.0, 9.0])

targets = ["llvm", "cuda"]
verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets)

shape = (3, 5, 3)
in_data = torch.zeros(shape)
xidx = torch.tensor([0, 1, 2, 2, 0])
yidx = torch.tensor([0, 1, 3, 4, 0])
zidx = torch.tensor([0, 1, 1, 2, 0])
values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0])

verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets)


def test_numel():
class Numel(Module):
def forward(self, data):
Expand Down

0 comments on commit 50e013d

Please sign in to comment.