diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 65f18c029441..20eb95ba7c00 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1157,6 +1157,15 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_reshape(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + new_indices, new_shape = get_relay_op("sparse_reshape")(inputs[0], inputs[1], inputs[2]) + return _expr.TupleWrapper(_expr.Tuple([new_indices, new_shape]), 2) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2650,6 +2659,7 @@ def _impl(inputs, attr, params, mod): "SparseToDense": _sparse_to_dense(), "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), "SparseFillEmptyRows": _sparse_fill_empty_rows(), + "SparseReshape": _sparse_reshape(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e9cf3d83eaeb..97f45278f073 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -66,6 +66,7 @@ _reg.register_injective_schedule("matrix_set_diag") _reg.register_injective_schedule("adv_index") + # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) @@ -114,6 +115,22 @@ def compute_sparse_fill_empty_rows(attrs, inputs, output_type): _reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy) +# sparse_reshape +@_reg.register_compute("sparse_reshape") +def compute_reshape(attrs, inputs, output_type): + """Compute definition of sparse_reshape""" + + return topi.sparse_reshape( + inputs[0], + inputs[1], + inputs[2], + output_type.fields[0].shape, + output_type.fields[1].shape, + ) + + +_reg.register_strategy("sparse_reshape", strategy.sparse_reshape_strategy) + # scatter_add @_reg.register_compute("scatter_add") def compute_scatter_add(attrs, inputs, output_type): @@ -515,6 +532,24 @@ def sparse_fill_empty_rows_func(attrs, inputs, _): return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2]) +@script +def _sparse_reshape_shape_func(sparse_indices_shape, prev_shape_shape, new_shape_shape): + indices_shape = output_tensor((2,), "int64") + indices_shape[0] = int64(sparse_indices_shape[0]) + indices_shape[1] = int64(new_shape_shape[0]) + shape_tensor = output_tensor((1,), "int64") + shape_tensor[0] = int64(new_shape_shape[0]) + return (indices_shape, shape_tensor) + + +@_reg.register_shape_func("sparse_reshape", False) +def sparse_reshape_shape_func(attrs, inputs, _): + """ + Shape func for sparse_reshape. + """ + return _sparse_reshape_shape_func(inputs[0], inputs[1], inputs[2]) + + @script def _layout_transform_shape_func( data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 3abc9c42b659..85bbab692574 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -764,6 +764,17 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target): return strategy +@sparse_reshape_strategy.register(["cuda", "gpu"]) +def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_reshape(topi.cuda.sparse_reshape), + wrap_topi_schedule(topi.generic.schedule_extern), + name="sparse_reshape.cuda", + ) + return strategy + + @sparse_dense_padded_strategy.register(["cuda", "gpu"]) def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): """sparse dense cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 8a2724dfb614..be86ea9d9184 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1105,6 +1105,33 @@ def _compute_sparse_fill_empty_rows(attrs, inputs, output_type): return _compute_sparse_fill_empty_rows +# sparse_reshape +@override_native_generic_func("sparse_reshape_strategy") +def sparse_reshape_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_reshape(topi.sparse_reshape), + wrap_topi_schedule(topi.generic.schedule_extern), + name="sparse_reshape.generic", + ) + return strategy + + +def wrap_compute_sparse_reshape(topi_compute): + """Wrap sparse_reshape compute""" + + def _compute_sparse_reshape(attrs, inputs, output_type): + return topi_compute( + inputs[0], + inputs[1], + inputs[2], + output_type.fields[0].shape, + output_type.fields[1].shape, + ) + + return _compute_sparse_reshape + + # roi_pool @generic_func def schedule_roi_pool(attrs, outs, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c0a0d31478ef..73508ddd2603 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1410,6 +1410,46 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v return Tuple((new_sparse_indices, new_sparse_values, empty_row_indicator)) +def sparse_reshape(sparse_indices, prev_shape, new_shape): + """ + Reshape a Sparse Tensor. The sparse array is in COO format. + + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + prev_shape = [2, 3, 4] + new_shape = [9, -1] + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2) + + def cumsum(data, axis=None, dtype=None, exclusive=None): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 63dc4bd4ab83..c196b33cf880 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -39,6 +39,7 @@ from .sort import * from .scatter import * from .sparse_fill_empty_rows import * +from .sparse_reshape import * from .scatter_add import * from .argwhere import * from .cumsum import * diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index df75c676fad3..52e64804d692 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -58,4 +58,5 @@ from . import tensorcore_alter_op from .argwhere import * from .scan import * +from .sparse_reshape import * from .unique import * diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py new file mode 100644 index 000000000000..4476648e0aa4 --- /dev/null +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Sparse_Reshape operator""" +import tvm +from tvm import te +from ...tir import decl_buffer, ir_builder, Cast +from ...te import extern, div, floordiv, floormod +from ..utils import ceil_div + + +def sparse_reshape( + sparse_indices, + prev_shape, + new_shape, + new_sparse_indices_shape, + new_shape_shape, +): + """ + Reshape a Sparse Tensor + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + prev_shape = [2, 3, 4] + new_shape = [9, -1] + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + + def gen_ir( + sparse_indices_ptr, + prev_shape_ptr, + new_shape_ptr, + new_sparse_indices_ptr, + out_new_shape_ptr, + ): + ib = ir_builder.create() + + sparse_indices = ib.buffer_ptr(sparse_indices_ptr) + prev_shape = ib.buffer_ptr(prev_shape_ptr) + + new_shape = ib.buffer_ptr(new_shape_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + new_sparse_indices = ib.buffer_ptr(new_sparse_indices_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + + prev_shape_size = prev_shape_ptr.shape[0] + new_shape_size = new_shape_ptr.shape[0] + + multipliers = ib.allocate( + new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="global" + ) + dividers = ib.allocate( + new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="global" + ) + flattened_indices = ib.allocate( + new_shape_ptr.dtype, + (sparse_indices_ptr.shape[0],), + name="flattened_indices", + scope="global", + ) + total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="global") + division_total_ele = ib.allocate( + new_shape_ptr.dtype, (1,), name="division_total_ele", scope="global" + ) + equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="global") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + # The computation in this block is very very miniscule since we are just iterating over + # shape tensors which are very small (< 10) and there is no need of parallelization + nthread_tx = 1 + nthread_bx = 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + total_ele[0] = prev_shape[0] + + # Cumulative Reverse Exclusive Multiply + multipliers[prev_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) + with ib.for_range(0, prev_shape_size - 1) as i_: + i = i_ + 1 + multipliers[prev_shape_size - 1 - i] = ( + prev_shape[prev_shape_size - i] * multipliers[prev_shape_size - i] + ) + total_ele[0] *= prev_shape[prev_shape_size - i] + + division_total_ele[0] = Cast(new_shape_ptr.dtype, 1) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] != -1): + division_total_ele[0] *= new_shape[i] + + # Compute true output shape (replace negative ones) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] == -1): + out_new_shape[i] = Cast( + new_shape_ptr.dtype, div(total_ele[0], division_total_ele[0]) + ) + with ib.else_scope(): + out_new_shape[i] = new_shape[i] + + # Check if prev_shape and new_shape are equal + equal_shape[0] = True + with ib.if_scope(prev_shape_size == new_shape_size): + with ib.for_range(0, prev_shape_size) as i: + with ib.if_scope(prev_shape[i] != out_new_shape[i]): + equal_shape[0] = False + with ib.else_scope(): + equal_shape[0] = False + + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(sparse_indices_ptr.shape[0], max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + row_number = bx * max_threads + tx + + # Return same inputs if shapes are equal + with ib.if_scope(equal_shape[0]): + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[row_number, j] = sparse_indices[row_number, j] + + # Else compute new_sparse_indices + with ib.else_scope(): + dividers[new_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) + with ib.for_range(0, new_shape_size - 1) as i_: + i = i_ + 1 + dividers[new_shape_size - 1 - i] = ( + dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] + ) + + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + flattened_indices[row_number] = Cast(new_shape_ptr.dtype, 0) + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + flattened_indices[row_number] += ( + sparse_indices[row_number, j] * multipliers[j] + ) + + with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): + current_element = ib.allocate( + new_shape_ptr.dtype, (1,), name="current_element", scope="local" + ) + current_element[0] = flattened_indices[row_number] + + with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[row_number, j] = Cast( + sparse_indices_ptr.dtype, floordiv(current_element[0], dividers[j]) + ) + current_element[0] = floormod(current_element[0], dividers[j]) + + return ib.get() + + new_sparse_indices_buf = decl_buffer( + new_sparse_indices_shape, sparse_indices.dtype, "new_sparse_indices_buf" + ) + new_shape_buf = decl_buffer(new_shape_shape, prev_shape.dtype, "new_shape_buf") + + return extern( + [new_sparse_indices_shape, new_shape_shape], + [sparse_indices, prev_shape, new_shape], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), + out_buffers=[new_sparse_indices_buf, new_shape_buf], + name="sparse_reshape_cuda", + tag="sparse_reshape_cuda", + ) diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py new file mode 100644 index 000000000000..5535477e17c8 --- /dev/null +++ b/python/tvm/topi/sparse_reshape.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Sparse_Reshape operator""" +from ..tir import decl_buffer, ir_builder, Cast +from ..te import extern, div, floordiv, floormod + + +def sparse_reshape( + sparse_indices, + prev_shape, + new_shape, + new_sparse_indices_shape, + new_shape_shape, +): + """ + Reshape a Sparse Tensor + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + prev_shape = [2, 3, 4] + new_shape = [9, -1] + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + + def gen_ir( + sparse_indices_ptr, + prev_shape_ptr, + new_shape_ptr, + new_sparse_indices_ptr, + out_new_shape_ptr, + ): + ib = ir_builder.create() + + sparse_indices = ib.buffer_ptr(sparse_indices_ptr) + prev_shape = ib.buffer_ptr(prev_shape_ptr) + + new_shape = ib.buffer_ptr(new_shape_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + new_sparse_indices = ib.buffer_ptr(new_sparse_indices_ptr) + out_new_shape = ib.buffer_ptr(out_new_shape_ptr) + + prev_shape_size = prev_shape_ptr.shape[0] + new_shape_size = new_shape_ptr.shape[0] + + multipliers = ib.allocate( + new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="local" + ) + dividers = ib.allocate( + new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="local" + ) + flattened_indices = ib.allocate( + new_shape_ptr.dtype, + (sparse_indices_ptr.shape[0],), + name="flattened_indices", + scope="local", + ) + + total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="local") + total_ele[0] = prev_shape[0] + + # Cumulative Reverse Exclusive Multiply + multipliers[prev_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) + with ib.for_range(0, prev_shape_size - 1) as i_: + i = i_ + 1 + multipliers[prev_shape_size - 1 - i] = ( + prev_shape[prev_shape_size - i] * multipliers[prev_shape_size - i] + ) + total_ele[0] *= prev_shape[prev_shape_size - i] + + division_total_ele = ib.allocate( + new_shape_ptr.dtype, (1,), name="division_total_ele", scope="local" + ) + division_total_ele[0] = Cast(new_shape_ptr.dtype, 1) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] != -1): + division_total_ele[0] *= new_shape[i] + + # Compute true output shape (replace negative ones) + with ib.for_range(0, new_shape_size) as i: + with ib.if_scope(new_shape[i] == -1): + out_new_shape[i] = Cast( + new_shape_ptr.dtype, div(total_ele[0], division_total_ele[0]) + ) + with ib.else_scope(): + out_new_shape[i] = new_shape[i] + + equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + + # Check if prev_shape and new_shape are equal + equal_shape[0] = True + with ib.if_scope(prev_shape_size == new_shape_size): + with ib.for_range(0, prev_shape_size) as i: + with ib.if_scope(prev_shape[i] != out_new_shape[i]): + equal_shape[0] = False + with ib.else_scope(): + equal_shape[0] = False + + # Return same inputs if shapes are equal + with ib.if_scope(equal_shape[0]): + with ib.for_range(0, sparse_indices_ptr.shape[0], kind="parallel") as i: + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[i, j] = sparse_indices[i, j] + + # Else compute new_sparse_indices + with ib.else_scope(): + dividers[new_shape_size - 1] = Cast(new_shape_ptr.dtype, 1) + with ib.for_range(0, new_shape_size - 1) as i_: + i = i_ + 1 + dividers[new_shape_size - 1 - i] = ( + dividers[new_shape_size - i] * out_new_shape[new_shape_size - i] + ) + + with ib.for_range(0, sparse_indices_ptr.shape[0], kind="parallel") as i: + flattened_indices[i] = Cast(new_shape_ptr.dtype, 0) + with ib.for_range(0, sparse_indices_ptr.shape[1]) as j: + flattened_indices[i] += sparse_indices[i, j] * multipliers[j] + + with ib.for_range(0, new_sparse_indices_ptr.shape[0], kind="parallel") as i: + current_element = ib.allocate( + new_shape_ptr.dtype, (1,), name="current_element", scope="local" + ) + current_element[0] = flattened_indices[i] + + with ib.for_range(0, new_sparse_indices_ptr.shape[1]) as j: + new_sparse_indices[i, j] = Cast( + sparse_indices_ptr.dtype, floordiv(current_element[0], dividers[j]) + ) + current_element[0] = floormod(current_element[0], dividers[j]) + + return ib.get() + + new_sparse_indices_buf = decl_buffer( + new_sparse_indices_shape, sparse_indices.dtype, "new_sparse_indices_buf" + ) + new_shape_buf = decl_buffer(new_shape_shape, prev_shape.dtype, "new_shape_buf") + + return extern( + [new_sparse_indices_shape, new_shape_shape], + [sparse_indices, prev_shape, new_shape], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), + out_buffers=[new_sparse_indices_buf, new_shape_buf], + name="sparse_reshape_cpu", + tag="sparse_reshape_cpu", + ) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index eae231fd8d06..941f43a5a2c4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1631,6 +1631,56 @@ RELAY_REGISTER_OP("sparse_fill_empty_rows") .set_support_level(3) .set_attr("TOpPattern", kOpaque); +bool SparseReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [sparse_indices, prev_shape, new_shape, result] + ICHECK_EQ(types.size(), 4) << "SparseReshapeRel expects 4 types but " << types.size() + << " provided"; + ICHECK_EQ(num_inputs, 3) << "SparseReshapeRel expects 4 inputs but " << num_inputs << " provided"; + auto sparse_indices = types[0].as(); + auto prev_shape = types[1].as(); + auto new_shape = types[2].as(); + if (sparse_indices == nullptr || prev_shape == nullptr || new_shape == nullptr) { + return false; + } + CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers"; + CHECK(prev_shape->dtype.is_int()) << "prev_shape must be tensor of integers"; + CHECK(new_shape->dtype.is_int()) << "new_shape must be tensor of integers"; + ICHECK_EQ(sparse_indices->shape.size(), 2) << "sparse_indices must be 2-D tensor"; + ICHECK_EQ(prev_shape->shape.size(), 1) << "prev_shape must be 1-D tensor"; + ICHECK_EQ(new_shape->shape.size(), 1) << "new_shape must be 1-D tensor"; + std::vector fields; + Array new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]}; + fields.push_back(TensorType(new_sparse_indices_shape, sparse_indices->dtype)); + fields.push_back(TensorType(new_shape->shape, new_shape->dtype)); + reporter->Assign(types[3], TupleType(Array(fields))); + return true; +} + +Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) { + static const Op& op = Op::Get("sparse_reshape"); + return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape); + +RELAY_REGISTER_OP("sparse_reshape") + .describe(R"code(Return new sparse indices of the reshaped tensor +)code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("sparse_indices", "Tensor", + "A 2-D tensor of shape [N, ndims], which specifies the indices of the" + "elements in the sparse tensor that contain nonzero values. COO Format") + .add_argument("prev_shape", "Tensor", + "A 1-D tensor of shape [ndims], which specifies the previous dense shape of the" + "sparse tensor") + .add_argument("new_shape", "Tensor", + "A 1-D tensor of shape [ndims], which specifies the desired dense shape of the" + "sparse tensor") + .add_type_rel("sparse_reshape", SparseReshapeRel) + .set_attr("TOpPattern", kInjective) + .set_support_level(3); + // meshgrid operator TVM_REGISTER_NODE_TYPE(MeshgridAttrs); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 8b146b6511ce..41145bf77218 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1956,6 +1956,130 @@ def test_forward_sparse_fill_empty_rows( ####################################################################### +# SparseReshape +# ------------ + + +def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, use_dyn=False): + with tf.Graph().as_default(): + if use_dyn: + indices = tf.placeholder(shape=(None, None), dtype=indices_np.dtype, name="indices") + values = tf.placeholder(shape=(None), dtype=values_np.dtype, name="values") + prev_shape = tf.placeholder(shape=(None), dtype=prev_shape_np.dtype, name="prev_shape") + new_shape = tf.placeholder(shape=(None), dtype=new_shape_np.dtype, name="new_shape") + else: + indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices") + values = tf.placeholder(shape=values_np.shape, dtype=values_np.dtype, name="values") + prev_shape = tf.placeholder( + shape=prev_shape_np.shape, dtype=prev_shape_np.dtype, name="prev_shape" + ) + new_shape = tf.placeholder( + shape=new_shape_np.shape, dtype=new_shape_np.dtype, name="new_shape" + ) + sp_input = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=prev_shape) + + _ = tf.sparse.reshape(sp_input, new_shape, name="sparse_reshape") + compare_tf_with_tvm( + [indices_np, values_np, prev_shape_np, new_shape_np], + [indices.name, values.name, prev_shape.name, new_shape.name], + ["sparse_reshape:0", "sparse_reshape:1", "sparse_reshape/Identity:0"], + mode="vm", + ) + + +@pytest.mark.parametrize( + "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", + [ + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, -1], dtype=np.int64), + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, 2], dtype=np.int64), + ), + ( + np.ones((0, 2), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([3, 6], dtype=np.int64), + np.array([-1, 2], dtype=np.int64), + ), + ( + np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6], dtype=np.int64), + np.array([-1, 9], dtype=np.int64), + ), + ( + np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 2, 3], + [0, 1, 0, 3, 5], + [1, 0, 0, 4, 6], + [1, 2, 3, 6, 8], + ], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7, 9], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([9, 4], dtype=np.int64), + np.array([-1], dtype=np.int64), + ), + ( + np.array([[0], [5], [10], [20], [24]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([25], dtype=np.int64), + np.array([5, 5], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, -1], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([250, 40], dtype=np.int64), + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_forward_sparse_reshape( + sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn +): + """ sparse_reshape op test""" + ################################################################### + # + # In order to create a SparseTensor, it requires 3 input as below: + # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + # + # Above Sparse can be represented in Dense as below : + # [[1, 0, 0, 0] + # [0, 0, 2, 0] + # [0, 0, 0, 0]] + # + # ------------------------------------------------------------------ + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn) + + # tensorflow.compat.v1.sparse_to_dense # --------------- def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index ee55b532218d..c9ed975c3b9b 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1311,6 +1311,229 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) +@tvm.testing.uses_gpu +@pytest.mark.parametrize( + "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np", + [ + ( + np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([2, 3, 6], dtype=np.int32), + np.array([9, -1], dtype=np.int32), + ), + ( + np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), + ( + np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 2, 3], + [0, 1, 0, 3, 5], + [1, 0, 0, 4, 6], + [1, 2, 3, 6, 8], + ], + dtype=np.int64, + ), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([2, 3, 6, 7, 9], dtype=np.int64), + np.array([9, -1, 7], dtype=np.int64), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([9, 4], dtype=np.int32), + np.array([2, -1, 6], dtype=np.int32), + ), + ( + np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([9, 4], dtype=np.int64), + np.array([-1], dtype=np.int64), + ), + ( + np.array([[0], [5], [10], [20], [24]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([25], dtype=np.int32), + np.array([5, 5], dtype=np.int32), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int32), + np.array([7, 5, 6, 3, 9], dtype=np.int32), + np.array([500, 20], dtype=np.int32), + np.array([500, -1], dtype=np.int32), + ), + ( + np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], dtype=np.int64), + np.array([7, 5, 6, 3, 9], dtype=np.int64), + np.array([500, 20], dtype=np.int64), + np.array([250, 40], dtype=np.int64), + ), + ( + np.ones((0, 1), dtype=np.int32), + np.array([], dtype=np.int32), + np.array([4], dtype=np.int32), + np.array([2, -1], dtype=np.int32), + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([4], dtype=np.int64), + np.array([2, 2], dtype=np.int64), + ), + ( + np.ones((0, 2), dtype=np.int32), + np.array([], dtype=np.int32), + np.array([3, 6], dtype=np.int32), + np.array([-1, 2], dtype=np.int32), + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn): + def ref_sparse_reshape( + sparse_indices: np.ndarray, + prev_shape: np.ndarray, + new_shape: np.ndarray, + ): + """ + This function calculates the expected output of sparseshape operator given the inputs. + """ + + new_sparse_indices = np.ones( + (sparse_indices.shape[0], new_shape.shape[0]), dtype=sparse_indices.dtype + ) + multipliers = np.ones(prev_shape.shape[0]) + dividers = np.ones(new_shape.shape[0]) + total_ele = np.prod(prev_shape) + division_total_ele = 1 + for i in range(new_shape.shape[0]): + if new_shape[i] == -1: + continue + division_total_ele *= new_shape[i] + for i in range(prev_shape.shape[0] - 2, -1, -1): + multipliers[i] = prev_shape[i + 1] * multipliers[i + 1] + + for i in range(len(new_shape)): + if new_shape[i] == -1: + new_shape[i] = total_ele // division_total_ele + + if np.array_equal(prev_shape, new_shape): + return sparse_indices, prev_shape + + for i in range(new_shape.shape[0] - 2, -1, -1): + dividers[i] = new_shape[i + 1] * dividers[i + 1] + + for row_num, sparse_row in enumerate(sparse_indices): + flat_idx = 0 + if len(sparse_indices.shape) != 1: + for i, ele in enumerate(sparse_row): + flat_idx += sparse_row[i] * multipliers[i] + else: + flat_idx += sparse_row + if len(new_sparse_indices.shape) != 1: + for i in range(new_sparse_indices.shape[1]): + new_sparse_indices[row_num][i] = flat_idx // dividers[i] + flat_idx = flat_idx % dividers[i] + else: + new_sparse_indices[row_num] = flat_idx + + return new_sparse_indices, new_shape + + def verify_sparse_reshape( + sparse_indices_np: np.ndarray, + sparse_values_np: np.ndarray, + prev_shape_np: np.ndarray, + new_shape_np: np.ndarray, + ): + """ + This function verifies the relay output of sparse_reshape with its expected output. + """ + if use_dyn: + sparse_indices = relay.var( + "sparse_indices", + shape=[relay.Any(), relay.Any()], + dtype=str(sparse_indices_np.dtype), + ) + prev_shape = relay.var( + "prev_shape", + shape=[relay.Any()], + dtype=str(prev_shape_np.dtype), + ) + new_shape = relay.var( + "new_shape", + shape=[relay.Any()], + dtype=str(new_shape_np.dtype), + ) + else: + sparse_indices = relay.var( + "sparse_indices", + relay.TensorType(sparse_indices_np.shape, str(sparse_indices_np.dtype)), + ) + prev_shape = relay.var( + "prev_shape", relay.TensorType(prev_shape_np.shape, str(prev_shape_np.dtype)) + ) + new_shape = relay.var( + "new_shape", relay.TensorType(new_shape_np.shape, str(new_shape_np.dtype)) + ) + z = relay.op.sparse_reshape(sparse_indices, prev_shape, new_shape).astuple() + + func = relay.Function([sparse_indices, prev_shape, new_shape], z) + + ref_res = ref_sparse_reshape(sparse_indices_np, prev_shape_np, new_shape_np) + outputs = run_infer_type(z) + new_sparse_indices_infer_type, new_shape_infer_type = ( + outputs.checked_type.fields[0].dtype, + outputs.checked_type.fields[1].dtype, + ) + + assert new_sparse_indices_infer_type == sparse_indices_np.dtype + assert new_shape_infer_type == new_shape_np.dtype + verify_func( + func, + [sparse_indices_np, prev_shape_np, new_shape_np], + ref_res, + ) + + verify_sparse_reshape( + sparse_indices_np, + sparse_values_np, + prev_shape_np, + new_shape_np, + ) + + +def verify_func(func, data, ref_res, target_ctx=tvm.testing.enabled_targets()): + assert isinstance(data, list) + for target, ctx in target_ctx: + for kind in ["vm"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(*data) + if isinstance(op_res, tvm.runtime.container.ADT): + assert len(op_res) == len( + ref_res + ), "Outputs from TVM and Python implementation must be equal " + + for op_result, ref_result in zip(op_res, ref_res): + tvm.testing.assert_allclose(op_result.asnumpy(), ref_result, rtol=1e-5) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + relay.backend.compile_engine.get().clear() + + @tvm.testing.uses_gpu def test_adv_index(): def verify_adv_index(data_shape, index_shapes):