Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Minor perf improvement for GPU scatter #7233

Merged
merged 16 commits into from
Jan 19, 2021
10 changes: 6 additions & 4 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from random import getrandbits
from collections import namedtuple
import tempfile
import numpy as np

import tvm._ffi
import tvm.ir.transform
Expand Down Expand Up @@ -560,10 +561,11 @@ def run_through_rpc(
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
for arg in args:
random_fill(arg)
ctx.sync()
args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
FrozenGene marked this conversation as resolved.
Show resolved Hide resolved
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
masahi marked this conversation as resolved.
Show resolved Hide resolved

costs = time_f(*args).results

Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,23 @@ def scatter_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.generic.schedule_extern),
wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.cuda",
plevel=10,
)

rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.stable_sort_by_key", allow_missing=True
):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.cuda",
plevel=9, # use the sequential version by default
)
return strategy


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def wrap_compute_scatter(topi_compute):
"""Wrap scatter topi compute"""

def _compute_scatter(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)]
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]

return _compute_scatter

Expand Down
179 changes: 102 additions & 77 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,33 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Scatter operator """
import tvm
from tvm import te
from tvm import te, autotvm
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available
from ..utils import prod


def ceil_div(a, b):
return (a + b - 1) // b


def _memcpy_ir(ib, out_ptr, data_ptr, shape):
fused = prod(shape)
with ib.new_scope():
num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads)
num_blocks = ceil_div(fused, num_thread)
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", num_blocks)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", num_thread)
tid = bx * num_thread + tx

with ib.if_scope(tid < fused):
out_ptr[tid] = data_ptr[tid]


def gen_ir_1d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 1d inputs

Expand Down Expand Up @@ -63,10 +80,7 @@ def gen_ir_1d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
out_ptr[bx] = data_ptr[bx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -114,8 +128,6 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
ret : tir
The computational ir.
"""
warp_size = tvm.target.Target.current(False).thread_warp_size

n = data.shape[0]
c = data.shape[1]

Expand All @@ -124,16 +136,7 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:
j = j_ * warp_size + tx
with ib.if_scope(j < c):
idx = bx * c + j
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -205,18 +208,7 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", c)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_:
k = k_ * warp_size + tx
with ib.if_scope(k < h):
idx = (bx * c + by) * h + k
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -311,20 +303,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
ib.scope_attr(i, "thread_extent", n)
j = te.thread_axis("blockIdx.y")
ib.scope_attr(j, "thread_extent", c)
k = te.thread_axis("blockIdx.z")
ib.scope_attr(k, "thread_extent", h)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_:
l = l_ * warp_size + tx
with ib.if_scope(l < w):
idx = ((i * c + j) * h + k) * w + l
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -417,7 +396,71 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
return ib.get()


def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
@autotvm.register_topi_compute("scatter.cuda")
def scatter(cfg, data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The index locations to update.

updates : relay.Expr
The values to update.

axis : int
The axis to scatter on

Returns
-------
ret : relay.Expr
The computed result.
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

cfg.add_flop(1) # A dummy value to satisfy AutoTVM

out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_gpu",
tag="scatter_gpu",
)

return out


@autotvm.register_topi_schedule("scatter.cuda")
def schedule_scatter(_, outs):
return schedule_extern(outs)


def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
"""Generate scatter ir for 1d inputs, using a sorting based approach.
By sorting indices and comparing neighboring two indices, we can tell which
of elements in the indices tensor can scatter its update value into the output.
Expand All @@ -438,9 +481,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
updates : tir.Tensor
The values to update, sorted by indices.

axis : int
The axis to scatter on. It must be 0 for this function.

out : tir.Tensor
The output tensor.

Expand All @@ -449,7 +489,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()
Expand Down Expand Up @@ -504,7 +543,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
return ib.get()


def scatter(data, indices, updates, axis=0):
@autotvm.register_topi_compute("scatter_via_sort.cuda")
def scatter_via_sort(cfg, data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates

Parameters
Expand All @@ -528,49 +568,34 @@ def scatter(data, indices, updates, axis=0):
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)
assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
assert is_thrust_available(), "Thrust is required for this op"

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update
cfg.add_flop(1) # A dummy value to satisfy AutoTVM

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

in_bufs = [data]

if rank == 1 and is_thrust_available():
ir_funcs[1] = gen_scatter_1d_thrust
indices_sorted, updates_sorted = stable_sort_by_key_thrust(
indices, updates, for_scatter=True
)
in_bufs += [indices_sorted, updates_sorted]
else:
in_bufs += [indices, updates]
indices_sorted, updates_sorted = stable_sort_by_key_thrust(indices, updates, for_scatter=True)

out = te.extern(
[out_shape],
in_bufs,
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
[data, indices_sorted, updates_sorted],
lambda ins, outs: gen_scatter_1d_thrust(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_gpu",
tag="scatter_gpu",
name="scatter_via_sort_gpu",
tag="scatter_via_sort_gpu",
)

return out


@autotvm.register_topi_schedule("scatter_via_sort.cuda")
def schedule_scatter_via_sort(_, outs):
return schedule_extern(outs)


def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
"""Generate scatter add ir for 1d inputs, using atomic_add instruction

Expand Down