From 7aecc1a44d0988c175527fa974d2295927cb1007 Mon Sep 17 00:00:00 2001 From: padreofthegame <97688606+padreofthegame@users.noreply.github.com> Date: Thu, 2 Feb 2023 07:05:51 +0100 Subject: [PATCH] [Torch] Fix advanced indexing with NoneType index arguments (#13826) [Torch] Fix advanced indexing with NoneType index --- python/tvm/relay/frontend/pytorch.py | 40 ++++++++++++++++--- tests/python/frontend/pytorch/test_forward.py | 35 ++++++++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 491c140c5cb4..fde2bfb26356 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2330,19 +2330,49 @@ def one_hot(self, inputs, input_types): def index(self, inputs, input_types): data = inputs[0] + data_shape = self.infer_type(data).shape + + axes_adv_idx = [i for i, v in enumerate(inputs[1]) if v is not None] + axes_rest = [i for i in range(len(data_shape)) if i not in axes_adv_idx] + + # check if the adv_index axes are consecutive + # if consecutive, result must be transposed again at the end + consecutive = True + for curr, nxt in zip(axes_adv_idx[:-1], axes_adv_idx[1:]): + if nxt - curr != 1: + consecutive = False + break + indices_list = [] + axes_order = axes_adv_idx + axes_rest - for indices in inputs[1]: - if self.infer_type(indices).dtype == "bool": + for i in axes_adv_idx: + inp = inputs[1][i] + if self.infer_type(inp).dtype == "bool": # adv_index does not support a mask as the index tensor (it will treat 0/1 as # an index rather than a flag). # So we use argwhere to turn the mask into indices, which will also take care # of the dynamism in the indexing by mask. - indices_list.append(_op.squeeze(_op.transform.argwhere(indices), axis=[1])) + indices_list.append(_op.squeeze(_op.transform.argwhere(inp), axis=[1])) else: - indices_list.append(indices) + indices_list.append(inp) + + data_after_adv_index = _op.adv_index([_op.transpose(data, axes=axes_order)] + indices_list) - return _op.adv_index([data] + indices_list) + if consecutive: + num_dims = len(self.infer_type(data_after_adv_index).shape) + num_new_dims = num_dims - len(axes_rest) + + axes_final_order = list(range(num_dims)) + axes_final_order = ( + axes_final_order[num_new_dims : num_new_dims + axes_adv_idx[0]] + + axes_final_order[:num_new_dims] + + axes_final_order[num_new_dims + axes_adv_idx[0] :] + ) + + return _op.transpose(data_after_adv_index, axes=axes_final_order) + else: + return data_after_adv_index def meshgrid(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0035d202ded2..82992d287ace 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4034,6 +4034,41 @@ def forward(self, x): input_data = torch.rand(input_shape).float() verify_model(Index1().eval(), input_data=input_data) + class Index2(Module): + def forward(self, x): + return x[None, [2, 2]] + + input_data = torch.rand(input_shape).float() + verify_model(Index2().eval(), input_data=input_data) + + class Index3(Module): + def forward(self, x): + return x[None, [0, 1, 2], 1, [2, 3, 4]] + + input_data = torch.rand(input_shape).float() + verify_model(Index3().eval(), input_data=input_data) + + class Index4(Module): + def forward(self, x): + return x[None, [0, 0], None, np.array([[0], [1], [2]]), None] + + input_data = torch.rand(input_shape).float() + verify_model(Index4().eval(), input_data=input_data) + + class Index5(Module): + def forward(self, x): + return x[None, None, [0, 0], np.array([[0], [1], [2]]), None] + + input_data = torch.rand(input_shape).float() + verify_model(Index5().eval(), input_data=input_data) + + class Index6(Module): + def forward(self, x): + return x[None, 1, None, [1, 2, 3]] + + input_data = torch.rand(input_shape).float() + verify_model(Index6().eval(), input_data=input_data) + def test_fn_bool_mask(): return lambda data, mask: data[0, mask]