Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay][OP] Add a strided_set operation. #4303

Merged
merged 10 commits into from
Dec 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_reg.register_schedule("cast_like", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("strided_set", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
Expand Down Expand Up @@ -304,6 +305,11 @@ def compute_argwhere(attrs, inputs, output_type, _):
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]

@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type, _):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]

@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None):
return _make.strided_slice(data, list(begin), list(end), list(strides))


def strided_set(data, v, begin, end, strides=None):
"""Strided set of an array.

Parameters
----------
data : relay.Expr
The source array to be sliced.

v : relay.Expr
The data to be set.

begin: relay.Expr
The indices to begin with in the slicing.

end: relay.Expr
Indices indicating end of the slice.

strides: relay.Expr, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.

Returns
-------
ret : relay.Expr
The computed result.
"""
strides = strides or const([1], dtype="int32")
return _make.strided_set(data, v, begin, end, strides)


def slice_like(data, shape_like, axes=None):
"""Slice the first input with respect to the second input.

Expand Down
48 changes: 48 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2049,6 +2049,54 @@ Examples::
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);

// strided_set
bool StridedSetRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 6);
reporter->Assign(types[5], types[0]);
return true;
}

Expr MakeStridedSet(Expr data,
Expr v,
Expr begin,
Expr end,
Expr strides) {
static const Op& op = Op::Get("strided_set");
return CallNode::make(op, {data, v, begin, end, strides}, {});
}

TVM_REGISTER_API("relay.op._make.strided_set")
.set_body_typed(MakeStridedSet);


RELAY_REGISTER_OP("strided_set")
.describe(R"code(Strided set of an array.
Example::

x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]

v = [[ 11., 22., 33.]
[ 44., 55., 66.]]

strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \
[[ 1., 11., 22., 33.],
[ 2., 44., 55., 66.],
[ 3., 6., 9., 12.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(5)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("v", "Tensor", "The data to set.")
.add_argument("begin", "Tensor", "Indices for the start of the slice.")
.add_argument("end", "Tensor", "Indices indicating the end of the slice.")
.add_argument("strides", "Tensor", "The strides values.")
.set_support_level(4)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.add_type_rel("StridedSet", StridedSetRel);

// relay.split
TVM_REGISTER_NODE_TYPE(SplitAttrs);
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,48 @@ def verify(dshape, begin, end, strides, output, test_ref=True):
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))


def test_strided_set():
def verify(dshape, begin, end, strides, vshape, test_ref=True):
x = relay.var("x", relay.TensorType(dshape, "float32"))
v = relay.var("v", relay.TensorType(vshape, "float32"))
begin_c = relay.const(begin, dtype="int32")
end_c = relay.const(end, dtype="int32")
if strides:
strides_c = relay.const(strides, dtype="int32")
z = relay.strided_set(x, v, begin=begin_c, end=end_c, strides=strides_c)
else:
z = relay.strided_set(x, v, begin=begin_c, end=end_c)
func = relay.Function([x, v], z)
func = run_infer_type(func)
text = func.astext()
assert "strided_set" in text
print(text)
assert func.body.checked_type == relay.ty.TensorType(dshape, "float32")
if not test_ref:
return
x_data = np.random.uniform(size=dshape).astype("float32")
v_data = np.random.uniform(size=vshape).astype("float32")
ref_res = topi.testing.strided_set_python(
x_data, v_data, begin, end, strides)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, v_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)

verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))


if __name__ == "__main__":
test_strided_slice()
test_strided_set()
test_binary_op()
test_cmp_type()
test_binary_int_broadcast()
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
Expand Down
40 changes: 39 additions & 1 deletion topi/python/topi/testing/strided_slice_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""gather_nd in python"""
"""strided_slice/set in python"""


def strided_slice_python(data, begin, end, strides):
"""Python version of strided slice operator.
Expand Down Expand Up @@ -46,3 +47,40 @@ def strided_slice_python(data, begin, end, strides):
end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None))
return data[tuple(slices)]


def strided_set_python(data, v, begin, end, strides):
"""Python version of strided slice operator.

Parameters
----------
data : numpy.ndarray
Input data

v : numpy.ndarray
Value data

begin : list
Begining of the slices.

end : list
End of the slices.

strides : list
The stride of each slice.

Returns
-------
result : numpy.ndarray
The updated result.
"""
strides = [] if strides is None else strides
slices = []
res = data.copy()
for i in range(len(data.shape)):
slices.append(slice(
begin[i] if i < len(begin) else None,
end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None))
res[tuple(slices)] = v
return res
93 changes: 93 additions & 0 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import tvm
import topi
from . import cpp
from . import tag
from .util import within_index, make_idx


def expand_dims(a, axis, num_newaxis=1):
Expand Down Expand Up @@ -155,6 +157,97 @@ def strided_slice(a, begin, end, strides=None):
strides = []
return cpp.strided_slice(a, begin, end, strides)

@tvm.tag_scope(tag=tag.INJECTIVE+",strided_set")
def strided_set(a, v, begin, end, strides=None):
"""Set slice of an array.

Parameters
----------
a : tvm.Tensor
The tensor to be sliced.

v : tvm.Tensor
The values to set

begin: tvm.Tensor
The indices to begin with in the slicing.

end: tvm.Tensor
Indicies indicating end of the slice.

strides: tvm.Tensor, optional
Specifies the stride values, it can be negative
in that case, the input tensor will be reversed
in that particular axis.

Returns
-------
ret : tvm.Tensor
"""
n = len(a.shape)

if len(begin.shape) != 1:
raise ValueError("begin should be a vector")
if not begin.dtype == 'int32':
raise TypeError("begin should be int32")
if len(end.shape) != 1:
raise ValueError("end should be a vector")
if not end.dtype == 'int32':
raise TypeError("end should be int32")
if strides is not None:
if len(strides.shape) != 1:
raise ValueError("strides should be a vector")
if not strides.dtype == 'int32':
raise TypeError("strides should be int32")

def _max(a, b):
return tvm.expr.Select(a > b, a, b)

if strides is None:
strides = [tvm.const(1, 'int32')] * n
else:
strides = [tvm.if_then_else(strides.shape[0] > i,
strides[i],
tvm.const(1, 'int32'))
for i in range(n)]

begin = [tvm.if_then_else(begin.shape[0] > i,
begin[i],
tvm.expr.Select(strides[i] > 0,
tvm.const(0, 'int32'),
a.shape[i]))
for i in range(n)]
end = [tvm.if_then_else(end.shape[0] > i,
end[i],
tvm.expr.Select(strides[i] > 0,
a.shape[i] + 1,
-(a.shape[i] + 1)))
for i in range(n)]


# Convert negative indexes
for i in range(n):
begin[i] = tvm.if_then_else(begin[i] < 0,
begin[i] + a.shape[i],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even begin[i] + a.shape[i] could trigger OOB. I am not sure how to assert the bound...Could anyone help here?

Copy link
Contributor Author

@abergeron abergeron Nov 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By OOB do you mean out of bounds of the indexed array? If yes, that is not a problem because the code will never try to fetch indices that are out the array shape.

This code just tries to handle numpy-style negative indexing (starts from the end of the array) just like strided_slice does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I don’t have much idea about this either

Copy link
Contributor Author

@abergeron abergeron Nov 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the logical result is very similar to strided_slice the implementation has a difference.

In this operation: a[begin:end:stride] = b

The core kernel loops over all valid indexes for a and check if that index is part of the values selected by the combination of begin, end and stride. If it is, it will compute the corresponding index in b and map the output to that value. Otherwise it will pick up the value from a at that index.

In all cases it doesn't matter if begin, end, or stride doesn't fall within the bounds of a because they are never used to directly or indirectly index into a.

begin[i])
end[i] = tvm.if_then_else(end[i] < 0,
end[i] + a.shape[i],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. OOB is not completely detected.

end[i])

def _select(*indices):
from_val = []
index_tuple = []
for i in range(n):
from_val.append(
within_index(begin[i], end[i], strides[i], indices[i]))
index_tuple.append(
make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i]))
return tvm.if_then_else(tvm.all(*from_val),
v(*index_tuple),
a(*indices))

return tvm.compute(a.shape, _select, name="strided_set")


def reshape(a, newshape):
"""Reshape the array
Expand Down
Loading