Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Support bincount and scatter_add ops #6740

Merged
merged 4 commits into from
Oct 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,37 @@ def _impl(inputs, input_types):
return _impl


def _bincount():
def _impl(inputs, input_types):
data = inputs[0]
weights = inputs[1]
maximum = _op.max(data)
dim = maximum + _expr.const(1, dtype="int64")
if weights:
weight_type = _infer_type(weights).checked_type
out_dtype = weight_type.dtype
updates = weights
else:
out_dtype = "int64"
updates = _op.ones_like(data)

counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
return _op.scatter_add(counts, data, updates, axis=0)

return _impl


def _scatter_add():
def _impl(inputs, input_types):
data = inputs[0]
axis = inputs[1]
index = inputs[2]
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)

return _impl


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
Expand Down Expand Up @@ -2699,6 +2730,8 @@ def _get_convert_map(prelude, default_dtype):
"aten::tensor": _identity(), # used for example in tensor(1.0)
"aten::numel": _numel(),
"aten::empty": _empty(),
"aten::bincount": _bincount(),
"aten::scatter_add": _scatter_add(),
}
return convert_map

Expand Down
32 changes: 23 additions & 9 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,26 +3139,27 @@ def forward(self, data):


def test_forward_scatter():
class Scatter(Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim
# integer cannot be traced
def test_fn_scatter(dim):
return lambda data, index, src: torch.scatter(data, dim=dim, index=index, src=src)

def forward(self, data, index, src):
return torch.scatter(data, dim=self.dim, index=index, src=src)
def test_fn_scatter_add(dim):
return lambda data, index, src: torch.scatter_add(data, dim=dim, index=index, src=src)

masahi marked this conversation as resolved.
Show resolved Hide resolved
in_data = torch.zeros(3, 5)
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
in_src = torch.rand(2, 5)
# TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"])

in_data = torch.zeros(2, 4)
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1)

# TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])
# # TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"])


def test_numel():
Expand Down Expand Up @@ -3350,6 +3351,18 @@ def expected(x_shape, y_shape):
assert tvm.ir.structural_equal(expected_mod, mod["main"], map_free_vars=True)


def test_bincount():
def test_fn(x, weights=None):
return torch.bincount(x, weights=weights)

inp = torch.randint(0, 8, (5,), dtype=torch.int64)
weights = torch.linspace(0, 1, steps=5)

verify_trace_model(test_fn, [inp], ["llvm"])
verify_trace_model(test_fn, [inp, weights], ["llvm"])
verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -3476,6 +3489,7 @@ def expected(x_shape, y_shape):
test_forward_nonzero()
test_forward_scatter()
test_numel()
test_bincount()

# Model tests
test_resnet18()
Expand Down