diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 13f41fc87001..c89ac33c5ff3 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -48,6 +48,7 @@ _reg.register_schedule("cast_like", schedule_injective) _reg.register_schedule("reinterpret", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective) +_reg.register_schedule("strided_set", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("split", schedule_injective) _reg.register_schedule("take", schedule_injective) @@ -304,6 +305,11 @@ def compute_argwhere(attrs, inputs, output_type, _): new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") return [topi.argwhere(new_output_type, inputs[0])] +@_reg.register_compute("strided_set") +def compute_strided_set(attrs, inputs, output_type, _): + """Compute definition of strided_set""" + return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])] + @script def _layout_transform_shape_func(data_shape, out_layout_len, diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 88d7a448005c..0595f75b0fe7 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None): return _make.strided_slice(data, list(begin), list(end), list(strides)) +def strided_set(data, v, begin, end, strides=None): + """Strided set of an array. + + Parameters + ---------- + data : relay.Expr + The source array to be sliced. + + v : relay.Expr + The data to be set. + + begin: relay.Expr + The indices to begin with in the slicing. + + end: relay.Expr + Indices indicating end of the slice. + + strides: relay.Expr, optional + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + strides = strides or const([1], dtype="int32") + return _make.strided_set(data, v, begin, end, strides) + + def slice_like(data, shape_like, axes=None): """Slice the first input with respect to the second input. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 203a0411d3c4..3a58a4b71d4f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2049,6 +2049,54 @@ Examples:: .set_attr("TOpPattern", kInjective) .set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); +// strided_set +bool StridedSetRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 6); + reporter->Assign(types[5], types[0]); + return true; +} + +Expr MakeStridedSet(Expr data, + Expr v, + Expr begin, + Expr end, + Expr strides) { + static const Op& op = Op::Get("strided_set"); + return CallNode::make(op, {data, v, begin, end, strides}, {}); +} + +TVM_REGISTER_API("relay.op._make.strided_set") +.set_body_typed(MakeStridedSet); + + +RELAY_REGISTER_OP("strided_set") + .describe(R"code(Strided set of an array. +Example:: + + x = [[ 1., 4., 7., 10.], + [ 2., 5., 8., 11.], + [ 3., 6., 9., 12.]] + + v = [[ 11., 22., 33.] + [ 44., 55., 66.]] + + strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \ + [[ 1., 11., 22., 33.], + [ 2., 44., 55., 66.], + [ 3., 6., 9., 12.]] +)code" TVM_ADD_FILELINE) +.set_num_inputs(5) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("v", "Tensor", "The data to set.") +.add_argument("begin", "Tensor", "Indices for the start of the slice.") +.add_argument("end", "Tensor", "Indices indicating the end of the slice.") +.add_argument("strides", "Tensor", "The strides values.") +.set_support_level(4) +.set_attr("TOpPattern", kInjective) +.add_type_rel("StridedSet", StridedSetRel); // relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 6a8a678bfda3..431f014c31a0 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -300,8 +300,48 @@ def verify(dshape, begin, end, strides, output, test_ref=True): verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) +def test_strided_set(): + def verify(dshape, begin, end, strides, vshape, test_ref=True): + x = relay.var("x", relay.TensorType(dshape, "float32")) + v = relay.var("v", relay.TensorType(vshape, "float32")) + begin_c = relay.const(begin, dtype="int32") + end_c = relay.const(end, dtype="int32") + if strides: + strides_c = relay.const(strides, dtype="int32") + z = relay.strided_set(x, v, begin=begin_c, end=end_c, strides=strides_c) + else: + z = relay.strided_set(x, v, begin=begin_c, end=end_c) + func = relay.Function([x, v], z) + func = run_infer_type(func) + text = func.astext() + assert "strided_set" in text + print(text) + assert func.body.checked_type == relay.ty.TensorType(dshape, "float32") + if not test_ref: + return + x_data = np.random.uniform(size=dshape).astype("float32") + v_data = np.random.uniform(size=vshape).astype("float32") + ref_res = topi.testing.strided_set_python( + x_data, v_data, begin, end, strides) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, v_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) + + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) + + if __name__ == "__main__": test_strided_slice() + test_strided_set() test_binary_op() test_cmp_type() test_binary_int_broadcast() diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index d607c28dccdb..240e23c89f6f 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -20,7 +20,7 @@ from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_nd_python import gather_nd_python -from .strided_slice_python import strided_slice_python +from .strided_slice_python import strided_slice_python, strided_set_python from .batch_matmul import batch_matmul from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask diff --git a/topi/python/topi/testing/strided_slice_python.py b/topi/python/topi/testing/strided_slice_python.py index b842da73df78..c1c899afe31f 100644 --- a/topi/python/topi/testing/strided_slice_python.py +++ b/topi/python/topi/testing/strided_slice_python.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""gather_nd in python""" +"""strided_slice/set in python""" + def strided_slice_python(data, begin, end, strides): """Python version of strided slice operator. @@ -46,3 +47,40 @@ def strided_slice_python(data, begin, end, strides): end[i] if i < len(end) else None, strides[i] if i < len(strides) else None)) return data[tuple(slices)] + + +def strided_set_python(data, v, begin, end, strides): + """Python version of strided slice operator. + + Parameters + ---------- + data : numpy.ndarray + Input data + + v : numpy.ndarray + Value data + + begin : list + Begining of the slices. + + end : list + End of the slices. + + strides : list + The stride of each slice. + + Returns + ------- + result : numpy.ndarray + The updated result. + """ + strides = [] if strides is None else strides + slices = [] + res = data.copy() + for i in range(len(data.shape)): + slices.append(slice( + begin[i] if i < len(begin) else None, + end[i] if i < len(end) else None, + strides[i] if i < len(strides) else None)) + res[tuple(slices)] = v + return res diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 3c7fc9c0dffb..41bf2e893b4d 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -20,6 +20,8 @@ import tvm import topi from . import cpp +from . import tag +from .util import within_index, make_idx def expand_dims(a, axis, num_newaxis=1): @@ -155,6 +157,97 @@ def strided_slice(a, begin, end, strides=None): strides = [] return cpp.strided_slice(a, begin, end, strides) +@tvm.tag_scope(tag=tag.INJECTIVE+",strided_set") +def strided_set(a, v, begin, end, strides=None): + """Set slice of an array. + + Parameters + ---------- + a : tvm.Tensor + The tensor to be sliced. + + v : tvm.Tensor + The values to set + + begin: tvm.Tensor + The indices to begin with in the slicing. + + end: tvm.Tensor + Indicies indicating end of the slice. + + strides: tvm.Tensor, optional + Specifies the stride values, it can be negative + in that case, the input tensor will be reversed + in that particular axis. + + Returns + ------- + ret : tvm.Tensor + """ + n = len(a.shape) + + if len(begin.shape) != 1: + raise ValueError("begin should be a vector") + if not begin.dtype == 'int32': + raise TypeError("begin should be int32") + if len(end.shape) != 1: + raise ValueError("end should be a vector") + if not end.dtype == 'int32': + raise TypeError("end should be int32") + if strides is not None: + if len(strides.shape) != 1: + raise ValueError("strides should be a vector") + if not strides.dtype == 'int32': + raise TypeError("strides should be int32") + + def _max(a, b): + return tvm.expr.Select(a > b, a, b) + + if strides is None: + strides = [tvm.const(1, 'int32')] * n + else: + strides = [tvm.if_then_else(strides.shape[0] > i, + strides[i], + tvm.const(1, 'int32')) + for i in range(n)] + + begin = [tvm.if_then_else(begin.shape[0] > i, + begin[i], + tvm.expr.Select(strides[i] > 0, + tvm.const(0, 'int32'), + a.shape[i])) + for i in range(n)] + end = [tvm.if_then_else(end.shape[0] > i, + end[i], + tvm.expr.Select(strides[i] > 0, + a.shape[i] + 1, + -(a.shape[i] + 1))) + for i in range(n)] + + + # Convert negative indexes + for i in range(n): + begin[i] = tvm.if_then_else(begin[i] < 0, + begin[i] + a.shape[i], + begin[i]) + end[i] = tvm.if_then_else(end[i] < 0, + end[i] + a.shape[i], + end[i]) + + def _select(*indices): + from_val = [] + index_tuple = [] + for i in range(n): + from_val.append( + within_index(begin[i], end[i], strides[i], indices[i])) + index_tuple.append( + make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i])) + return tvm.if_then_else(tvm.all(*from_val), + v(*index_tuple), + a(*indices)) + + return tvm.compute(a.shape, _select, name="strided_set") + def reshape(a, newshape): """Reshape the array diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 623d06ae07c8..e25e85dac05e 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -345,3 +345,75 @@ def get_shape(src_shape, src_layout, dst_layout): tvm.convert([i for i in range(len(src_layout))])) return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) + + +def within_index(b, e, s, i): + """Return a boolean value that indicates if i is within the given index. + + Parameter + --------- + b : Expr + beginning of the index + + e : Expr + end of the index + + s : Expr + strides of index + + i : Expr + array position + + Returns + ------- + selected: Expr + bool expression that is True is the array position would be selected + by the index and False otherwise + """ + bc = tvm.expr.Select(s < 0, i <= e, i < b) + ec = tvm.expr.Select(s < 0, i > b, i >= e) + ss = tvm.if_then_else(s < 0, + ((i - e) + (e % tvm.abs(s)) + 1) % tvm.abs(s), + (i - b) % s) + return tvm.expr.Select(tvm.expr.Or(bc, ec), tvm.const(False), ss.equal(0)) + + +def make_idx(b, e, s, z, i): + """Return the array position in the selection that corresponds to an + array position in the full array. + + The returned value is only meaningful if within_index() returns True + for the same set of parameters. + + Parameter + --------- + b : Expr + beginning of the index + + e : Expr + end of the index + + s : Expr + strides of index + + z : Expr + size of the indexed dimension + + i : Expr + array position + + Returns + ------- + postion: Expr + int expression that corresponds to an array position in the selection. + """ + bc = tvm.expr.Select(s < 0, i <= e, i < b) + ec = tvm.expr.Select(s < 0, i > b, i >= e) + + # Clamp to array size + b = tvm.expr.Select(z < b, z - 1, b) + + ss = tvm.if_then_else(s < 0, + (b - i) // tvm.abs(s), + (i - b) // s) + return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss) diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 4a529f4a047f..4dc485836ee6 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -342,6 +342,52 @@ def check_device(device): for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) +def verify_strided_set(in_shape, v_shape, begin, end, strides=None): + A = tvm.placeholder(shape=in_shape, name="A") + V = tvm.placeholder(shape=v_shape, name="V") + b = tvm.placeholder(shape=(len(begin),), name="b", dtype='int32') + e = tvm.placeholder(shape=(len(end),), name="e", dtype='int32') + if strides is not None: + st = tvm.placeholder(shape=(len(strides),), name="st", dtype='int32') + B = topi.strided_set(A, V, b, e, st) + 1 + else: + B = topi.strided_set(A, V, b, e) + 1 + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + + if strides is not None: + foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set") + s_np = np.asarray(strides).astype('int32') + s_nd = tvm.nd.array(s_np, ctx) + else: + foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set") + x_np = np.random.uniform(size=in_shape).astype(A.dtype) + v_np = np.random.uniform(size=v_shape).astype(V.dtype) + b_np = np.asarray(begin).astype('int32') + e_np = np.asarray(end).astype('int32') + out_npy = topi.testing.strided_set_python( + x_np, v_np, begin, end, strides) + 1 + data_nd = tvm.nd.array(x_np, ctx) + v_nd = tvm.nd.array(v_np, ctx) + b_nd = tvm.nd.array(b_np, ctx) + e_nd = tvm.nd.array(e_np, ctx) + out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype) + if strides is not None: + foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd) + else: + foo(data_nd, v_nd, b_nd, e_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) + + for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: + check_device(device) + def verify_gather_nd(src_shape, indices_src, indices_dtype): src_dtype = "float32" indices_src = np.array(indices_src, dtype=indices_dtype) @@ -510,6 +556,17 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) +def test_strided_set(): + verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2]) + verify_strided_set((3, 4, 3), (3, 1, 2), [0, 0, 0], [4, -5, 4], [1, -1, 2]) + verify_strided_set((3, 4, 3), (1, 3, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) + verify_strided_set((3, 4, 3), (1, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) + verify_strided_set((3, 4, 3), (1, 2, 2), [1, 0, 0], [2, 2, 3], [1, 1, 2]) + verify_strided_set((3, 4, 3), (1, 2, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) + verify_strided_set((3, 4, 3), (1, 2, 3), [1, 1, 0], [2, 3, 3], [1]) + verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1, 0], [4, 4, 3]) + verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1], [4, 4, 3]) + def test_expand_dims(): verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)