From 27fa7dfce2b080596e2098f222b2e04a029e9150 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 06:25:52 +0000 Subject: [PATCH 01/25] SparseReshape Inital Code --- python/tvm/relay/frontend/tensorflow.py | 10 + python/tvm/relay/op/_transform.py | 36 ++++ python/tvm/relay/op/strategy/generic.py | 27 +++ python/tvm/relay/op/transform.py | 39 ++++ python/tvm/topi/__init__.py | 1 + python/tvm/topi/generic/search.py | 4 + python/tvm/topi/sparse_reshape.py | 138 ++++++++++++++ src/relay/op/tensor/transform.cc | 51 +++++ .../frontend/tensorflow/test_forward.py | 107 +++++++++++ tests/python/relay/test_op_level3.py | 178 ++++++++++++++++++ 10 files changed, 591 insertions(+) create mode 100644 python/tvm/topi/sparse_reshape.py diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6a29ce266ea6..ec5fda503b8c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1020,6 +1020,15 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_reshape(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + new_indices, new_shape = get_relay_op("sparse_reshape")(inputs[0], inputs[1], inputs[2]) + return _expr.TupleWrapper(_expr.Tuple([new_indices, new_shape]), 2) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2478,6 +2487,7 @@ def _impl(inputs, attr, params, mod): "SparseToDense": _sparse_to_dense(), "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), "SparseFillEmptyRows": _sparse_fill_empty_rows(), + "SparseReshape": _sparse_reshape(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 01bcf4a6cf60..4ce9679c2881 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -65,6 +65,8 @@ _reg.register_injective_schedule("sparse_to_dense") _reg.register_injective_schedule("matrix_set_diag") _reg.register_injective_schedule("adv_index") +# _reg.register_injective_schedule("sparse_reshape") + # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) @@ -114,6 +116,22 @@ def compute_sparse_fill_empty_rows(attrs, inputs, output_type): _reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy) +# sparse_reshape +@_reg.register_compute("sparse_reshape") +def compute_reshape(attrs, inputs, output_type): + """Compute definition of sparse_reshape""" + + return topi.reshape( + inputs[0], + inputs[1], + inputs[2], + output_type.fields[0].shape, + output_type.fields[1].shape, + ) + + +_reg.register_strategy("sparse_reshape", strategy.sparse_reshape_strategy) + # scatter_add @_reg.register_compute("scatter_add") def compute_scatter_add(attrs, inputs, output_type): @@ -506,6 +524,24 @@ def sparse_fill_empty_rows_func(attrs, inputs, _): return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2]) +@script +def _sparse_reshape_shape_func(sparse_indices_shape, prev_shape_shape, new_shape_shape): + indices_shape = output_tensor((2,), "int64") + indices_shape[0] = int64(sparse_indices_shape[0]) + indices_shape[1] = int64(new_shape_shape[0]) + shape_tensor = output_tensor((1,), "int64") + shape_tensor[0] = int64(new_shape_shape[0]) + return (indices_shape, shape_tensor) + + +@_reg.register_shape_func("sparse_reshape", False) +def sparse_reshape_shape_func(attrs, inputs, _): + """ + Shape func for sparse_reshape. + """ + return _sparse_reshape_shape_func(inputs[0], inputs[1], inputs[2]) + + @script def _layout_transform_shape_func( data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f076176c5d8a..270db261ac15 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1105,6 +1105,33 @@ def _compute_sparse_fill_empty_rows(attrs, inputs, output_type): return _compute_sparse_fill_empty_rows +# sparse_reshape +@override_native_generic_func("sparse_reshape_strategy") +def sparse_reshape_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_reshape(topi.sparse_reshape), + wrap_topi_schedule(topi.generic.schedule_sparse_reshape), + name="sparse_reshape.generic", + ) + return strategy + + +def wrap_compute_sparse_reshape(topi_compute): + """Wrap sparse_reshape compute""" + + def _compute_sparse_reshape(attrs, inputs, output_type): + return topi_compute( + inputs[0], + inputs[1], + inputs[2], + output_type.fields[0].shape, + output_type.fields[1].shape, + ) + + return _compute_sparse_reshape + + # roi_pool @generic_func def schedule_roi_pool(attrs, outs, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b676fe742544..f1d3963ccad7 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1410,6 +1410,45 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v return Tuple((new_sparse_indices, new_sparse_values, empty_row_indicator)) +def sparse_reshape(sparse_indices, prev_shape, new_shape): + """ + Reshape a Sparse Tensor + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + prev_shape = [2, 3, 4] + new_shape = [9, -1] + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2) + + def cumsum(data, axis=None, dtype=None, exclusive=None): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 2b17162048e0..fe90ca5a94b8 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -39,6 +39,7 @@ from .sort import * from .scatter import * from .sparse_fill_empty_rows import * +from .sparse_reshape import * from .scatter_add import * from .argwhere import * from .cumsum import * diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index 5924d35def73..a7761da40a59 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -70,3 +70,7 @@ def schedule_scatter_add(outs): def schedule_sparse_fill_empty_rows(outs): return _default_schedule(outs, False) + + +def schedule_sparse_reshape(outs): + return _default_schedule(outs, False) \ No newline at end of file diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py new file mode 100644 index 000000000000..9b432db1e9ef --- /dev/null +++ b/python/tvm/topi/sparse_reshape.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Scatter operator""" +from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate +from ..te import extern, hybrid, div, floordiv, floormod + + +def sparse_reshape( + sparse_indices, + prev_shape, + new_shape, + new_sparse_indices_shape, + new_shape_shape, +): + def gen_ir( + sparse_indices_ptr, + prev_shape_ptr, + new_shape_ptr, + new_sparse_indices_ptr, + out_new_shape_ptr, + ): + ib = ir_builder.create() + + sparse_indices = ib.buffer_ptr(sparse_indices_ptr) + prev_shape = ib.buffer_ptr(prev_shape_ptr) + + new_shape = ib.buffer_ptr(new_shape_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + new_sparse_indices = ib.buffer_ptr(new_sparse_indices_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + + prev_shape_size = prev_shape_ptr.shape[0] + new_shape_size = new_shape_ptr.shape[0] + + multipliers = ib.allocate("int64", (prev_shape_size,), name="multipliers", scope="local") + dividers = ib.allocate("int64", (new_shape_size,), name="dividers", scope="local") + flattened_indices = ib.allocate( + "int64", (sparse_indices_ptr.shape[0],), name="flattened_indices", scope="local" + ) + + total_ele = ib.allocate("int64", (1,), name="total_ele", scope="local") + total_ele[0] = prev_shape[0] + + # Cumulative Reverse Exclusive Multiply + multipliers[prev_shape_size - 1] = Cast("int64", 1) + with ib.for_range(0, prev_shape_size - 1) as i_: + i = i_ + 1 + multipliers[prev_shape_size - 1 - i] = ( + prev_shape[prev_shape_size - i] * multipliers[prev_shape_size - i] + ) + total_ele[0] *= prev_shape[prev_shape_size - i] + + division_total_ele = ib.allocate("int64", (1,), name="division_total_ele", scope="local") + division_total_ele[0] = Cast("int64", 1) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] != -1): + division_total_ele[0] *= new_shape[i] + + # Compute true output shape (replace negative ones) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] == -1): + # if Cast("int64", new_shape[i]) == Cast("int64", -1): + out_new_shape[i] = Cast("int64", div(total_ele[0], division_total_ele[0])) + with ib.else_scope(): + out_new_shape[i] = new_shape[i] + + equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + + # Check if prev_shape and new_shape are equal + equal_shape[0] = True + with ib.if_scope(prev_shape_size == new_shape_size): + with ib.for_range(0, prev_shape_size) as i: + with ib.if_scope(prev_shape[i] != out_new_shape[i]): + equal_shape[0] = False + with ib.else_scope(): + equal_shape[0] = False + + # Return same inputs if shapes are equal + with ib.if_scope(equal_shape[0]): + with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[i, j] = sparse_indices[i, j] + + # Else compute new_sparse_indices + with ib.else_scope(): + dividers[new_shape_size - 1] = Cast("int64", 1) + with ib.for_range(0, new_shape_size - 1) as i_: + i = i_ + 1 + dividers[new_shape_size - 1 - i] = ( + dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] + ) + + with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + flattened_indices[i] = Cast("int64", 0) + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + flattened_indices[i] += sparse_indices[i, j] * multipliers[j] + + with ib.for_range(0, new_sparse_indices_ptr.shape[0]) as i: + current_element = ib.allocate("int64", (1,), name="current_element", scope="local") + current_element[0] = flattened_indices[i] + + with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[i, j] = Cast( + "int64", floordiv(current_element[0], dividers[j]) + ) + current_element[0] = floormod(current_element[0], dividers[j]) + + return ib.get() + + new_sparse_indices_buf = decl_buffer( + new_sparse_indices_shape, "int64", "new_sparse_indices_buf" + ) + new_shape_buf = decl_buffer(new_shape_shape, "int64", "new_shape_buf") + + return extern( + [new_sparse_indices_shape, new_shape_shape], + [sparse_indices, prev_shape, new_shape], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), + dtype="int64", + out_buffers=[new_sparse_indices_buf, new_shape_buf], + name="sparse_reshape", + tag="sparse_reshape", + ) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1e782a568fe9..1b8f467a5e49 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1628,6 +1628,57 @@ RELAY_REGISTER_OP("sparse_fill_empty_rows") .set_support_level(3) .set_attr("TOpPattern", kOpaque); +bool SparseReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [sparse_indices, prev_shape, new_shape, result] + ICHECK_EQ(types.size(), 4) << "SparseReshapeRel expects 4 types but " << types.size() + << " provided"; + ICHECK_EQ(num_inputs, 3) << "SparseReshapeRel expects 4 inputs but " << num_inputs << " provided"; + auto sparse_indices = types[0].as(); + auto prev_shape = types[1].as(); + auto new_shape = types[2].as(); + if (sparse_indices == nullptr || prev_shape == nullptr || new_shape == nullptr) { + return false; + } + CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers"; + CHECK(prev_shape->dtype.is_int()) << "prev_shape must be tensor of integers"; + CHECK(new_shape->dtype.is_int()) << "new_shape must be tensor of integers"; + ICHECK_EQ(sparse_indices->shape.size(), 2) << "sparse_indices must be 2-D tensor"; + ICHECK_EQ(prev_shape->shape.size(), 1) << "prev_shape must be 1-D tensor"; + ICHECK_EQ(new_shape->shape.size(), 1) << "new_shape must be 1-D tensor"; + std::vector fields; + Array new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]}; + fields.push_back(TensorType(new_sparse_indices_shape, sparse_indices->dtype)); + fields.push_back(TensorType(new_shape->shape, new_shape->dtype)); + reporter->Assign(types[3], TupleType(Array(fields))); + return true; +} + +// Array SparseReshapeCompute(const Attrs& attrs, const Array& inputs, +// const Type& out_type) { +// ICHECK_EQ(inputs.size(), 3) << "SparseReshapeCompute expects 3 inputs but " << inputs.size() << "provided"; +// return {topi::SparseReshape(inputs[0], inputs[1], inputs[2])}; +// } + +Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) { + static const Op& op = Op::Get("sparse_reshape"); + return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape); + +RELAY_REGISTER_OP("sparse_reshape") + .describe(R"code(Return new sparse indices of the reshaped tensor +)code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("sparse_indices", "Tensor", "The first tensor") + .add_argument("prev_shape", "Tensor", "The second tensor") + .add_argument("new_shape", "Tensor", "The third tensor") + .add_type_rel("sparse_reshape", SparseReshapeRel) + .set_attr("TOpPattern", kInjective) + .set_support_level(3); + // .set_attr("FTVMCompute", SparseReshapeCompute); + // meshgrid operator TVM_REGISTER_NODE_TYPE(MeshgridAttrs); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f956ea02eb47..3c19263af830 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1915,6 +1915,113 @@ def test_forward_sparse_fill_empty_rows( ) +####################################################################### +# SparseReshape +# ------------ + + +def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, use_dyn=False): + with tf.Graph().as_default(): + if use_dyn: + indices = tf.placeholder(shape=(None, None), dtype=indices_np.dtype, name="indices") + values = tf.placeholder(shape=(None), dtype=values_np.dtype, name="values") + prev_shape = tf.placeholder(shape=(None), dtype=prev_shape_np.dtype, name="prev_shape") + new_shape = tf.placeholder(shape=(None), dtype=new_shape_np.dtype, name="new_shape") + else: + indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices") + values = tf.placeholder(shape=values_np.shape, dtype=values_np.dtype, name="values") + prev_shape = tf.placeholder( + shape=prev_shape_np.shape, dtype=prev_shape_np.dtype, name="prev_shape" + ) + new_shape = tf.placeholder( + shape=new_shape_np.shape, dtype=new_shape_np.dtype, name="new_shape" + ) + sp_input = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=prev_shape) + + _ = tf.sparse.reshape(sp_input, new_shape, name="sparse_reshape") + compare_tf_with_tvm( + [indices_np, values_np, prev_shape_np, new_shape_np], + [indices.name, values.name, prev_shape.name, new_shape.name], + ["sparse_reshape:0", "sparse_reshape:1", "sparse_reshape/Identity:0"], + mode="vm", + ) + + +@pytest.mark.parametrize( + "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", + [ + ( + np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6], dtype=np.int64), + np.array([9, -1], dtype=np.int64), + ), + ( + np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([9, 4], dtype=np.int64), + np.array([2, -1, 6], dtype=np.int64), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([9, 4], dtype=np.int64), + np.array([-1], dtype=np.int64), + ), + ( + np.array([[0], [5], [10], [20], [24]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([25], dtype=np.int64), + np.array([5, 5], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, -1], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([250, 40], dtype=np.int64), + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_forward_sparse_reshape( + sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn +): + """ sparse_reshape op test""" + ################################################################### + # + # In order to create a SparseTensor, it requires 3 input as below: + # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + # + # Above Sparse can be represented in Dense as below : + # [[1, 0, 0, 0] + # [0, 0, 2, 0] + # [0, 0, 0, 0]] + # + # ------------------------------------------------------------------ + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn) + + ####################################################################### # StridedSlice # ------------ diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 625c47240326..09ad0fd3c15d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1311,6 +1311,184 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) +@tvm.testing.uses_gpu +@pytest.mark.parametrize( + "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", + [ + ( + np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6], dtype=np.int64), + np.array([9, -1], dtype=np.int64), + ), + ( + np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([9, 4], dtype=np.int64), + np.array([2, -1, 6], dtype=np.int64), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([9, 4], dtype=np.int64), + np.array([-1], dtype=np.int64), + ), + ( + np.array([[0], [5], [10], [20], [24]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([25], dtype=np.int64), + np.array([5, 5], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, -1], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([250, 40], dtype=np.int64), + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn): + def ref_sparse_reshape( + sparse_indices: np.ndarray, + prev_shape: np.ndarray, + new_shape: np.ndarray, + ): + """ + This function calculates the expected output of sparseshape operator given the inputs. + """ + + new_sparse_indices = np.ones( + (sparse_indices.shape[0], new_shape.shape[0]), dtype=sparse_indices.dtype + ) + multipliers = np.ones(prev_shape.shape[0]) + dividers = np.ones(new_shape.shape[0]) + total_ele = np.prod(prev_shape) + division_total_ele = 1 + for i in range(new_shape.shape[0]): + if new_shape[i] == -1: + continue + division_total_ele *= new_shape[i] + for i in range(prev_shape.shape[0] - 2, -1, -1): + multipliers[i] = prev_shape[i + 1] * multipliers[i + 1] + + for i in range(len(new_shape)): + if new_shape[i] == -1: + new_shape[i] = total_ele // division_total_ele + + if np.array_equal(prev_shape, new_shape): + return sparse_indices, prev_shape + + for i in range(new_shape.shape[0] - 2, -1, -1): + dividers[i] = new_shape[i + 1] * dividers[i + 1] + + for row_num, sparse_row in enumerate(sparse_indices): + flat_idx = 0 + if len(sparse_indices.shape) != 1: + for i, ele in enumerate(sparse_row): + flat_idx += sparse_row[i] * multipliers[i] + else: + flat_idx += sparse_row + if len(new_sparse_indices.shape) != 1: + for i in range(new_sparse_indices.shape[1]): + new_sparse_indices[row_num][i] = flat_idx // dividers[i] + flat_idx = flat_idx % dividers[i] + else: + new_sparse_indices[row_num] = flat_idx + + return new_sparse_indices, new_shape + + def verify_sparse_reshape( + sparse_indices_np: np.ndarray, + sparse_values_np: np.ndarray, + prev_shape_np: np.ndarray, + new_shape_np: np.ndarray, + ): + """ + This function verifies the relay output of sparse_reshape with its expected output. + """ + if use_dyn: + sparse_indices = relay.var( + "sparse_indices", + shape=[relay.Any(), relay.Any()], + dtype=str(sparse_indices_np.dtype), + ) + prev_shape = relay.var( + "prev_shape", + shape=[relay.Any()], + dtype=str(prev_shape_np.dtype), + ) + new_shape = relay.var( + "new_shape", + shape=[relay.Any()], + dtype=str(new_shape_np.dtype), + ) + else: + sparse_indices = relay.var( + "sparse_indices", + relay.TensorType(sparse_indices_np.shape, str(sparse_indices_np.dtype)), + ) + prev_shape = relay.var( + "prev_shape", relay.TensorType(prev_shape_np.shape, str(prev_shape_np.dtype)) + ) + new_shape = relay.var( + "new_shape", relay.TensorType(new_shape_np.shape, str(new_shape_np.dtype)) + ) + z = relay.op.sparse_reshape(sparse_indices, prev_shape, new_shape).astuple() + + func = relay.Function([sparse_indices, prev_shape, new_shape], z) + + ref_res = ref_sparse_reshape(sparse_indices_np, prev_shape_np, new_shape_np) + + verify_func( + func, + [sparse_indices_np, prev_shape_np, new_shape_np], + ref_res, + ) + + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + +def verify_func(func, data, ref_res, target_ctx=tvm.testing.enabled_targets()): + assert isinstance(data, list) + for target, ctx in target_ctx: + for kind in ["vm"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(*data) + if isinstance(op_res, tvm.runtime.container.ADT): + assert len(op_res) == len( + ref_res + ), "Outputs from TVM and Python implementation must be equal " + + for op_result, ref_result in zip(op_res, ref_res): + tvm.testing.assert_allclose(op_result.asnumpy(), ref_result, rtol=1e-5) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + relay.backend.compile_engine.get().clear() + + @tvm.testing.uses_gpu def test_adv_index(): def verify_adv_index(data_shape, index_shapes): From a4f3d73a1c9dc3e04833b56c511a84a28e915e80 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 06:28:11 +0000 Subject: [PATCH 02/25] Done --- python/tvm/topi/generic/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index a7761da40a59..6746e6ad3979 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -73,4 +73,4 @@ def schedule_sparse_fill_empty_rows(outs): def schedule_sparse_reshape(outs): - return _default_schedule(outs, False) \ No newline at end of file + return _default_schedule(outs, False) From e818cedfbf5b63ce1a6bad68da399000acec70b4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 06:29:27 +0000 Subject: [PATCH 03/25] Format --- python/tvm/relay/op/_transform.py | 1 - src/relay/op/tensor/transform.cc | 7 ------- 2 files changed, 8 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 4ce9679c2881..a781bc62b373 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -65,7 +65,6 @@ _reg.register_injective_schedule("sparse_to_dense") _reg.register_injective_schedule("matrix_set_diag") _reg.register_injective_schedule("adv_index") -# _reg.register_injective_schedule("sparse_reshape") # concatenate diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1b8f467a5e49..86212419785d 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1654,12 +1654,6 @@ bool SparseReshapeRel(const Array& types, int num_inputs, const Attrs& att return true; } -// Array SparseReshapeCompute(const Attrs& attrs, const Array& inputs, -// const Type& out_type) { -// ICHECK_EQ(inputs.size(), 3) << "SparseReshapeCompute expects 3 inputs but " << inputs.size() << "provided"; -// return {topi::SparseReshape(inputs[0], inputs[1], inputs[2])}; -// } - Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) { static const Op& op = Op::Get("sparse_reshape"); return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {}); @@ -1677,7 +1671,6 @@ RELAY_REGISTER_OP("sparse_reshape") .add_type_rel("sparse_reshape", SparseReshapeRel) .set_attr("TOpPattern", kInjective) .set_support_level(3); - // .set_attr("FTVMCompute", SparseReshapeCompute); // meshgrid operator TVM_REGISTER_NODE_TYPE(MeshgridAttrs); From afc5d52e0c3490c4f994ebee2b9b1fe813afa3da Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 08:03:08 +0000 Subject: [PATCH 04/25] Add empty tests --- .../frontend/tensorflow/test_forward.py | 35 ++++++++++++++++++- tests/python/relay/test_op_level3.py | 33 +++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 3c19263af830..e56d47e2541a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1950,11 +1950,29 @@ def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, use @pytest.mark.parametrize( "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", [ + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, -1], dtype=np.int64), + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, 2], dtype=np.int64), + ), + ( + np.ones((0, 2), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([3, 6], dtype=np.int64), + np.array([-1, 2], dtype=np.int64), + ), ( np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), np.array([7, 5, 6, 3, 9], dtype=np.int64), np.array([2, 3, 6], dtype=np.int64), - np.array([9, -1], dtype=np.int64), + np.array([-1, 9], dtype=np.int64), ), ( np.array( @@ -1965,6 +1983,21 @@ def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, use np.array([2, 3, 6, 7], dtype=np.int64), np.array([9, -1, 7], dtype=np.int64), ), + ( + np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 2, 3], + [0, 1, 0, 3, 5], + [1, 0, 0, 4, 6], + [1, 2, 3, 6, 8], + ], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7, 9], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), ( np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), np.array([7, 5, 6, 3, 9], dtype=np.int64), diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 09ad0fd3c15d..8ed19fb446f2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1315,6 +1315,24 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ @pytest.mark.parametrize( "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", [ + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, -1], dtype=np.int64), + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, 2], dtype=np.int64), + ), + ( + np.ones((0, 2), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([3, 6], dtype=np.int64), + np.array([-1, 2], dtype=np.int64), + ), ( np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), np.array([7, 5, 6, 3, 9], dtype=np.int64), @@ -1330,6 +1348,21 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([2, 3, 6, 7], dtype=np.int64), np.array([9, -1, 7], dtype=np.int64), ), + ( + np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 2, 3], + [0, 1, 0, 3, 5], + [1, 0, 0, 4, 6], + [1, 2, 3, 6, 8], + ], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7, 9], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), ( np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), np.array([7, 5, 6, 3, 9], dtype=np.int64), From fdc275e7913dddbf343245abb29ade5018bf0c26 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 08:05:02 +0000 Subject: [PATCH 05/25] Formatting --- python/tvm/topi/sparse_reshape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 9b432db1e9ef..6088b43b2bff 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate -from ..te import extern, hybrid, div, floordiv, floormod +from ..tir import decl_buffer, ir_builder, Cast +from ..te import extern, div, floordiv, floormod def sparse_reshape( From bc584e5ab97962531a90409c2709b7612fa5f13a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 08:14:33 +0000 Subject: [PATCH 06/25] SanityCheck --- python/tvm/relay/op/_transform.py | 2 +- python/tvm/topi/sparse_reshape.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a781bc62b373..c6419ebabe53 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -120,7 +120,7 @@ def compute_sparse_fill_empty_rows(attrs, inputs, output_type): def compute_reshape(attrs, inputs, output_type): """Compute definition of sparse_reshape""" - return topi.reshape( + return topi.sparse_reshape( inputs[0], inputs[1], inputs[2], diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 6088b43b2bff..84b87a6cf134 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -27,6 +27,42 @@ def sparse_reshape( new_sparse_indices_shape, new_shape_shape, ): + """ + Reshape a Sparse Tensor + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + prev_shape = [2, 3, 4] + new_shape = [9, -1] + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + def gen_ir( sparse_indices_ptr, prev_shape_ptr, From 8896133efeebde8fcb332c01a20ee26edecb1503 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 08:21:05 +0000 Subject: [PATCH 07/25] formatting documentation --- src/relay/op/tensor/transform.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 86212419785d..8772bed02638 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1665,9 +1665,15 @@ RELAY_REGISTER_OP("sparse_reshape") .describe(R"code(Return new sparse indices of the reshaped tensor )code" TVM_ADD_FILELINE) .set_num_inputs(3) - .add_argument("sparse_indices", "Tensor", "The first tensor") - .add_argument("prev_shape", "Tensor", "The second tensor") - .add_argument("new_shape", "Tensor", "The third tensor") + .add_argument("sparse_indices", "Tensor", + "A 2-D tensor of shape [N, ndims], which specifies the indices of the" + "elements in the sparse tensor that contain nonzero values. COO Format") + .add_argument("prev_shape", "Tensor", + "A 1-D tensor of shape [ndims], which specifies the previous dense shape of the" + "sparse tensor") + .add_argument("new_shape", "Tensor", + "A 1-D tensor of shape [ndims], which specifies the desired dense shape of the" + "sparse tensor") .add_type_rel("sparse_reshape", SparseReshapeRel) .set_attr("TOpPattern", kInjective) .set_support_level(3); From 96169224b1a07bde1bc6c6d944d932a372157099 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 08:22:33 +0000 Subject: [PATCH 08/25] Documentation --- python/tvm/relay/op/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f1d3963ccad7..bb5f7d6f1c1d 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1412,7 +1412,8 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v def sparse_reshape(sparse_indices, prev_shape, new_shape): """ - Reshape a Sparse Tensor + Reshape a Sparse Tensor. The sparse array is in COO format. + Parameters ---------- sparse_indices : relay.Expr From 6577d2a4640cd6d6c1ebc5e0d0efdc1beadbdcad Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 08:43:58 +0000 Subject: [PATCH 09/25] Only Enable CPU --- tests/python/relay/test_op_level3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 8ed19fb446f2..99528a395bb9 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1498,6 +1498,7 @@ def verify_sparse_reshape( func, [sparse_indices_np, prev_shape_np, new_shape_np], ref_res, + [("llvm", tvm.cpu())], ) verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) From f6d386d7408ff96aee0c74fafbfe9a56434a2ac1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 23:33:15 +0000 Subject: [PATCH 10/25] Add support for CUDA --- python/tvm/relay/op/strategy/cuda.py | 11 ++ python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/sparse_reshape.py | 221 +++++++++++++++++++++++++ tests/python/relay/test_op_level3.py | 40 ++--- 4 files changed, 253 insertions(+), 20 deletions(-) create mode 100644 python/tvm/topi/cuda/sparse_reshape.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index cb4688c4889e..30fe7741890d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -764,6 +764,17 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target): return strategy +@sparse_reshape_strategy.register(["cuda", "gpu"]) +def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_reshape(topi.cuda.sparse_reshape), + wrap_topi_schedule(topi.cuda.schedule_sparse_reshape), + name="sparse_reshape.cuda", + ) + return strategy + + @sparse_dense_padded_strategy.register(["cuda", "gpu"]) def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): """sparse dense cuda strategy""" diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index bf3582c01d4f..a2e11e8c9f47 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -58,3 +58,4 @@ from . import tensorcore_alter_op from .argwhere import * from .scan import * +from .sparse_reshape import * diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py new file mode 100644 index 000000000000..92827a0f85ce --- /dev/null +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Scatter operator""" +from ...tir import decl_buffer, ir_builder, Cast +from tvm import te +from ...te import extern, div, floordiv, floormod + + +def sparse_reshape( + sparse_indices, + prev_shape, + new_shape, + new_sparse_indices_shape, + new_shape_shape, +): + """ + Reshape a Sparse Tensor + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + prev_shape = [2, 3, 4] + new_shape = [9, -1] + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + + def gen_ir( + sparse_indices_ptr, + prev_shape_ptr, + new_shape_ptr, + new_sparse_indices_ptr, + out_new_shape_ptr, + ): + ib = ir_builder.create() + + sparse_indices = ib.buffer_ptr(sparse_indices_ptr) + prev_shape = ib.buffer_ptr(prev_shape_ptr) + + new_shape = ib.buffer_ptr(new_shape_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + new_sparse_indices = ib.buffer_ptr(new_sparse_indices_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + + prev_shape_size = prev_shape_ptr.shape[0] + new_shape_size = new_shape_ptr.shape[0] + + multipliers = ib.allocate("int64", (prev_shape_size,), name="multipliers", scope="local") + dividers = ib.allocate("int64", (new_shape_size,), name="dividers", scope="local") + flattened_indices = ib.allocate( + "int64", (sparse_indices_ptr.shape[0],), name="flattened_indices", scope="local" + ) + + with ib.new_scope(): + + nthread_tx = 1 + nthread_bx = 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + total_ele = ib.allocate("int64", (1,), name="total_ele", scope="local") + total_ele[0] = prev_shape[0] + + # Cumulative Reverse Exclusive Multiply + multipliers[prev_shape_size - 1] = Cast("int64", 1) + with ib.for_range(0, prev_shape_size - 1) as i_: + i = i_ + 1 + multipliers[prev_shape_size - 1 - i] = ( + prev_shape[prev_shape_size - i] * multipliers[prev_shape_size - i] + ) + total_ele[0] *= prev_shape[prev_shape_size - i] + + division_total_ele = ib.allocate( + "int64", (1,), name="division_total_ele", scope="local" + ) + division_total_ele[0] = Cast("int64", 1) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] != -1): + division_total_ele[0] *= new_shape[i] + + # Compute true output shape (replace negative ones) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] == -1): + # if Cast("int64", new_shape[i]) == Cast("int64", -1): + out_new_shape[i] = Cast("int64", div(total_ele[0], division_total_ele[0])) + with ib.else_scope(): + out_new_shape[i] = new_shape[i] + + equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + + # Check if prev_shape and new_shape are equal + equal_shape[0] = True + with ib.if_scope(prev_shape_size == new_shape_size): + with ib.for_range(0, prev_shape_size) as i: + with ib.if_scope(prev_shape[i] != out_new_shape[i]): + equal_shape[0] = False + with ib.else_scope(): + equal_shape[0] = False + + # Return same inputs if shapes are equal + with ib.if_scope(equal_shape[0]): + with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[i, j] = sparse_indices[i, j] + + # Else compute new_sparse_indices + with ib.else_scope(): + dividers[new_shape_size - 1] = Cast("int64", 1) + with ib.for_range(0, new_shape_size - 1) as i_: + i = i_ + 1 + dividers[new_shape_size - 1 - i] = ( + dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] + ) + + with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + flattened_indices[i] = Cast("int64", 0) + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + flattened_indices[i] += sparse_indices[i, j] * multipliers[j] + + with ib.for_range(0, new_sparse_indices_ptr.shape[0]) as i: + current_element = ib.allocate( + "int64", (1,), name="current_element", scope="local" + ) + current_element[0] = flattened_indices[i] + + with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[i, j] = Cast( + "int64", floordiv(current_element[0], dividers[j]) + ) + current_element[0] = floormod(current_element[0], dividers[j]) + + return ib.get() + + new_sparse_indices_buf = decl_buffer( + new_sparse_indices_shape, "int64", "new_sparse_indices_buf" + ) + new_shape_buf = decl_buffer(new_shape_shape, "int64", "new_shape_buf") + + return extern( + [new_sparse_indices_shape, new_shape_shape], + [sparse_indices, prev_shape, new_shape], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), + dtype="int64", + out_buffers=[new_sparse_indices_buf, new_shape_buf], + name="sparse_reshape", + tag="sparse_reshape", + ) + + +def _default_schedule(outs): + """Default schedule for gpu.""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + traverse(outs[0].op) + return s + + +def schedule_sparse_reshape(outs): + """Schedule for Sparse Reshape + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of nms + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs) \ No newline at end of file diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 99528a395bb9..63d14cbd77fc 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1315,24 +1315,6 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ @pytest.mark.parametrize( "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", [ - ( - np.ones((0, 1), dtype=np.int64), - np.array([], dtype=np.int64), - np.array([4], dtype=np.int64), - np.array([2, -1], dtype=np.int64), - ), - ( - np.ones((0, 1), dtype=np.int64), - np.array([], dtype=np.int64), - np.array([4], dtype=np.int64), - np.array([2, 2], dtype=np.int64), - ), - ( - np.ones((0, 2), dtype=np.int64), - np.array([], dtype=np.int64), - np.array([3, 6], dtype=np.int64), - np.array([-1, 2], dtype=np.int64), - ), ( np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), np.array([7, 5, 6, 3, 9], dtype=np.int64), @@ -1399,9 +1381,27 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([500, 20], dtype=np.int64), np.array([250, 40], dtype=np.int64), ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, -1], dtype=np.int64), + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, 2], dtype=np.int64), + ), + ( + np.ones((0, 2), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([3, 6], dtype=np.int64), + np.array([-1, 2], dtype=np.int64), + ), ], ) -@pytest.mark.parametrize("use_dyn", [True, False]) +@pytest.mark.parametrize("use_dyn", [True]) def test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn): def ref_sparse_reshape( sparse_indices: np.ndarray, @@ -1498,7 +1498,7 @@ def verify_sparse_reshape( func, [sparse_indices_np, prev_shape_np, new_shape_np], ref_res, - [("llvm", tvm.cpu())], + [("cuda", tvm.gpu(0))], ) verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) From eadacdbbaf7eb1d4bafaa2fab878c970cd8e272a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 23:34:34 +0000 Subject: [PATCH 11/25] Stuff --- python/tvm/topi/cuda/sparse_reshape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 92827a0f85ce..ce2fe7da12d2 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -218,4 +218,4 @@ def schedule_sparse_reshape(outs): s: Schedule The computation schedule for the op. """ - return _default_schedule(outs) \ No newline at end of file + return _default_schedule(outs) From e08b319624c9348a3fff9a7790ecb961852afe40 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Feb 2021 23:34:56 +0000 Subject: [PATCH 12/25] Add Dynamic Support --- tests/python/relay/test_op_level3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 63d14cbd77fc..347e9e12bd9e 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1401,7 +1401,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ ), ], ) -@pytest.mark.parametrize("use_dyn", [True]) +@pytest.mark.parametrize("use_dyn", [True, False]) def test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn): def ref_sparse_reshape( sparse_indices: np.ndarray, From 09f70c2f77b03c59d9638b0088166b48aab28ad9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 20 Feb 2021 00:16:26 +0000 Subject: [PATCH 13/25] Parallelize GPU Impl --- python/tvm/topi/cuda/sparse_reshape.py | 49 ++++++++++++++++---------- python/tvm/topi/sparse_reshape.py | 5 ++- tests/python/relay/test_op_level3.py | 1 - 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index ce2fe7da12d2..3c6b57d7fcd5 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks -"""Scatter operator""" +"""Sparse_Reshape operator""" +import tvm from ...tir import decl_buffer, ir_builder, Cast from tvm import te from ...te import extern, div, floordiv, floormod +from ..utils import ceil_div def sparse_reshape( @@ -89,7 +91,10 @@ def gen_ir( flattened_indices = ib.allocate( "int64", (sparse_indices_ptr.shape[0],), name="flattened_indices", scope="local" ) - + total_ele = ib.allocate("int64", (1,), name="total_ele", scope="local") + division_total_ele = ib.allocate("int64", (1,), name="division_total_ele", scope="local") + equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = 1 @@ -99,7 +104,6 @@ def gen_ir( ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) - total_ele = ib.allocate("int64", (1,), name="total_ele", scope="local") total_ele[0] = prev_shape[0] # Cumulative Reverse Exclusive Multiply @@ -111,9 +115,6 @@ def gen_ir( ) total_ele[0] *= prev_shape[prev_shape_size - i] - division_total_ele = ib.allocate( - "int64", (1,), name="division_total_ele", scope="local" - ) division_total_ele[0] = Cast("int64", 1) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] != -1): @@ -122,13 +123,10 @@ def gen_ir( # Compute true output shape (replace negative ones) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] == -1): - # if Cast("int64", new_shape[i]) == Cast("int64", -1): out_new_shape[i] = Cast("int64", div(total_ele[0], division_total_ele[0])) with ib.else_scope(): out_new_shape[i] = new_shape[i] - equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") - # Check if prev_shape and new_shape are equal equal_shape[0] = True with ib.if_scope(prev_shape_size == new_shape_size): @@ -139,10 +137,21 @@ def gen_ir( equal_shape[0] = False # Return same inputs if shapes are equal + with ib.new_scope(): + + nthread_tx = max_threads + nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + row_number = bx * max_threads + tx + with ib.if_scope(equal_shape[0]): - with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: - new_sparse_indices[i, j] = sparse_indices[i, j] + new_sparse_indices[row_number, j] = sparse_indices[row_number, j] # Else compute new_sparse_indices with ib.else_scope(): @@ -153,19 +162,21 @@ def gen_ir( dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] ) - with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: - flattened_indices[i] = Cast("int64", 0) + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + flattened_indices[row_number] = Cast("int64", 0) with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: - flattened_indices[i] += sparse_indices[i, j] * multipliers[j] + flattened_indices[row_number] += ( + sparse_indices[row_number, j] * multipliers[j] + ) - with ib.for_range(0, new_sparse_indices_ptr.shape[0]) as i: + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): current_element = ib.allocate( "int64", (1,), name="current_element", scope="local" ) - current_element[0] = flattened_indices[i] + current_element[0] = flattened_indices[row_number] with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: - new_sparse_indices[i, j] = Cast( + new_sparse_indices[row_number, j] = Cast( "int64", floordiv(current_element[0], dividers[j]) ) current_element[0] = floormod(current_element[0], dividers[j]) @@ -183,8 +194,8 @@ def gen_ir( lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), dtype="int64", out_buffers=[new_sparse_indices_buf, new_shape_buf], - name="sparse_reshape", - tag="sparse_reshape", + name="sparse_reshape_cuda", + tag="sparse_reshape_cuda", ) diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 84b87a6cf134..26d31fe31c52 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -110,7 +110,6 @@ def gen_ir( # Compute true output shape (replace negative ones) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] == -1): - # if Cast("int64", new_shape[i]) == Cast("int64", -1): out_new_shape[i] = Cast("int64", div(total_ele[0], division_total_ele[0])) with ib.else_scope(): out_new_shape[i] = new_shape[i] @@ -169,6 +168,6 @@ def gen_ir( lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), dtype="int64", out_buffers=[new_sparse_indices_buf, new_shape_buf], - name="sparse_reshape", - tag="sparse_reshape", + name="sparse_reshape_cpu", + tag="sparse_reshape_cpu", ) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 347e9e12bd9e..78379f32633b 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1498,7 +1498,6 @@ def verify_sparse_reshape( func, [sparse_indices_np, prev_shape_np, new_shape_np], ref_res, - [("cuda", tvm.gpu(0))], ) verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) From 84ed9666afd9959545516eb890fa380f1c57ca6c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 20 Feb 2021 00:17:44 +0000 Subject: [PATCH 14/25] Documentation --- python/tvm/topi/sparse_reshape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 26d31fe31c52..671a03b3b499 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks -"""Scatter operator""" +"""Sparse_Reshape operator""" from ..tir import decl_buffer, ir_builder, Cast from ..te import extern, div, floordiv, floormod From df488bd0523c01dd54c72363b628e9fb76606d6c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 20 Feb 2021 00:20:19 +0000 Subject: [PATCH 15/25] Documentation --- python/tvm/topi/cuda/sparse_reshape.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 3c6b57d7fcd5..98009d130b0f 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -96,7 +96,8 @@ def gen_ir( equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): - + # The computation in this block is very very miniscule since we are just iterating over + # shape tensors which are very small (< 10) and there is no need of parallelization nthread_tx = 1 nthread_bx = 1 tx = te.thread_axis("threadIdx.x") From 4f9ebd9d5c29ddd3c628d704c6221af9de9d6b6b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 20 Feb 2021 00:55:48 +0000 Subject: [PATCH 16/25] Import --- python/tvm/topi/cuda/sparse_reshape.py | 77 +++++++++++++++++++++----- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 98009d130b0f..5904b66d017e 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -17,8 +17,8 @@ # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Sparse_Reshape operator""" import tvm -from ...tir import decl_buffer, ir_builder, Cast from tvm import te +from ...tir import decl_buffer, ir_builder, Cast from ...te import extern, div, floordiv, floormod from ..utils import ceil_div @@ -138,24 +138,32 @@ def gen_ir( equal_shape[0] = False # Return same inputs if shapes are equal - with ib.new_scope(): - - nthread_tx = max_threads - nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) + with ib.if_scope(equal_shape[0]): + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) - row_number = bx * max_threads + tx - - with ib.if_scope(equal_shape[0]): + row_number = bx * max_threads + tx with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: new_sparse_indices[row_number, j] = sparse_indices[row_number, j] # Else compute new_sparse_indices - with ib.else_scope(): + with ib.else_scope(): + with ib.new_scope(): + + nthread_tx = max_threads + nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + row_number = bx * max_threads + tx dividers[new_shape_size - 1] = Cast("int64", 1) with ib.for_range(0, new_shape_size - 1) as i_: i = i_ + 1 @@ -181,6 +189,49 @@ def gen_ir( "int64", floordiv(current_element[0], dividers[j]) ) current_element[0] = floormod(current_element[0], dividers[j]) + # with ib.new_scope(): + + # nthread_tx = max_threads + # nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) + # tx = te.thread_axis("threadIdx.x") + # bx = te.thread_axis("blockIdx.x") + # ib.scope_attr(tx, "thread_extent", nthread_tx) + # ib.scope_attr(bx, "thread_extent", nthread_bx) + + # row_number = bx * max_threads + tx + + # with ib.if_scope(equal_shape[0]): + # with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + # with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + # new_sparse_indices[row_number, j] = sparse_indices[row_number, j] + + # # Else compute new_sparse_indices + # with ib.else_scope(): + # dividers[new_shape_size - 1] = Cast("int64", 1) + # with ib.for_range(0, new_shape_size - 1) as i_: + # i = i_ + 1 + # dividers[new_shape_size - 1 - i] = ( + # dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] + # ) + + # with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + # flattened_indices[row_number] = Cast("int64", 0) + # with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + # flattened_indices[row_number] += ( + # sparse_indices[row_number, j] * multipliers[j] + # ) + + # with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + # current_element = ib.allocate( + # "int64", (1,), name="current_element", scope="local" + # ) + # current_element[0] = flattened_indices[row_number] + + # with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: + # new_sparse_indices[row_number, j] = Cast( + # "int64", floordiv(current_element[0], dividers[j]) + # ) + # current_element[0] = floormod(current_element[0], dividers[j]) return ib.get() From ccdb67ce31a397d5118a17edd97f9d4ff3d35e60 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 20 Feb 2021 00:56:43 +0000 Subject: [PATCH 17/25] Import --- python/tvm/topi/cuda/sparse_reshape.py | 70 ++++++-------------------- 1 file changed, 16 insertions(+), 54 deletions(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 5904b66d017e..33b6b886ad58 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -137,7 +137,6 @@ def gen_ir( with ib.else_scope(): equal_shape[0] = False - # Return same inputs if shapes are equal with ib.if_scope(equal_shape[0]): with ib.new_scope(): nthread_tx = max_threads @@ -152,18 +151,24 @@ def gen_ir( with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: new_sparse_indices[row_number, j] = sparse_indices[row_number, j] - # Else compute new_sparse_indices - with ib.else_scope(): - with ib.new_scope(): + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) - nthread_tx = max_threads - nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) + row_number = bx * max_threads + tx - row_number = bx * max_threads + tx + # Return same inputs if shapes are equal + with ib.if_scope(equal_shape[0]): + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[row_number, j] = sparse_indices[row_number, j] + + # Else compute new_sparse_indices + with ib.else_scope(): dividers[new_shape_size - 1] = Cast("int64", 1) with ib.for_range(0, new_shape_size - 1) as i_: i = i_ + 1 @@ -189,49 +194,6 @@ def gen_ir( "int64", floordiv(current_element[0], dividers[j]) ) current_element[0] = floormod(current_element[0], dividers[j]) - # with ib.new_scope(): - - # nthread_tx = max_threads - # nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) - # tx = te.thread_axis("threadIdx.x") - # bx = te.thread_axis("blockIdx.x") - # ib.scope_attr(tx, "thread_extent", nthread_tx) - # ib.scope_attr(bx, "thread_extent", nthread_bx) - - # row_number = bx * max_threads + tx - - # with ib.if_scope(equal_shape[0]): - # with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): - # with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: - # new_sparse_indices[row_number, j] = sparse_indices[row_number, j] - - # # Else compute new_sparse_indices - # with ib.else_scope(): - # dividers[new_shape_size - 1] = Cast("int64", 1) - # with ib.for_range(0, new_shape_size - 1) as i_: - # i = i_ + 1 - # dividers[new_shape_size - 1 - i] = ( - # dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] - # ) - - # with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): - # flattened_indices[row_number] = Cast("int64", 0) - # with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: - # flattened_indices[row_number] += ( - # sparse_indices[row_number, j] * multipliers[j] - # ) - - # with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): - # current_element = ib.allocate( - # "int64", (1,), name="current_element", scope="local" - # ) - # current_element[0] = flattened_indices[row_number] - - # with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: - # new_sparse_indices[row_number, j] = Cast( - # "int64", floordiv(current_element[0], dividers[j]) - # ) - # current_element[0] = floormod(current_element[0], dividers[j]) return ib.get() From 0f8310df828ba93fb4a9eec6a7ee812b4a18a508 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 20 Feb 2021 00:59:56 +0000 Subject: [PATCH 18/25] Remove unnecessary code --- python/tvm/topi/cuda/sparse_reshape.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 33b6b886ad58..8947bb3dbeb3 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -137,20 +137,6 @@ def gen_ir( with ib.else_scope(): equal_shape[0] = False - with ib.if_scope(equal_shape[0]): - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - - row_number = bx * max_threads + tx - with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): - with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: - new_sparse_indices[row_number, j] = sparse_indices[row_number, j] - with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) From 87bbec07d8aeab30e17614dbc1c0ec2cb3b684b7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 Feb 2021 00:02:36 +0000 Subject: [PATCH 19/25] PR Comments --- python/tvm/topi/cuda/sparse_reshape.py | 42 ++++++++++++++++---------- python/tvm/topi/sparse_reshape.py | 42 +++++++++++++++++--------- tests/python/relay/test_op_level3.py | 19 ++++++++++-- 3 files changed, 70 insertions(+), 33 deletions(-) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 8947bb3dbeb3..bbfc85daacb9 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -86,14 +86,23 @@ def gen_ir( prev_shape_size = prev_shape_ptr.shape[0] new_shape_size = new_shape_ptr.shape[0] - multipliers = ib.allocate("int64", (prev_shape_size,), name="multipliers", scope="local") - dividers = ib.allocate("int64", (new_shape_size,), name="dividers", scope="local") + multipliers = ib.allocate( + new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="global" + ) + dividers = ib.allocate( + new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="global" + ) flattened_indices = ib.allocate( - "int64", (sparse_indices_ptr.shape[0],), name="flattened_indices", scope="local" + new_shape_ptr.dtype, + (sparse_indices_ptr.shape[0],), + name="flattened_indices", + scope="global", ) - total_ele = ib.allocate("int64", (1,), name="total_ele", scope="local") - division_total_ele = ib.allocate("int64", (1,), name="division_total_ele", scope="local") - equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="global") + division_total_ele = ib.allocate( + new_shape_ptr.dtype, (1,), name="division_total_ele", scope="global" + ) + equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="global") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): # The computation in this block is very very miniscule since we are just iterating over @@ -108,7 +117,7 @@ def gen_ir( total_ele[0] = prev_shape[0] # Cumulative Reverse Exclusive Multiply - multipliers[prev_shape_size - 1] = Cast("int64", 1) + multipliers[prev_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, prev_shape_size - 1) as i_: i = i_ + 1 multipliers[prev_shape_size - 1 - i] = ( @@ -116,7 +125,7 @@ def gen_ir( ) total_ele[0] *= prev_shape[prev_shape_size - i] - division_total_ele[0] = Cast("int64", 1) + division_total_ele[0] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] != -1): division_total_ele[0] *= new_shape[i] @@ -124,7 +133,9 @@ def gen_ir( # Compute true output shape (replace negative ones) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] == -1): - out_new_shape[i] = Cast("int64", div(total_ele[0], division_total_ele[0])) + out_new_shape[i] = Cast( + new_shape_ptr.dtype, div(total_ele[0], division_total_ele[0]) + ) with ib.else_scope(): out_new_shape[i] = new_shape[i] @@ -155,7 +166,7 @@ def gen_ir( # Else compute new_sparse_indices with ib.else_scope(): - dividers[new_shape_size - 1] = Cast("int64", 1) + dividers[new_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, new_shape_size - 1) as i_: i = i_ + 1 dividers[new_shape_size - 1 - i] = ( @@ -163,7 +174,7 @@ def gen_ir( ) with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): - flattened_indices[row_number] = Cast("int64", 0) + flattened_indices[row_number] = Cast(new_shape_ptr.dtype, 0) with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: flattened_indices[row_number] += ( sparse_indices[row_number, j] * multipliers[j] @@ -171,28 +182,27 @@ def gen_ir( with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): current_element = ib.allocate( - "int64", (1,), name="current_element", scope="local" + new_shape_ptr.dtype, (1,), name="current_element", scope="local" ) current_element[0] = flattened_indices[row_number] with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: new_sparse_indices[row_number, j] = Cast( - "int64", floordiv(current_element[0], dividers[j]) + sparse_indices_ptr.dtype, floordiv(current_element[0], dividers[j]) ) current_element[0] = floormod(current_element[0], dividers[j]) return ib.get() new_sparse_indices_buf = decl_buffer( - new_sparse_indices_shape, "int64", "new_sparse_indices_buf" + new_sparse_indices_shape, sparse_indices.dtype, "new_sparse_indices_buf" ) - new_shape_buf = decl_buffer(new_shape_shape, "int64", "new_shape_buf") + new_shape_buf = decl_buffer(new_shape_shape, prev_shape.dtype, "new_shape_buf") return extern( [new_sparse_indices_shape, new_shape_shape], [sparse_indices, prev_shape, new_shape], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), - dtype="int64", out_buffers=[new_sparse_indices_buf, new_shape_buf], name="sparse_reshape_cuda", tag="sparse_reshape_cuda", diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 671a03b3b499..7b4b2746a516 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -83,17 +83,24 @@ def gen_ir( prev_shape_size = prev_shape_ptr.shape[0] new_shape_size = new_shape_ptr.shape[0] - multipliers = ib.allocate("int64", (prev_shape_size,), name="multipliers", scope="local") - dividers = ib.allocate("int64", (new_shape_size,), name="dividers", scope="local") + multipliers = ib.allocate( + new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="local" + ) + dividers = ib.allocate( + new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="local" + ) flattened_indices = ib.allocate( - "int64", (sparse_indices_ptr.shape[0],), name="flattened_indices", scope="local" + new_shape_ptr.dtype, + (sparse_indices_ptr.shape[0],), + name="flattened_indices", + scope="local", ) - total_ele = ib.allocate("int64", (1,), name="total_ele", scope="local") + total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="local") total_ele[0] = prev_shape[0] # Cumulative Reverse Exclusive Multiply - multipliers[prev_shape_size - 1] = Cast("int64", 1) + multipliers[prev_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, prev_shape_size - 1) as i_: i = i_ + 1 multipliers[prev_shape_size - 1 - i] = ( @@ -101,8 +108,10 @@ def gen_ir( ) total_ele[0] *= prev_shape[prev_shape_size - i] - division_total_ele = ib.allocate("int64", (1,), name="division_total_ele", scope="local") - division_total_ele[0] = Cast("int64", 1) + division_total_ele = ib.allocate( + new_shape_ptr.dtype, (1,), name="division_total_ele", scope="local" + ) + division_total_ele[0] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] != -1): division_total_ele[0] *= new_shape[i] @@ -110,7 +119,9 @@ def gen_ir( # Compute true output shape (replace negative ones) with ib.for_range(0, new_shape_size) as i: with ib.if_scope(new_shape[i] == -1): - out_new_shape[i] = Cast("int64", div(total_ele[0], division_total_ele[0])) + out_new_shape[i] = Cast( + new_shape_ptr.dtype, div(total_ele[0], division_total_ele[0]) + ) with ib.else_scope(): out_new_shape[i] = new_shape[i] @@ -133,7 +144,7 @@ def gen_ir( # Else compute new_sparse_indices with ib.else_scope(): - dividers[new_shape_size - 1] = Cast("int64", 1) + dividers[new_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, new_shape_size - 1) as i_: i = i_ + 1 dividers[new_shape_size - 1 - i] = ( @@ -141,32 +152,33 @@ def gen_ir( ) with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: - flattened_indices[i] = Cast("int64", 0) + flattened_indices[i] = Cast(new_shape_ptr.dtype, 0) with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: flattened_indices[i] += sparse_indices[i, j] * multipliers[j] with ib.for_range(0, new_sparse_indices_ptr.shape[0]) as i: - current_element = ib.allocate("int64", (1,), name="current_element", scope="local") + current_element = ib.allocate( + new_shape_ptr.dtype, (1,), name="current_element", scope="local" + ) current_element[0] = flattened_indices[i] with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: new_sparse_indices[i, j] = Cast( - "int64", floordiv(current_element[0], dividers[j]) + sparse_indices_ptr.dtype, floordiv(current_element[0], dividers[j]) ) current_element[0] = floormod(current_element[0], dividers[j]) return ib.get() new_sparse_indices_buf = decl_buffer( - new_sparse_indices_shape, "int64", "new_sparse_indices_buf" + new_sparse_indices_shape, sparse_indices.dtype, "new_sparse_indices_buf" ) - new_shape_buf = decl_buffer(new_shape_shape, "int64", "new_shape_buf") + new_shape_buf = decl_buffer(new_shape_shape, prev_shape.dtype, "new_shape_buf") return extern( [new_sparse_indices_shape, new_shape_shape], [sparse_indices, prev_shape, new_shape], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), - dtype="int64", out_buffers=[new_sparse_indices_buf, new_shape_buf], name="sparse_reshape_cpu", tag="sparse_reshape_cpu", diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 78379f32633b..5bc87128cc71 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1401,8 +1401,11 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ ), ], ) +@pytest.mark.parametrize("dtype", [np.int32, np.int64]) @pytest.mark.parametrize("use_dyn", [True, False]) -def test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn): +def test_sparse_reshape( + sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, dtype, use_dyn +): def ref_sparse_reshape( sparse_indices: np.ndarray, prev_shape: np.ndarray, @@ -1493,14 +1496,26 @@ def verify_sparse_reshape( func = relay.Function([sparse_indices, prev_shape, new_shape], z) ref_res = ref_sparse_reshape(sparse_indices_np, prev_shape_np, new_shape_np) + outputs = run_infer_type(z) + new_sparse_indices_infer_type, new_shape_infer_type = ( + outputs.checked_type.fields[0].dtype, + outputs.checked_type.fields[1].dtype, + ) + assert new_sparse_indices_infer_type == sparse_indices_np.dtype + assert new_shape_infer_type == new_shape_np.dtype verify_func( func, [sparse_indices_np, prev_shape_np, new_shape_np], ref_res, ) - verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + verify_sparse_reshape( + sparse_indices_np.astype(dtype), + sparse_values_np.astype(dtype), + prev_shape_np.astype(dtype), + new_shape_np.astype(dtype), + ) def verify_func(func, data, ref_res, target_ctx=tvm.testing.enabled_targets()): From 479ea654ec4c1424fe43210baa55856f0a495c02 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 Feb 2021 00:50:53 +0000 Subject: [PATCH 20/25] Schedules --- python/tvm/relay/op/strategy/cuda.py | 2 +- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/topi/cuda/sparse_reshape.py | 33 ------------------------- python/tvm/topi/generic/search.py | 4 --- 4 files changed, 2 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 30fe7741890d..75ac49b7b268 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -769,7 +769,7 @@ def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_sparse_reshape(topi.cuda.sparse_reshape), - wrap_topi_schedule(topi.cuda.schedule_sparse_reshape), + wrap_topi_schedule(topi.generic.schedule_extern), name="sparse_reshape.cuda", ) return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 270db261ac15..0514538fd883 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1111,7 +1111,7 @@ def sparse_reshape_strategy(attrs, outs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_sparse_reshape(topi.sparse_reshape), - wrap_topi_schedule(topi.generic.schedule_sparse_reshape), + wrap_topi_schedule(topi.generic.schedule_extern), name="sparse_reshape.generic", ) return strategy diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index bbfc85daacb9..4476648e0aa4 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -207,36 +207,3 @@ def gen_ir( name="sparse_reshape_cuda", tag="sparse_reshape_cuda", ) - - -def _default_schedule(outs): - """Default schedule for gpu.""" - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def traverse(op): - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - - traverse(outs[0].op) - return s - - -def schedule_sparse_reshape(outs): - """Schedule for Sparse Reshape - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of nms - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for the op. - """ - return _default_schedule(outs) diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index 6746e6ad3979..5924d35def73 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -70,7 +70,3 @@ def schedule_scatter_add(outs): def schedule_sparse_fill_empty_rows(outs): return _default_schedule(outs, False) - - -def schedule_sparse_reshape(outs): - return _default_schedule(outs, False) From 37ff458bedea75b9f484b3c1e9dcda53646fe893 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 Feb 2021 00:55:27 +0000 Subject: [PATCH 21/25] Tests --- tests/python/relay/test_op_level3.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 5bc87128cc71..3aac27165db3 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1394,18 +1394,15 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([2, 2], dtype=np.int64), ), ( - np.ones((0, 2), dtype=np.int64), - np.array([], dtype=np.int64), - np.array([3, 6], dtype=np.int64), - np.array([-1, 2], dtype=np.int64), + np.ones((0, 2), dtype=np.int32), + np.array([], dtype=np.int32), + np.array([3, 6], dtype=np.int32), + np.array([-1, 2], dtype=np.int32), ), ], ) -@pytest.mark.parametrize("dtype", [np.int32, np.int64]) @pytest.mark.parametrize("use_dyn", [True, False]) -def test_sparse_reshape( - sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, dtype, use_dyn -): +def test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn): def ref_sparse_reshape( sparse_indices: np.ndarray, prev_shape: np.ndarray, @@ -1511,10 +1508,10 @@ def verify_sparse_reshape( ) verify_sparse_reshape( - sparse_indices_np.astype(dtype), - sparse_values_np.astype(dtype), - prev_shape_np.astype(dtype), - new_shape_np.astype(dtype), + sparse_indices_np, + sparse_values_np, + prev_shape_np, + new_shape_np, ) From b3bec7de1776f555ef74990397a2a7fdd0c60689 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 Feb 2021 00:56:21 +0000 Subject: [PATCH 22/25] Dtypes --- tests/python/relay/test_op_level3.py | 40 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 3aac27165db3..97ad478edcd3 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1316,10 +1316,10 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", [ ( - np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), - np.array([7, 5, 6, 3, 9], dtype=np.int64), - np.array([2, 3, 6], dtype=np.int64), - np.array([9, -1], dtype=np.int64), + np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([2, 3, 6], dtype=np.int32), + np.array([9, -1], dtype=np.int32), ), ( np.array( @@ -1346,10 +1346,10 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([9, -1, 7], dtype=np.int64), ), ( - np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), - np.array([7, 5, 6, 3, 9], dtype=np.int64), - np.array([9, 4], dtype=np.int64), - np.array([2, -1, 6], dtype=np.int64), + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([9, 4], dtype=np.int32), + np.array([2, -1, 6], dtype=np.int32), ), ( np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), @@ -1358,10 +1358,10 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([-1], dtype=np.int64), ), ( - np.array([[0], [5], [10], [20], [24]], dtype=np.int64), - np.array([7, 5, 6, 3, 9], dtype=np.int64), - np.array([25], dtype=np.int64), - np.array([5, 5], dtype=np.int64), + np.array([[0], [5], [10], [20], [24]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([25], dtype=np.int32), + np.array([5, 5], dtype=np.int32), ), ( np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), @@ -1370,10 +1370,10 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([500, 20], dtype=np.int64), ), ( - np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), - np.array([7, 5, 6, 3, 9], dtype=np.int64), - np.array([500, 20], dtype=np.int64), - np.array([500, -1], dtype=np.int64), + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([500, 20], dtype=np.int32), + np.array([500, -1], dtype=np.int32), ), ( np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), @@ -1382,10 +1382,10 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ np.array([250, 40], dtype=np.int64), ), ( - np.ones((0, 1), dtype=np.int64), - np.array([], dtype=np.int64), - np.array([4], dtype=np.int64), - np.array([2, -1], dtype=np.int64), + np.ones((0, 1), dtype=np.int32), + np.array([], dtype=np.int32), + np.array([4], dtype=np.int32), + np.array([2, -1], dtype=np.int32), ), ( np.ones((0, 1), dtype=np.int64), From 3ba3f1e7558666c52714bc7ae8ad86b80cea0ddd Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 Feb 2021 13:40:56 +0000 Subject: [PATCH 23/25] Black --- tests/python/frontend/tensorflow/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c2adc96d8ece..4bff5df4c453 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2094,6 +2094,7 @@ def test_forward_sparse_reshape( # ------------------------------------------------------------------ _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn) + # tensorflow.compat.v1.sparse_to_dense # --------------- def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): From c80561cb59c78ff83712e711b3ac26791e2f5cc6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 Feb 2021 21:54:41 +0000 Subject: [PATCH 24/25] Parallelize CPU --- python/tvm/topi/sparse_reshape.py | 6 +++--- tests/python/frontend/tensorflow/test_forward.py | 15 --------------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 7b4b2746a516..5535477e17c8 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -138,7 +138,7 @@ def gen_ir( # Return same inputs if shapes are equal with ib.if_scope(equal_shape[0]): - with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + with ib.for_range(0, sparse_indices_ptr.shape[0], kind="parallel") as i: with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: new_sparse_indices[i, j] = sparse_indices[i, j] @@ -151,12 +151,12 @@ def gen_ir( dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] ) - with ib.for_range(0, sparse_indices_ptr.shape[0]) as i: + with ib.for_range(0, sparse_indices_ptr.shape[0], kind="parallel") as i: flattened_indices[i] = Cast(new_shape_ptr.dtype, 0) with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: flattened_indices[i] += sparse_indices[i, j] * multipliers[j] - with ib.for_range(0, new_sparse_indices_ptr.shape[0]) as i: + with ib.for_range(0, new_sparse_indices_ptr.shape[0], kind="parallel") as i: current_element = ib.allocate( new_shape_ptr.dtype, (1,), name="current_element", scope="local" ) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 4bff5df4c453..35ca42d23f18 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2014,15 +2014,6 @@ def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, use np.array([2, 3, 6], dtype=np.int64), np.array([-1, 9], dtype=np.int64), ), - ( - np.array( - [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], - dtype=np.int64, - ), - np.array([7, 5, 6, 3, 9], dtype=np.int64), - np.array([2, 3, 6, 7], dtype=np.int64), - np.array([9, -1, 7], dtype=np.int64), - ), ( np.array( [ @@ -2038,12 +2029,6 @@ def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, use np.array([2, 3, 6, 7, 9], dtype=np.int64), np.array([9, -1, 7], dtype=np.int64), ), - ( - np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), - np.array([7, 5, 6, 3, 9], dtype=np.int64), - np.array([9, 4], dtype=np.int64), - np.array([2, -1, 6], dtype=np.int64), - ), ( np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), np.array([7, 5, 6, 3, 9], dtype=np.int64), From 651cb93b8eb5254f164bf06399af79f4adfff725 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 26 Feb 2021 12:03:03 +0000 Subject: [PATCH 25/25] CI error