Skip to content

Commit

Permalink
[TOPI] Enable scatter_add on GPU (#6856)
Browse files Browse the repository at this point in the history
* enable scatter gpu test on cuda

* adding update_func arg

* pytorch scatter_add gpu tests working

* update 3d and 4d scatter

* enable scatter_add gpu test

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Nov 5, 2020
1 parent a4bd5f8 commit 7ee91da
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 42 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def compute_scatter_add(attrs, inputs, output_type):
return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]


_reg.register_schedule("scatter_add", strategy.schedule_scatter_add)
_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

#####################
# Shape functions #
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):

@scatter_strategy.register(["cuda", "gpu"])
def scatter_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
"""scatter cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
Expand All @@ -675,6 +675,19 @@ def scatter_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_add_strategy.register(["cuda", "gpu"])
def scatter_add_cuda(attrs, inputs, out_type, target):
"""scatter_add cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_add),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_add.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,12 +1052,15 @@ def _compute_scatter(attrs, inputs, _):
return _compute_scatter


# scatter_add
@generic_func
def schedule_scatter_add(attrs, outs, target):
"""schedule scatter_add"""
with target:
return topi.generic.schedule_scatter_add(outs)
@override_native_generic_func("scatter_add_strategy")
def scatter_add_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter_add),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter_add.generic",
)
return strategy


# bitserial_conv2d
Expand Down
133 changes: 107 additions & 26 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ceil_div(a, b):
return (a + b - 1) // b


def gen_ir_1d(data, indices, updates, axis, out):
def gen_ir_1d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 1d inputs
Parameters
Expand All @@ -44,6 +44,9 @@ def gen_ir_1d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.
update_func: function
The function to be applied to a destination and the corresponding update.
Returns
-------
ret : tir
Expand Down Expand Up @@ -73,14 +76,14 @@ def gen_ir_1d(data, indices, updates, axis, out):
with ib.for_range(0, ni, name="i") as i:
index = indices_ptr[i]
with ib.if_scope(index < 0):
out_ptr[index + n] = updates_ptr[i]
update_func(out_ptr, index + n, updates_ptr[i])
with ib.else_scope():
out_ptr[index] = updates_ptr[i]
update_func(out_ptr, index, updates_ptr[i])

return ib.get()


def gen_ir_2d(data, indices, updates, axis, out):
def gen_ir_2d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 2d inputs
Parameters
Expand All @@ -100,6 +103,9 @@ def gen_ir_2d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.
update_func: function
The function to be applied to a destination and the corresponding update
Returns
-------
ret : tir
Expand Down Expand Up @@ -140,9 +146,9 @@ def gen_ir_2d(data, indices, updates, axis, out):
idx = i * ci + j
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(index + n) * c + j] = updates_ptr[idx]
update_func(out_ptr, (index + n) * c + j, updates_ptr[idx])
with ib.else_scope():
out_ptr[index * c + j] = updates_ptr[idx]
update_func(out_ptr, index * c + j, updates_ptr[idx])
else:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -151,13 +157,13 @@ def gen_ir_2d(data, indices, updates, axis, out):
idx = i * ci + j
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[i * c + (index + c)] = updates_ptr[idx]
update_func(out_ptr, i * c + (index + c), updates_ptr[idx])
with ib.else_scope():
out_ptr[i * c + index] = updates_ptr[idx]
update_func(out_ptr, i * c + index, updates_ptr[idx])
return ib.get()


def gen_ir_3d(data, indices, updates, axis, out):
def gen_ir_3d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 3d inputs
Parameters
Expand All @@ -177,6 +183,9 @@ def gen_ir_3d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.
update_func: function
The function to be applied to a destination and the corresponding update
Returns
-------
ret : tir
Expand Down Expand Up @@ -225,9 +234,9 @@ def gen_ir_3d(data, indices, updates, axis, out):
idx = (i * ci + j) * hi + k
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((index + n) * c + j) * h + k] = updates_ptr[idx]
update_func(out_ptr, ((index + n) * c + j) * h + k, updates_ptr[idx])
with ib.else_scope():
out_ptr[(index * c + j) * h + k] = updates_ptr[idx]
update_func(out_ptr, (index * c + j) * h + k, updates_ptr[idx])
elif axis == 1:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -241,9 +250,9 @@ def gen_ir_3d(data, indices, updates, axis, out):
idx = (i * ci + j) * hi + k
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[idx]
update_func(out_ptr, (i * c + (index + c)) * h + k, updates_ptr[idx])
with ib.else_scope():
out_ptr[(i * c + index) * h + k] = updates_ptr[idx]
update_func(out_ptr, (i * c + index) * h + k, updates_ptr[idx])
else:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -254,13 +263,13 @@ def gen_ir_3d(data, indices, updates, axis, out):
idx = (i * ci + j) * hi + k
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[idx]
update_func(out_ptr, (i * c + j) * h + (index + h), updates_ptr[idx])
with ib.else_scope():
out_ptr[(i * c + j) * h + index] = updates_ptr[idx]
update_func(out_ptr, (i * c + j) * h + index, updates_ptr[idx])
return ib.get()


def gen_ir_4d(data, indices, updates, axis, out):
def gen_ir_4d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 4d inputs
Parameters
Expand All @@ -280,6 +289,9 @@ def gen_ir_4d(data, indices, updates, axis, out):
out : tir.Tensor
The output tensor.
update_func: function
The function to be applied to a destination and the corresponding update
Returns
-------
ret : tir
Expand Down Expand Up @@ -333,9 +345,13 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, (((index + n) * c + j) * h + k) * w + l, updates_ptr[idx]
)
with ib.else_scope():
out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((index * c + j) * h + k) * w + l, updates_ptr[idx]
)
elif axis == 1:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -351,9 +367,13 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + (index + c)) * h + k) * w + l, updates_ptr[idx]
)
with ib.else_scope():
out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + index) * h + k) * w + l, updates_ptr[idx]
)
elif axis == 2:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -369,9 +389,13 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + j) * h + (index + h)) * w + l, updates_ptr[idx]
)
with ib.else_scope():
out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[idx]
update_func(
out_ptr, ((i * c + j) * h + index) * w + l, updates_ptr[idx]
)
else:
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
Expand All @@ -384,10 +408,9 @@ def gen_ir_4d(data, indices, updates, axis, out):
idx = ((i * ci + j) * hi + k) * wi + l
index = indices_ptr[idx]
with ib.if_scope(index < 0):
out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[idx]
update_func(out_ptr, ((i * c + j) * h + k) * w + (index + w), updates_ptr[idx])
with ib.else_scope():
out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[idx]

update_func(out_ptr, ((i * c + j) * h + k) * w + index, updates_ptr[idx])
return ib.get()


Expand Down Expand Up @@ -428,16 +451,74 @@ def scatter(data, indices, updates, axis=0):
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0]),
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_gpu",
tag="scatter_gpu",
)

return out


def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
updates : relay.Expr
The values to be added.
axis : int
The axis to scatter on
Returns
-------
ret : relay.Expr
The computed result.
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] += update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_add_gpu",
tag="scatter_add_gpu",
)

return out
12 changes: 6 additions & 6 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3149,17 +3149,17 @@ def test_fn_scatter_add(dim):
in_data = torch.zeros(3, 5)
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
in_src = torch.rand(2, 5)
# TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"])

targets = ["llvm", "cuda"]
verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets)
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets)

in_data = torch.zeros(2, 4)
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1)

# # TODO: add scatter gpu schedule to enable gpu test.
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"])
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets)
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)


def test_numel():
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ def verify_dynamic_scatter(dshape, ishape, axis=0):
verify_dynamic_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)


@tvm.testing.uses_gpu
def test_scatter_add():
def ref_scatter_add(data, indices, updates, axis=0):
output = np.copy(data)
Expand All @@ -983,8 +984,7 @@ def verify_scatter_add(dshape, ishape, axis=0):
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")

ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
# TODO(mbrookhart): expand testing when adding more backend schedules
for target, ctx in [("llvm", tvm.cpu())]:
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
Expand Down

0 comments on commit 7ee91da

Please sign in to comment.