Skip to content

Commit

Permalink
[Torch] Support bincount and scatter_add ops (apache#6740)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and Trevor Morris committed Dec 4, 2020
1 parent 57441b6 commit c39de69
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
33 changes: 33 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,37 @@ def _impl(inputs, input_types):
return _impl


def _bincount():
def _impl(inputs, input_types):
data = inputs[0]
weights = inputs[1]
maximum = _op.max(data)
dim = maximum + _expr.const(1, dtype="int64")
if weights:
weight_type = _infer_type(weights).checked_type
out_dtype = weight_type.dtype
updates = weights
else:
out_dtype = "int64"
updates = _op.ones_like(data)

counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
return _op.scatter_add(counts, data, updates, axis=0)

return _impl


def _scatter_add():
def _impl(inputs, input_types):
data = inputs[0]
axis = inputs[1]
index = inputs[2]
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)

return _impl


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
Expand Down Expand Up @@ -2699,6 +2730,8 @@ def _get_convert_map(prelude, default_dtype):
"aten::tensor": _identity(), # used for example in tensor(1.0)
"aten::numel": _numel(),
"aten::empty": _empty(),
"aten::bincount": _bincount(),
"aten::scatter_add": _scatter_add(),
}
return convert_map

Expand Down
32 changes: 23 additions & 9 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,26 +3139,27 @@ def forward(self, data):


def test_forward_scatter():
class Scatter(Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim
# integer cannot be traced
def test_fn_scatter(dim):
return lambda data, index, src: torch.scatter(data, dim=dim, index=index, src=src)

def forward(self, data, index, src):
return torch.scatter(data, dim=self.dim, index=index, src=src)
def test_fn_scatter_add(dim):
return lambda data, index, src: torch.scatter_add(data, dim=dim, index=index, src=src)

in_data = torch.zeros(3, 5)
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
in_src = torch.rand(2, 5)
# TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"])

in_data = torch.zeros(2, 4)
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1)

# TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])
# # TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"])


def test_numel():
Expand Down Expand Up @@ -3350,6 +3351,18 @@ def expected(x_shape, y_shape):
assert tvm.ir.structural_equal(expected_mod, mod["main"], map_free_vars=True)


def test_bincount():
def test_fn(x, weights=None):
return torch.bincount(x, weights=weights)

inp = torch.randint(0, 8, (5,), dtype=torch.int64)
weights = torch.linspace(0, 1, steps=5)

verify_trace_model(test_fn, [inp], ["llvm"])
verify_trace_model(test_fn, [inp, weights], ["llvm"])
verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -3476,6 +3489,7 @@ def expected(x_shape, y_shape):
test_forward_nonzero()
test_forward_scatter()
test_numel()
test_bincount()

# Model tests
test_resnet18()
Expand Down

0 comments on commit c39de69

Please sign in to comment.