Skip to content

Commit

Permalink
Add operation scatter_add to relay, based on scatter implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
notoraptor committed Jul 10, 2020
1 parent 8a0249c commit c39bd9e
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
}
};

struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
}
};

struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;

Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def compute_scatter(attrs, inputs, output_type):

_reg.register_schedule("scatter", strategy.schedule_scatter)

# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
"""Compute definition of scatter_add"""
return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]

_reg.register_schedule("scatter_add", strategy.schedule_scatter_add)

#####################
# Shape functions #
#####################
Expand Down Expand Up @@ -396,6 +404,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return ValueError("Does not support rank higher than 5 in argwhere")

_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)

@script
def _layout_transform_shape_func(data_shape,
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,13 @@ def schedule_scatter(attrs, outs, target):
with target:
return topi.generic.schedule_scatter(outs)

# scatter_add
@generic_func
def schedule_scatter_add(attrs, outs, target):
"""schedule scatter_add"""
with target:
return topi.generic.schedule_scatter_add(outs)

# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,30 @@ def scatter(data, indices, updates, axis):
"""
return _make.scatter(data, indices, updates, axis)

def scatter_add(data, indices, updates, axis):
"""Update data by adding values in updates at positions defined by indices
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
updates : relay.Expr
The values to add.
axis : int
The axis to scatter_add on
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_add(data, indices, updates, axis)

def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
Expand Down
49 changes: 49 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,55 @@ RELAY_REGISTER_OP("scatter")
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);

// Scatter_add
TVM_REGISTER_NODE_TYPE(ScatterAddAttrs);

// Scatter Add
bool ScatterAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 3);
CHECK_EQ(types.size(), 4);
auto data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
auto indices = types[1].as<TensorTypeNode>();
if (indices == nullptr) {
return false;
}
auto updates = types[2].as<TensorTypeNode>();
if (updates == nullptr) {
return false;
}
CHECK(indices->dtype.is_int()) << "indices of scatter_add must be tensor of integer";
const auto param = attrs.as<ScatterAddAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}

TVM_REGISTER_GLOBAL("relay.op._make.scatter_add")
.set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
auto attrs = make_object<ScatterAddAttrs>();
attrs->axis = std::move(axis);
static const Op& op = Op::Get("scatter_add");
return Call(op, {data, indices, updates}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("scatter_add")
.describe(
R"doc(Update data by adding values in updates at positions defined by indices)doc" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input data tensor.")
.add_argument("indicies", "Tensor", "The indicies location tensor.")
.add_argument("updates", "Tensor", "The values to update the input with.")
.add_type_rel("ScatterAdd", ScatterAddRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);

////

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

Expand Down
45 changes: 45 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,51 @@ def verify_scatter(dshape, ishape, axis=0):
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)


def test_scatter_add():

def ref_scatter_add(data, indices, updates, axis=0):
output = np.copy(data)
for index in np.ndindex(*indices.shape):
new_index = list(index)
new_index[axis] = indices[index]
output[tuple(new_index)] += updates[index]
return output

def verify_scatter_add(dshape, ishape, axis=0):
d = relay.var("d", relay.TensorType(dshape, "float32"))
i = relay.var("i", relay.TensorType(ishape, "int64"))
u = relay.var("u", relay.TensorType(ishape, "float32"))
z = relay.op.scatter_add(d, i, u, axis)

func = relay.Function([d, i, u], z)

data_np = np.random.uniform(size=dshape).astype("float32")
updates_np = np.random.uniform(size=ishape).astype("float32")
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")

ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
# TODO(mbrookhart): expand testing when adding more backend schedules
for target, ctx in [("llvm", tvm.cpu())]:
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(
op_res.asnumpy(), ref_res, rtol=1e-5)

verify_scatter_add((10, ), (10, ), 0)
verify_scatter_add((10, 5), (10, 5), -2)
verify_scatter_add((10, 5), (10, 5), -1)
verify_scatter_add((10, 5), (3, 5), 0)
verify_scatter_add((12, 4), (7, 2), 1)
verify_scatter_add((2, 3, 4), (1, 3, 4), 0)
verify_scatter_add((2, 3, 4), (2, 1, 4), 1)
verify_scatter_add((2, 3, 4), (2, 3, 1), 2)
verify_scatter_add((2, 3, 4, 5), (1, 3, 4, 5), 0)
verify_scatter_add((6, 3, 4, 5), (2, 3, 4, 5), 1)
verify_scatter_add((2, 3, 8, 5), (2, 3, 1, 1), 2)
verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3)


def test_gather():
def verify_gather(data, axis, indices, ref_res):
data = np.asarray(data, dtype='float32')
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .broadcast import *
from .sort import *
from .scatter import *
from .scatter_add import *
from .argwhere import *
from . import generic
from . import nn
Expand Down
16 changes: 16 additions & 0 deletions topi/python/topi/generic/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,19 @@ def schedule_scatter(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_scatter_add(outs):
"""Schedule for scatter_add operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of scatter_add.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
165 changes: 165 additions & 0 deletions topi/python/topi/scatter_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter Add operator"""
from tvm.te import hybrid


@hybrid.script
def _scatter_add_1d(data, indices, updates):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
out[i] = data[i]
for i in range(indices.shape[0]):
out[indices[i] if indices[i] >= 0 else indices[i] +
data.shape[0]] += updates[i]
return out


@hybrid.script
def _scatter_add_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
out[i, j] = data[i, j]
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
out[indices[i, j] if indices[i, j] >=
0 else indices[i, j] + data.shape[axis], j] += updates[i, j]
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
out[i, indices[i, j] if indices[i, j] >=
0 else indices[i, j] + data.shape[axis]] += updates[i, j]

return out


@hybrid.script
def _scatter_add_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
out[i, j, k] = data[i, j, k]
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
out[indices[i, j, k] if indices[i, j, k] >=
0 else indices[i, j, k] + data.shape[axis], j, k] += updates[i, j, k]
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
out[i, indices[i, j, k] if indices[i, j, k] >=
0 else indices[i, j, k] + data.shape[axis], k] += updates[i, j, k]
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
out[i, j, indices[i, j, k] if indices[i, j, k] >=
0 else indices[i, j, k] + data.shape[axis]] += updates[i, j, k]

return out


@hybrid.script
def _scatter_add_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for l in const_range(data.shape[3]):
out[i, j, k, l] = data[i, j, k, l]

if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
out[indices[i, j, k, l] if indices[i, j, k, l] >=
0 else indices[i, j, k, l] + data.shape[axis],
j, k, l] += updates[i, j, k, l]
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
out[i,
indices[i, j, k, l] if indices[i, j, k, l] >=
0 else indices[i, j, k, l] + data.shape[axis],
k, l] += updates[i, j, k, l]
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
out[i, j,
indices[i, j, k, l] if indices[i, j, k, l] >=
0 else indices[i, j, k, l] + data.shape[axis],
l] += updates[i, j, k, l]
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
out[i, j, k,
indices[i, j, k, l] if indices[i, j, k, l] >=
0 else indices[i, j, k, l] + data.shape[axis]
] += updates[i, j, k, l]

return out


def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
updates : relay.Expr
The values to update.
axis : int
The axis to scatter_add on
Returns
-------
ret : relay.Expr
The computed result.
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)

if len(data.shape) == 1:
return _scatter_add_1d(data, indices, updates)
if len(data.shape) == 2:
return _scatter_add_2d(data, indices, updates, axis)
if len(data.shape) == 3:
return _scatter_add_3d(data, indices, updates, axis)
if len(data.shape) == 4:
return _scatter_add_4d(data, indices, updates, axis)
raise ValueError("scatter_add only support for 1-4 dimensions")

0 comments on commit c39bd9e

Please sign in to comment.