Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Various updates for PyTorch frontend #7348

Merged
merged 9 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,7 @@ def slice(self, inputs, input_types):
begin = [0] * ndim
dim = int(inputs[1])
stride = int(inputs[4])
if isinstance(inputs[2], _expr.Call):
begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
else:
begin[dim] = int(inputs[2])
begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))

# Process begin
if not isinstance(begin[dim], int):
Expand Down Expand Up @@ -518,13 +515,13 @@ def select(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
index = _wrap_const(inputs[2])
return _op.transform.take(data, index, axis=dim)
return _op.transform.take(data, index, axis=dim, mode="wrap")

def take(self, inputs, input_types):
data = inputs[0]
indices = _op.cast(inputs[1], "int32")

return _op.transform.take(data, indices=indices)
return _op.transform.take(data, indices=indices, mode="wrap")

def topk(self, inputs, input_types):
data = inputs[0]
Expand All @@ -551,7 +548,13 @@ def reciprocal(self, inputs, input_types):

def repeat(self, inputs, input_types):
data = inputs[0]
reps = inputs[1]
reps = []
for r in inputs[1]:
if isinstance(r, int):
reps.append(r)
else:
reps.append(int(_infer_value(r, {}).asnumpy()))

return _op.transform.tile(data, reps=reps)

def repeat_interleave(self, inputs, input_types):
Expand Down Expand Up @@ -1520,12 +1523,6 @@ def matmul(self, inputs, input_types):
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
# Broadcast b to match batch size of a
new_b_shape = list(self.infer_shape_with_prelude(b))
new_a_shape = self.infer_shape_with_prelude(a)
if new_a_shape[0] > new_b_shape[0]:
new_b_shape[0] = new_a_shape[0]
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
Expand Down Expand Up @@ -2070,6 +2067,40 @@ def scatter_add(self, inputs, input_types):
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)

def cumsum(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
dtype = inputs[2]

if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])

return _op.cumsum(data, axis=dim, dtype=dtype)

def masked_fill(self, inputs, input_types):
mask = inputs[1]
value = _op.cast(_wrap_const(inputs[2]), input_types[0])
return _op.where(mask, value, inputs[0])

def masked_select(self, inputs, input_types):
mask = inputs[1]
indices = self.nonzero([mask], input_types, is_numpy_style=True)
return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)])

def sort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
is_descending = inputs[2]
# pytorch sort returns both sorted indices and values
indices = _op.argsort(data, dim, not is_descending)
return _op.gather(data, dim, indices), indices

def argsort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
is_descending = inputs[2]
return _op.argsort(data, dim, not is_descending)

def is_floating_point(self, inputs, input_types):
assert len(inputs) == 1

Expand Down Expand Up @@ -2263,6 +2294,7 @@ def create_convert_map(self):
"torchvision::roi_align": self.roi_align,
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
"aten::logical_and": self.logical_and,
"aten::_shape_as_tensor": self.shape_as_tensor,
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
Expand All @@ -2278,6 +2310,11 @@ def create_convert_map(self):
"aten::__not__": self.logical_not,
"aten::hardswish_": self.hard_swish,
"aten::hardswish": self.hard_swish,
"aten::cumsum": self.cumsum,
"aten::masked_fill": self.masked_fill,
"aten::masked_select": self.masked_select,
"aten::argsort": self.argsort,
"aten::sort": self.sort,
}

def update_convert_map(self, custom_map):
Expand Down
101 changes: 100 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def forward(self, *args):
@tvm.testing.uses_gpu
def test_forward_select():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_shape = [5, 3, 10, 10]

class Select1(Module):
def forward(self, *args):
Expand All @@ -1167,6 +1167,9 @@ def forward(self, index):
input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)

# test negative indexing
verify_model(lambda x: x[-1], input_data=input_data)

x = torch.randn(3, 4)
indices = torch.tensor([0, 2])
verify_model(IndexedSelect(x, 0).eval(), input_data=indices)
Expand Down Expand Up @@ -2653,6 +2656,8 @@ def forward(self, *args):
verify_model(Take1().float().eval(), input_data=input_data)
indices = torch.tensor([[0, 0], [1, 0]])
verify_model(Take2().float().eval(), input_data=[input_data, indices])
indices = torch.tensor([0, -1])
verify_model(Take2().float().eval(), input_data=[input_data, indices])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3452,6 +3457,93 @@ def test_hard_swish():
verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input)


def test_cumsum():
def test_fn(dim, dtype=None):
return lambda x: torch.cumsum(x, dim=dim, dtype=dtype)

inp = torch.randint(0, 100, (10000,), dtype=torch.int32)
verify_model(test_fn(0), [inp])
verify_model(test_fn(0), [inp.to(torch.int64)])
verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)])

inp = torch.randn((100, 100), dtype=torch.float32)
verify_model(test_fn(dim=0, dtype=torch.float64), [inp])
verify_model(test_fn(dim=1), [inp])

inp = torch.randn((100, 100), dtype=torch.float32) > 0.5
verify_model(test_fn(dim=0, dtype=torch.int32), [inp])


def test_masked_fill():
def test_fn(x, mask):
return torch.masked_fill(x, mask, 0.0)

inp = torch.randn(100, 100)
verify_model(test_fn, [inp, inp > 0.5])
verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])


def test_transformer():
model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = model.eval()
src = torch.rand((10, 32, 256))
tgt = torch.rand((20, 32, 256))
verify_model(model.eval(), input_data=[src, tgt])


def test_argsort():
def test_fn(dim, descending):
return lambda x: torch.argsort(x, dim=dim, descending=descending)

inp = torch.randn(100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])

inp = torch.randn(100, 100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])
verify_model(test_fn(1, True), [inp])
verify_model(test_fn(1, False), [inp])


def test_sort():
def test_fn(dim, descending):
return lambda x: torch.sort(x, dim=dim, descending=descending)

inp = torch.randn(100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])

inp = torch.randn(100, 100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])
verify_model(test_fn(1, True), [inp])
verify_model(test_fn(1, False), [inp])


def test_logical_and():
def test_fn(x, y):
return torch.logical_and(x, y)

a = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
b = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
verify_model(test_fn, [a, b])

a = torch.tensor([True, False, True])
b = torch.tensor([True, False, False])
verify_model(test_fn, [a, b])


def test_masked_select():
def test_fn(x, mask):
return torch.masked_select(x, mask)

for shape in [(10,), (3, 4), (16, 32, 64)]:
x = torch.randn(*shape)
mask = x.ge(0.5)
verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"])


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -3580,6 +3672,13 @@ def test_hard_swish():
test_forward_scatter()
test_numel()
test_bincount()
test_cumsum()
test_masked_fill()
test_transformer()
test_sort()
test_argsort()
test_logical_and()
test_masked_select()

# Model tests
test_resnet18()
Expand Down