diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c8fbd5a5c10c..c41d6802edd9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2357,6 +2357,37 @@ def _impl(inputs, input_types): return _impl +def _bincount(): + def _impl(inputs, input_types): + data = inputs[0] + weights = inputs[1] + maximum = _op.max(data) + dim = maximum + _expr.const(1, dtype="int64") + if weights: + weight_type = _infer_type(weights).checked_type + out_dtype = weight_type.dtype + updates = weights + else: + out_dtype = "int64" + updates = _op.ones_like(data) + + counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) + return _op.scatter_add(counts, data, updates, axis=0) + + return _impl + + +def _scatter_add(): + def _impl(inputs, input_types): + data = inputs[0] + axis = inputs[1] + index = inputs[2] + src = inputs[3] + return _op.scatter_add(data, index, src, axis=axis) + + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2699,6 +2730,8 @@ def _get_convert_map(prelude, default_dtype): "aten::tensor": _identity(), # used for example in tensor(1.0) "aten::numel": _numel(), "aten::empty": _empty(), + "aten::bincount": _bincount(), + "aten::scatter_add": _scatter_add(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 54c3daf25385..e997ebe07a50 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3139,26 +3139,27 @@ def forward(self, data): def test_forward_scatter(): - class Scatter(Module): - def __init__(self, dim=0): - super().__init__() - self.dim = dim + # integer cannot be traced + def test_fn_scatter(dim): + return lambda data, index, src: torch.scatter(data, dim=dim, index=index, src=src) - def forward(self, data, index, src): - return torch.scatter(data, dim=self.dim, index=index, src=src) + def test_fn_scatter_add(dim): + return lambda data, index, src: torch.scatter_add(data, dim=dim, index=index, src=src) in_data = torch.zeros(3, 5) in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) in_src = torch.rand(2, 5) # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"]) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) + # # TODO: add scatter gpu schedule to enable gpu test. + verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"]) def test_numel(): @@ -3350,6 +3351,18 @@ def expected(x_shape, y_shape): assert tvm.ir.structural_equal(expected_mod, mod["main"], map_free_vars=True) +def test_bincount(): + def test_fn(x, weights=None): + return torch.bincount(x, weights=weights) + + inp = torch.randint(0, 8, (5,), dtype=torch.int64) + weights = torch.linspace(0, 1, steps=5) + + verify_trace_model(test_fn, [inp], ["llvm"]) + verify_trace_model(test_fn, [inp, weights], ["llvm"]) + verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"]) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3476,6 +3489,7 @@ def expected(x_shape, y_shape): test_forward_nonzero() test_forward_scatter() test_numel() + test_bincount() # Model tests test_resnet18()