Skip to content

Commit

Permalink
[RELAY,TOPI] Add scatter_nd op (apache#6854)
Browse files Browse the repository at this point in the history
* [RELAY,TOPI] Add scatter_nd op

Scatter_nd is the inverse of gather_nd and also happens to be its
gradient. The implementation here is not optimized. There are no cpu or
gpu specific implementations.

* formatting

* Fix tests

* formatting

* specify types on test

* Fix grad test

* scatter_nd cuda impl

* cuda impl

* x86 impl

* formatting

* fix shape rel

* fix tests

* formatting
  • Loading branch information
tkonolige authored and Trevor Morris committed Dec 4, 2020
1 parent e6b06d5 commit 4ef456c
Show file tree
Hide file tree
Showing 21 changed files with 627 additions and 8 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
}
};

struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
Array<Integer> out_shape;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
}
};

struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;

Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
The list of all valid op implementations.
"""
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
assert fstrategy is not None, (
"%s doesn't have an FTVMStrategy registered. You can register "
"one in python with `tvm.relay.op.register_strategy`." % op.name
)
with target:
strategy = fstrategy(attrs, inputs, out_type, target)
analyzer = tvm.arith.Analyzer()
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
squeeze,
strided_set,
arange,
scatter_nd,
)


Expand Down Expand Up @@ -803,3 +804,9 @@ def arange_grad(orig, grad):
grad_step = cast_like(_sum(grad_step), step)

return [grad_start, grad_stop, grad_step]


@register_gradient("gather_nd")
def gather_nd_grad(orig, grad):
data, indices = orig.args
return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
9 changes: 9 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ def compute_interpolate(attrs, inputs, output_type):

_reg.register_schedule("interpolate", strategy.schedule_interpolate)

# scatter
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]


_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)

#####################
# Shape functions #
#####################
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,19 @@ def scatter_add_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_nd_strategy.register(["cuda", "gpu"])
def scatter_nd_cuda(attrs, inputs, out_type, target):
"""scatter_nd cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.cuda.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,28 @@ def schedule_interpolate(attrs, outs, target):
return topi.generic.schedule_interpolate(outs)


# scatter_nd
@override_native_generic_func("scatter_nd_strategy")
def scatter_nd_strategy(attrs, inputs, out_type, target):
"""scatter_nd generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.generic",
)
return strategy


def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""

def _compute_scatter_nd(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]

return _compute_scatter_nd


# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,16 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
name="bitserial_dense.x86",
)
return strategy


@scatter_nd_strategy.register("cpu")
def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
"""scatter_nd x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.x86.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.x86",
plevel=10,
)
return strategy
24 changes: 24 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,30 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def scatter_nd(data, indices, out_shape):
"""Scatter values from an array.
See :py:func:`tvm.topi.scatter` for how data is scattered.
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
out_shape : relay.Expr
Output shape of the scatter.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_nd(data, indices, out_shape)


def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
"""Reshapes the input tensor by the size of another tensor.
For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def check_grad(
break
grads = tmp

assert len(grads) > 0, "You must test at least one gradient."

# Get numeric gradients for each dimension of each param, using two-sided approximation.
approx_grads = []
for x in test_inputs:
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,11 @@ def extern(
if isinstance(body, tvm.tir.PrimExpr):
body = tvm.tir.Evaluate(body)
if not isinstance(body, tvm.tir.Stmt):
raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__))
raise ValueError(
"Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(
fcompute.__name__, type(body)
)
)

op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
Expand Down
106 changes: 106 additions & 0 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Scatter operator """
import tvm
from tvm import te
from ..scatter import _verify_scatter_nd_inputs


def ceil_div(a, b):
Expand Down Expand Up @@ -522,3 +523,108 @@ def update_func(dst_ptr, dst_index, update):
)

return out


def scatter_nd(data, indices, shape):
"""Scatter elements from a n-dimension array.
Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
(M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
.. code-block::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M,
...,
x_{N-1}
] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
all other entries in the output are 0. Repeated indices are summed.
Parameters
----------
data : tvm.te.Tensor
The source array.
indices : tvm.te.Tensor
The indices of the values to extract.
shape : Sequence[int]
The output shape. This must be specified because it cannot be inferred.
Returns
-------
ret : tvm.te.Tensor
"""
_verify_scatter_nd_inputs(data, indices, shape)

def gen_ir(data_ptr, indices_ptr, out_ptr):
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
out = ib.buffer_ptr(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i

fused_shape = 1
for i in shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_data_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_data_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# zero data
# TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
index = i * fused_data_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
out[index] = tvm.tir.Cast(data_ptr.dtype, 0)

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_data_dimension):
offset = fused_data_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= shape[l]
out[index] += data[i * fused_data_dimension + j]

return ib.get()

out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
return te.extern(
[shape],
[data, indices],
lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_cuda",
tag="scatter_nd_cuda",
)
Loading

0 comments on commit 4ef456c

Please sign in to comment.