Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Enable scatter_add on GPU #6856

Merged
merged 5 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def compute_scatter_add(attrs, inputs, output_type):
return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]


_reg.register_schedule("scatter_add", strategy.schedule_scatter_add)
_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

#####################
# Shape functions #
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):

@scatter_strategy.register(["cuda", "gpu"])
def scatter_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
"""scatter cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
Expand All @@ -675,6 +675,19 @@ def scatter_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_add_strategy.register(["cuda", "gpu"])
def scatter_add_cuda(attrs, inputs, out_type, target):
"""scatter_add cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_add),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_add.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,12 +1052,15 @@ def _compute_scatter(attrs, inputs, _):
return _compute_scatter


# scatter_add
@generic_func
def schedule_scatter_add(attrs, outs, target):
"""schedule scatter_add"""
with target:
return topi.generic.schedule_scatter_add(outs)
@override_native_generic_func("scatter_add_strategy")
def scatter_add_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter_add),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter_add.generic",
)
return strategy


# bitserial_conv2d
Expand Down
133 changes: 107 additions & 26 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ceil_div(a, b):
return (a + b - 1) // b


def gen_ir_1d(data, indices, updates, axis, out):
def gen_ir_1d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 1d inputs

Parameters
Expand All @@ -44,6 +44,9 @@ def gen_ir_1d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.

update_func: function
The function to be applied to a destination and the corresponding update.

Returns
-------
ret : tir
Expand Down Expand Up @@ -73,14 +76,14 @@ def gen_ir_1d(data, indices, updates, axis, out):
with ib.for_range(0, ni, name="i") as i:
index = indices_ptr[i]
with ib.if_scope(index < 0):
out_ptr[index + n] = updates_ptr[i]
update_func(out_ptr, index + n, updates_ptr[i])
with ib.else_scope():
out_ptr[index] = updates_ptr[i]
update_func(out_ptr, index, updates_ptr[i])

return ib.get()


def gen_ir_2d(data, indices, updates, axis, out):
def gen_ir_2d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 2d inputs

Parameters
Expand All @@ -100,6 +103,9 @@ def gen_ir_2d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.

update_func: function
The function to be applied to a destination and the corresponding update

Returns
-------
ret : tir
Expand Down Expand Up @@ -140,9 +146,9 @@ def gen_ir_2d(data, indices, updates, axis, out):
idx = i * ci + j
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(index + n) * c + j] = updates_ptr[idx]
update_func(out_ptr, (index + n) * c + j, updates_ptr[idx])
with ib.else_scope():
out_ptr[index * c + j] = updates_ptr[idx]
update_func(out_ptr, index * c + j, updates_ptr[idx])
else:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -151,13 +157,13 @@ def gen_ir_2d(data, indices, updates, axis, out):
idx = i * ci + j
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[i * c + (index + c)] = updates_ptr[idx]
update_func(out_ptr, i * c + (index + c), updates_ptr[idx])
with ib.else_scope():
out_ptr[i * c + index] = updates_ptr[idx]
update_func(out_ptr, i * c + index, updates_ptr[idx])
return ib.get()


def gen_ir_3d(data, indices, updates, axis, out):
def gen_ir_3d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 3d inputs

Parameters
Expand All @@ -177,6 +183,9 @@ def gen_ir_3d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.

update_func: function
The function to be applied to a destination and the corresponding update

Returns
-------
ret : tir
Expand Down Expand Up @@ -225,9 +234,9 @@ def gen_ir_3d(data, indices, updates, axis, out):
idx = (i * ci + j) * hi + k
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((index + n) * c + j) * h + k] = updates_ptr[idx]
update_func(out_ptr, ((index + n) * c + j) * h + k, updates_ptr[idx])
with ib.else_scope():
out_ptr[(index * c + j) * h + k] = updates_ptr[idx]
update_func(out_ptr, (index * c + j) * h + k, updates_ptr[idx])
elif axis == 1:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -241,9 +250,9 @@ def gen_ir_3d(data, indices, updates, axis, out):
idx = (i * ci + j) * hi + k
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[idx]
update_func(out_ptr, (i * c + (index + c)) * h + k, updates_ptr[idx])
with ib.else_scope():
out_ptr[(i * c + index) * h + k] = updates_ptr[idx]
update_func(out_ptr, (i * c + index) * h + k, updates_ptr[idx])
else:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -254,13 +263,13 @@ def gen_ir_3d(data, indices, updates, axis, out):
idx = (i * ci + j) * hi + k
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[idx]
update_func(out_ptr, (i * c + j) * h + (index + h), updates_ptr[idx])
with ib.else_scope():
out_ptr[(i * c + j) * h + index] = updates_ptr[idx]
update_func(out_ptr, (i * c + j) * h + index, updates_ptr[idx])
return ib.get()


def gen_ir_4d(data, indices, updates, axis, out):
def gen_ir_4d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 4d inputs

Parameters
Expand All @@ -280,6 +289,9 @@ def gen_ir_4d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.

update_func: function
The function to be applied to a destination and the corresponding update

Returns
-------
ret : tir
Expand Down Expand Up @@ -333,9 +345,13 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, (((index + n) * c + j) * h + k) * w + l, updates_ptr[idx]
)
with ib.else_scope():
out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((index * c + j) * h + k) * w + l, updates_ptr[idx]
)
elif axis == 1:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -351,9 +367,13 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + (index + c)) * h + k) * w + l, updates_ptr[idx]
)
with ib.else_scope():
out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + index) * h + k) * w + l, updates_ptr[idx]
)
elif axis == 2:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -369,9 +389,13 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + j) * h + (index + h)) * w + l, updates_ptr[idx]
)
with ib.else_scope():
out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + j) * h + index) * w + l, updates_ptr[idx]
)
else:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -384,10 +408,9 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[idx]
update_func(out_ptr, ((i * c + j) * h + k) * w + (index + w), updates_ptr[idx])
with ib.else_scope():
out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[idx]

update_func(out_ptr, ((i * c + j) * h + k) * w + index, updates_ptr[idx])
return ib.get()


Expand Down Expand Up @@ -428,16 +451,74 @@ def scatter(data, indices, updates, axis=0):
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0]),
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_gpu",
tag="scatter_gpu",
)

return out


def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The index locations to update.

updates : relay.Expr
The values to be added.

axis : int
The axis to scatter on

Returns
-------
ret : relay.Expr
The computed result.
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] += update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_add_gpu",
tag="scatter_add_gpu",
)

return out
12 changes: 6 additions & 6 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3149,17 +3149,17 @@ def test_fn_scatter_add(dim):
in_data = torch.zeros(3, 5)
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
in_src = torch.rand(2, 5)
# TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"])

targets = ["llvm", "cuda"]
verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets)
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets)

in_data = torch.zeros(2, 4)
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1)

# # TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets)
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)


def test_numel():
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ def verify_dynamic_scatter(dshape, ishape, axis=0):
verify_dynamic_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)


@tvm.testing.uses_gpu
def test_scatter_add():
def ref_scatter_add(data, indices, updates, axis=0):
output = np.copy(data)
Expand All @@ -983,8 +984,7 @@ def verify_scatter_add(dshape, ishape, axis=0):
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")

ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
# TODO(mbrookhart): expand testing when adding more backend schedules
for target, ctx in [("llvm", tvm.cpu())]:
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
Expand Down