Skip to content

Commit

Permalink
[Frontend][Pytorch]Add Pytorch advanced indexing (#6318)
Browse files Browse the repository at this point in the history
* Add Pytorch advanced indexing

* Minor fix for test

* Fix for cuda
  • Loading branch information
kevinthesun authored Aug 22, 2020
1 parent 9a0413a commit 88f6f79
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
53 changes: 50 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,18 @@ def _impl(inputs, input_types):
end[dim] = min(end[dim], int(inputs[3]))
else:
if isinstance(inputs[3], _expr.Call):
end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
else:
end[dim] = inputs[3]
target_end = inputs[3]

end[dim] = min(end[dim], target_end)

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
slice_mode="size")
slice_mode="end")
return _impl

def _split():
Expand Down Expand Up @@ -1759,6 +1761,50 @@ def _impl(inputs, input_types):
return _impl


def _index():
def _impl(inputs, input_types):
data = inputs[0]
indices = []
raw_indices = []
max_indices_len = -1
for index in inputs[1]:
if not isinstance(index, _expr.Constant):
try:
index = _expr.const(_infer_value(index, {}))
except Exception:
raise RuntimeError("Only supports constant indices for "
"pytorch advanced indexing ")
raw_indices.append(index)
cindex_len = index.data.shape[0]
if cindex_len > max_indices_len:
max_indices_len = cindex_len

for index in raw_indices:
cnp = index.data.asnumpy()
cindex_len = cnp.shape[0]
if cindex_len < max_indices_len:
cnp = np.tile(cnp, max_indices_len // cindex_len)
indices.append(cnp)

ret = []
slice_map = {}
for i in range(indices[0].shape[0]):
tmp = data
current_indices = []
for index in indices:
current_indices.append(index[i])
index_key = tuple(current_indices)
if index_key in slice_map:
tmp = slice_map[index_key]
else:
tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
slice_map[index_key] = tmp
ret.append(_op.expand_dims(tmp, axis=0))

return _op.concatenate(ret, axis=0)
return _impl


def _meshgrid():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -2064,6 +2110,7 @@ def _get_convert_map(prelude):
"aten::type_as" : _type_as(),
"aten::gather" : _gather(),
"aten::index_select" : _select(),
"aten::index" : _index(),
}
return convert_map

Expand Down
24 changes: 22 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,13 +1202,13 @@ def forward(self, *args):

class Slice2(Module):
def forward(self, *args):
return args[0][0, :, :, :]
return args[0][0, :, :-3, :]

class Slice3(Module):
def forward(self, *args):
x0 = torch.tensor(2) - torch.tensor(1)
x1 = torch.tensor(3) + torch.tensor(1)
return args[0][:, x0:, :x1, :]
return args[0][:, x0:, 1:x1, :]

input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
Expand Down Expand Up @@ -2620,6 +2620,25 @@ def forward(self, *args):
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])


def test_forward_index():
torch.set_grad_enabled(False)
input_shape = [3, 4, 5, 6]

class Index0(Module):
def forward(self, x):
return x[[0, 1], [0, 2], :2, 4]

input_data = torch.rand(input_shape).float()
verify_model(Index0().eval(), input_data=input_data)

class Index1(Module):
def forward(self, x):
return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]]

input_data = torch.rand(input_shape).float()
verify_model(Index1().eval(), input_data=input_data)


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
Expand Down Expand Up @@ -2859,6 +2878,7 @@ def test_forward_pretrained_bert_base_uncased():
test_adaptive_pool3d()
test_conv3d()
test_conv3d_transpose()
test_forward_index()

# Model tests
test_resnet18()
Expand Down

0 comments on commit 88f6f79

Please sign in to comment.