Skip to content

Commit

Permalink
[Torch] Add index_put operator
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Feb 18, 2021
1 parent d280118 commit d2c1b77
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,6 +2010,29 @@ def scatter(self, inputs, input_types):
src = inputs[3]
return _op.transform.scatter(data, index, src, axis)

def index_put(self, inputs, input_types):
in_tensor = inputs[0]
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
# accumulate parameter is ignored.
# torch.index_put default is False but Relay.scatter_nd accumulates values.
# We assume there is no duplicate indices in torch.index_put input
if not accumulate:
logging.warning("torch.index_put accumulate parameter is False. "
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
"Make sure there is no duplicate indices in torch.index_put input.")
# Relay scatter_nd does not support input tensor
# We assume that torch.index_put is used with empty zero-values input tensor
# scatter_nd will create empty zero-values tensor with a given shape
out_shape = self.infer_shape(in_tensor)
logging.warning("tvm.relay.scatter_nd operator does not support input tensor parameter. "
"TVM assumes that torch.index_put is used with empty zero-values input tensor")
# Combine array of index tensors into one index tensor with shape (N,_)
indices_expdim = [self.unsqueeze((x, 0), None) for x in indices]
indices_concat = self.concatenate((indices_expdim, 0), None)
return _op.transform.scatter_nd(values, indices_concat, out_shape)

def scalar_tensor(self, inputs, input_types):
data = inputs[0]
cast_map = {
Expand Down Expand Up @@ -2326,6 +2349,8 @@ def create_convert_map(self):
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
"aten::scatter": self.scatter,
"aten::index_put": self.index_put,
"aten::index_put_": self.index_put,
"aten::scalar_tensor": self.scalar_tensor,
"aten::__interpolate": self.interpolate,
"aten::IntImplicit": self.identity,
Expand Down
44 changes: 44 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ def test_forward_pixel_shuffle():
verify_model(torch.nn.PixelShuffle(3).float().eval(), input_data=input_data)
verify_model(torch.nn.PixelShuffle(4).float().eval(), input_data=input_data)

@tvm.testing.uses_gpu
def test_forward_input_put():
torch.set_grad_enabled(False)
input_shape = [3,3]

class Zeros1(Module):
def forward(self, *args):
hs = torch.tensor([0, 1, 2, 2])
ws = torch.tensor([0, 1, 1, 2])
vs = torch.tensor([2.0, 4.0, 7.0, 9.0])
return torch.index_put_(args[0], indices=[hs, ws], values=vs)

input_data = torch.zeros(input_shape, dtype=torch.float)
verify_model(Zeros1(), input_data=input_data)

@tvm.testing.uses_gpu
def test_forward_add():
Expand Down Expand Up @@ -3327,6 +3341,36 @@ def test_fn_scatter_add(dim):
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)


def test_forward_index_put():
# torch.index_put for 2D tensor and default accumulate (False)
def test_fn_index_put2():
return lambda data, xidx, yidx, values: \
torch.index_put(data, indices=[xidx, yidx], values=values)

# torch.index_put for 3D tensor and accumulate=True
def test_fn_index_put3a():
return lambda data, xidx, yidx, zidx, values: \
torch.index_put(data, indices=[xidx, yidx, zidx], values=values, accumulate=True)

shape = (3, 5)
in_data = torch.zeros(shape)
xidx = torch.tensor([0, 1, 2, 2])
yidx = torch.tensor([0, 1, 3, 4])
values = torch.tensor([2.0, 4.0, 7.0, 9.0])

targets = ["llvm", "cuda"]
verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets)

shape = (3, 5, 3)
in_data = torch.zeros(shape)
xidx = torch.tensor([0, 1, 2, 2, 0])
yidx = torch.tensor([0, 1, 3, 4, 0])
zidx = torch.tensor([0, 1, 1, 2, 0])
values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0])

verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets)


def test_numel():
class Numel(Module):
def forward(self, data):
Expand Down

0 comments on commit d2c1b77

Please sign in to comment.