From 494d33e029ea0f0acd3fbf973ea7ac883ac23988 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 23 Oct 2020 12:55:11 +0900 Subject: [PATCH 1/4] add bincount support to pytorch frontend --- python/tvm/relay/frontend/pytorch.py | 21 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 21 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c8fbd5a5c10c..04e9d038f779 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2357,6 +2357,25 @@ def _impl(inputs, input_types): return _impl +def _bincount(): + def _impl(inputs, input_types): + data = inputs[0] + weights = inputs[1] + maximum = _op.max(data) + dim = maximum + _expr.const(1, dtype="int64") + if weights: + out_dtype = "float32" + updates = weights + else: + out_dtype = "int64" + updates = _op.ones_like(data) + + counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) + return _op.scatter_add(counts, data, updates, axis=0) + + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2699,6 +2718,7 @@ def _get_convert_map(prelude, default_dtype): "aten::tensor": _identity(), # used for example in tensor(1.0) "aten::numel": _numel(), "aten::empty": _empty(), + "aten::bincount": _bincount(), } return convert_map @@ -3330,6 +3350,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt graph = script_module.graph.copy() _run_jit_passes(graph) + print(graph) if custom_convert_map: convert_map.update(custom_convert_map) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 54c3daf25385..31fc46d03b8b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3350,6 +3350,26 @@ def expected(x_shape, y_shape): assert tvm.ir.structural_equal(expected_mod, mod["main"], map_free_vars=True) +def test_bincount(): + class Bincount(torch.nn.Module): + def __init__(self, weights=None): + super().__init__() + self.weights = weights + + def forward(self, x): + return torch.bincount(x, weights=self.weights) + + inp = torch.randint(0, 8, (5,), dtype=torch.int64) + weights = torch.linspace(0, 1, steps=5) + + verify_trace_model(Bincount(), [inp], ["llvm"]) + verify_trace_model(Bincount(weights), [inp], ["llvm"]) + + +def test_scatter_add(): + pass + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3476,6 +3496,7 @@ def expected(x_shape, y_shape): test_forward_nonzero() test_forward_scatter() test_numel() + test_bincount() # Model tests test_resnet18() From b2eeb0db40a664734306de4691cd490963730f96 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 23 Oct 2020 13:19:18 +0900 Subject: [PATCH 2/4] add scatter_add support --- python/tvm/relay/frontend/pytorch.py | 12 ++++++++++++ tests/python/frontend/pytorch/test_forward.py | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 04e9d038f779..9e0fe3414a29 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2376,6 +2376,17 @@ def _impl(inputs, input_types): return _impl +def _scatter_add(): + def _impl(inputs, input_types): + data = inputs[0] + axis = inputs[1] + index = inputs[2] + src = inputs[3] + return _op.scatter_add(data, index, src, axis=axis) + + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2719,6 +2730,7 @@ def _get_convert_map(prelude, default_dtype): "aten::numel": _numel(), "aten::empty": _empty(), "aten::bincount": _bincount(), + "aten::scatter_add": _scatter_add(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 31fc46d03b8b..d963d307791f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3147,11 +3147,20 @@ def __init__(self, dim=0): def forward(self, data, index, src): return torch.scatter(data, dim=self.dim, index=index, src=src) + class ScatterAdd(Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, data, index, src): + return torch.scatter_add(data, dim=self.dim, index=index, src=src) + in_data = torch.zeros(3, 5) in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) in_src = torch.rand(2, 5) # TODO: add scatter gpu schedule to enable gpu test. verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(ScatterAdd(), [in_data, in_index, in_src], ["llvm"]) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) @@ -3159,6 +3168,7 @@ def forward(self, data, index, src): # TODO: add scatter gpu schedule to enable gpu test. verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(ScatterAdd(1), [in_data, in_index, in_src], ["llvm"]) def test_numel(): From 068f2e8eebd4b75f238894160436167ed5d2da04 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 23 Oct 2020 13:41:57 +0900 Subject: [PATCH 3/4] remove stuff --- python/tvm/relay/frontend/pytorch.py | 1 - tests/python/frontend/pytorch/test_forward.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9e0fe3414a29..07ac5cd124c2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3362,7 +3362,6 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt graph = script_module.graph.copy() _run_jit_passes(graph) - print(graph) if custom_convert_map: convert_map.update(custom_convert_map) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d963d307791f..5512980062e1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3376,10 +3376,6 @@ def forward(self, x): verify_trace_model(Bincount(weights), [inp], ["llvm"]) -def test_scatter_add(): - pass - - if __name__ == "__main__": # some structural tests test_forward_traced_function() From c3e042fbd009fae8f99c22bf6c610e32b008d0f5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 23 Oct 2020 16:30:33 +0900 Subject: [PATCH 4/4] fix weight dtype, cleanup test --- python/tvm/relay/frontend/pytorch.py | 3 +- tests/python/frontend/pytorch/test_forward.py | 43 +++++++------------ 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 07ac5cd124c2..c41d6802edd9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2364,7 +2364,8 @@ def _impl(inputs, input_types): maximum = _op.max(data) dim = maximum + _expr.const(1, dtype="int64") if weights: - out_dtype = "float32" + weight_type = _infer_type(weights).checked_type + out_dtype = weight_type.dtype updates = weights else: out_dtype = "int64" diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 5512980062e1..e997ebe07a50 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3139,36 +3139,27 @@ def forward(self, data): def test_forward_scatter(): - class Scatter(Module): - def __init__(self, dim=0): - super().__init__() - self.dim = dim - - def forward(self, data, index, src): - return torch.scatter(data, dim=self.dim, index=index, src=src) - - class ScatterAdd(Module): - def __init__(self, dim=0): - super().__init__() - self.dim = dim + # integer cannot be traced + def test_fn_scatter(dim): + return lambda data, index, src: torch.scatter(data, dim=dim, index=index, src=src) - def forward(self, data, index, src): - return torch.scatter_add(data, dim=self.dim, index=index, src=src) + def test_fn_scatter_add(dim): + return lambda data, index, src: torch.scatter_add(data, dim=dim, index=index, src=src) in_data = torch.zeros(3, 5) in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) in_src = torch.rand(2, 5) # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"]) - verify_trace_model(ScatterAdd(), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"]) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) - verify_trace_model(ScatterAdd(1), [in_data, in_index, in_src], ["llvm"]) + # # TODO: add scatter gpu schedule to enable gpu test. + verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"]) def test_numel(): @@ -3361,19 +3352,15 @@ def expected(x_shape, y_shape): def test_bincount(): - class Bincount(torch.nn.Module): - def __init__(self, weights=None): - super().__init__() - self.weights = weights - - def forward(self, x): - return torch.bincount(x, weights=self.weights) + def test_fn(x, weights=None): + return torch.bincount(x, weights=weights) inp = torch.randint(0, 8, (5,), dtype=torch.int64) weights = torch.linspace(0, 1, steps=5) - verify_trace_model(Bincount(), [inp], ["llvm"]) - verify_trace_model(Bincount(weights), [inp], ["llvm"]) + verify_trace_model(test_fn, [inp], ["llvm"]) + verify_trace_model(test_fn, [inp, weights], ["llvm"]) + verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"]) if __name__ == "__main__":