Skip to content

Commit

Permalink
Add Scatter to Topi/Relay/ONNX via hybrid script (#5619)
Browse files Browse the repository at this point in the history
* I can construct scatter but not embed it in a Relay Graph

* working 1-4 dimesion scatter

* add scatter to ONNX

fix lint

* isolate tests to cpu backend

* Fix i386 test

* fix gpu tolerance

* use elemwise_shape_func for scatter

* fix incorrect rebase
  • Loading branch information
Matthew Brookhart authored Jun 9, 2020
1 parent aa80857 commit 9f79199
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 1 deletion.
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}
}; // struct ReshapeAttrs

struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,16 @@ def _impl_v1(cls, inputs, attr, params):
return _op.gather_nd(inputs[0], inputs[1])


class Scatter(OnnxOpConverter):
""" Operator converter for Scatter.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)


class Greater(OnnxOpConverter):
""" Operator logical greater.
"""
Expand Down Expand Up @@ -1863,6 +1873,8 @@ def _get_convert_map(opset):
'SpaceToDepth': SpaceToDepth.get_converter(opset),
'Gather': Gather.get_converter(opset),
'GatherND': GatherND.get_converter(opset),
'Scatter': Scatter.get_converter(opset),
'ScatterElements': Scatter.get_converter(opset),
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import op as _reg
from . import strategy
from .op import OpPattern
from ._tensor import elemwise_shape_func

_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
Expand Down Expand Up @@ -88,6 +89,14 @@ def compute_argwhere(attrs, inputs, output_type):

_reg.register_schedule("argwhere", strategy.schedule_argwhere)

# scatter
@_reg.register_compute("scatter")
def compute_scatter(attrs, inputs, output_type):
"""Compute definition of scatter"""
return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]

_reg.register_schedule("scatter", strategy.schedule_scatter)

#####################
# Shape functions #
#####################
Expand Down Expand Up @@ -453,6 +462,8 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")

_reg.register_shape_func("scatter", False, elemwise_shape_func)

@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,13 @@ def schedule_argwhere(attrs, outs, target):
with target:
return topi.generic.schedule_argwhere(outs)

# scatter
@generic_func
def schedule_scatter(attrs, outs, target):
"""schedule scatter"""
with target:
return topi.generic.schedule_scatter(outs)

# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,30 @@ def argwhere(condition):
"""
return _make.argwhere(condition)

def scatter(data, indices, updates, axis):
"""Update data at positions defined by indices with values in updates
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
updates : relay.Expr
The values to update.
axis : int
The axis to scatter on
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter(data, indices, updates, axis)

def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
Expand Down
47 changes: 47 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,53 @@ non-zero)doc" TVM_ADD_FILELINE)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);

// Scatter
TVM_REGISTER_NODE_TYPE(ScatterAttrs);

// Scatter
bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 3);
CHECK_EQ(types.size(), 4);
auto data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
auto indices = types[1].as<TensorTypeNode>();
if (indices == nullptr) {
return false;
}
auto updates = types[2].as<TensorTypeNode>();
if (updates == nullptr) {
return false;
}
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<ScatterAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}

TVM_REGISTER_GLOBAL("relay.op._make.scatter")
.set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
auto attrs = make_object<ScatterAttrs>();
attrs->axis = std::move(axis);
static const Op& op = Op::Get("scatter");
return Call(op, {data, indices, updates}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("scatter")
.describe(
R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input data tensor.")
.add_argument("indicies", "Tensor", "The indicies location tensor.")
.add_argument("updates", "Tensor", "The values to update the input with.")
.add_type_rel("Scatter", ScatterRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

Expand Down
36 changes: 36 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,41 @@ def test_gather():
verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')


def verify_scatter(in_shape, indices, axis):
x = np.random.uniform(size=in_shape).astype("float32")
indices = np.array(indices, dtype="int32")
updates = np.random.uniform(size=indices.shape).astype("float32")

y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis)

graph = helper.make_graph([y],
'scatter_test',
inputs=[helper.make_tensor_value_info("data",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("indices",
TensorProto.INT32, list(indices.shape)),
helper.make_tensor_value_info("updates",
TensorProto.FLOAT, list(indices.shape))],
outputs=[helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(in_shape))])
model = helper.make_model(graph, producer_name='scatter_test')
onnx_out = get_onnxruntime_output(model, [x, indices, updates])

for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [x, indices, updates], target, ctx, onnx_out[0].shape)
tvm.testing.assert_allclose(onnx_out[0], tvm_out)


def test_scatter():
verify_scatter((4,), [1], 0)
verify_scatter((1, 4), [[0]], 0)
verify_scatter((4,), [2, 3], 0)
verify_scatter((2, 2), [[1, 0], [0, 1]], 1)
verify_scatter((3, 3, 3), [[[-1, -3]]], -1)
verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0)


def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
if axes:
y = helper.make_node(
Expand Down Expand Up @@ -2823,6 +2858,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_
test_batch_matmul()
test_gather()
test_gather_nd()
test_scatter()
test_lrn()
test_instance_norm()
test_upsample()
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,54 @@ def verify_reverse(dshape, axis):
verify_reverse((2, 3, 4), -1)


def test_scatter():

def ref_scatter(data, indices, updates, axis=0):
idx = np.indices(indices.shape).reshape(indices.ndim, -1)

updated_idx = np.copy(idx)
indices = indices.reshape(-1)
for i in range(len(indices)):
updated_idx[axis, i] = indices[i]
scattered = np.copy(data)
scattered[tuple(updated_idx)] = updates[tuple(idx)]
return scattered

def verify_scatter(dshape, ishape, axis=0):
d = relay.var("d", relay.TensorType(dshape, "float32"))
i = relay.var("i", relay.TensorType(ishape, "int64"))
u = relay.var("u", relay.TensorType(ishape, "float32"))
z = relay.op.scatter(d, i, u, axis)

func = relay.Function([d, i, u], z)

data_np = np.random.uniform(size=dshape).astype("float32")
updates_np = np.random.uniform(size=ishape).astype("float32")
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")

ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
# TODO(mbrookhart): expand testing when adding more backend schedules
for target, ctx in [("llvm", tvm.cpu())]:
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(
op_res.asnumpy(), ref_res, rtol=1e-5)

verify_scatter((10, ), (10, ), 0)
verify_scatter((10, 5), (10, 5), -2)
verify_scatter((10, 5), (10, 5), -1)
verify_scatter((10, 5), (3, 5), 0)
verify_scatter((12, 4), (7, 2), 1)
verify_scatter((2, 3, 4), (1, 3, 4), 0)
verify_scatter((2, 3, 4), (2, 1, 4), 1)
verify_scatter((2, 3, 4), (2, 3, 1), 2)
verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0)
verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1)
verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2)
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)


def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def verify_resize(dshape, scale, method, layout):
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
for method in ["bilinear", "nearest_neighbor"]:
for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout)
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .transform import *
from .broadcast import *
from .sort import *
from .scatter import *
from .argwhere import *
from . import generic
from . import nn
Expand Down
16 changes: 16 additions & 0 deletions topi/python/topi/generic/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ def schedule_argwhere(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_scatter(outs):
"""Schedule for scatter operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of scatter.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
Loading

0 comments on commit 9f79199

Please sign in to comment.