Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseReshape Op #7477

Merged
merged 27 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,15 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_reshape():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
new_indices, new_shape = get_relay_op("sparse_reshape")(inputs[0], inputs[1], inputs[2])
return _expr.TupleWrapper(_expr.Tuple([new_indices, new_shape]), 2)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2626,6 +2635,7 @@ def _impl(inputs, attr, params, mod):
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseFillEmptyRows": _sparse_fill_empty_rows(),
"SparseReshape": _sparse_reshape(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")


# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)

Expand Down Expand Up @@ -114,6 +115,22 @@ def compute_sparse_fill_empty_rows(attrs, inputs, output_type):

_reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy)

# sparse_reshape
@_reg.register_compute("sparse_reshape")
def compute_reshape(attrs, inputs, output_type):
"""Compute definition of sparse_reshape"""

return topi.sparse_reshape(
inputs[0],
inputs[1],
inputs[2],
output_type.fields[0].shape,
output_type.fields[1].shape,
)


_reg.register_strategy("sparse_reshape", strategy.sparse_reshape_strategy)

# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
Expand Down Expand Up @@ -506,6 +523,24 @@ def sparse_fill_empty_rows_func(attrs, inputs, _):
return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2])


@script
def _sparse_reshape_shape_func(sparse_indices_shape, prev_shape_shape, new_shape_shape):
indices_shape = output_tensor((2,), "int64")
indices_shape[0] = int64(sparse_indices_shape[0])
indices_shape[1] = int64(new_shape_shape[0])
shape_tensor = output_tensor((1,), "int64")
shape_tensor[0] = int64(new_shape_shape[0])
return (indices_shape, shape_tensor)


@_reg.register_shape_func("sparse_reshape", False)
def sparse_reshape_shape_func(attrs, inputs, _):
"""
Shape func for sparse_reshape.
"""
return _sparse_reshape_shape_func(inputs[0], inputs[1], inputs[2])


@script
def _layout_transform_shape_func(
data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,17 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@sparse_reshape_strategy.register(["cuda", "gpu"])
def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_reshape(topi.cuda.sparse_reshape),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_reshape.cuda",
)
return strategy


@sparse_dense_padded_strategy.register(["cuda", "gpu"])
def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,33 @@ def _compute_sparse_fill_empty_rows(attrs, inputs, output_type):
return _compute_sparse_fill_empty_rows


# sparse_reshape
@override_native_generic_func("sparse_reshape_strategy")
def sparse_reshape_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_reshape(topi.sparse_reshape),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_reshape.generic",
)
return strategy


def wrap_compute_sparse_reshape(topi_compute):
"""Wrap sparse_reshape compute"""

def _compute_sparse_reshape(attrs, inputs, output_type):
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
output_type.fields[0].shape,
output_type.fields[1].shape,
)

return _compute_sparse_reshape


# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,46 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v
return Tuple((new_sparse_indices, new_sparse_values, empty_row_indicator))


def sparse_reshape(sparse_indices, prev_shape, new_shape):
"""
Reshape a Sparse Tensor. The sparse array is in COO format.

Parameters
----------
sparse_indices : relay.Expr
A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the
number of sparse values and n_dim is the number of dimensions of the dense_shape
prev_shape : relay.Expr
A 1-D tensor containing the previous shape of the dense tensor
new_shape : relay.Expr
A 1-D tensor containing the new shape of the dense tensor
Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python
sparse_indices = [[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[1, 2, 3]]
prev_shape = [2, 3, 4]
new_shape = [9, -1]
new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices,
prev_shape,
new_shape)
new_sparse_indices = [[0, 0],
[0, 1],
[1, 2],
[4, 2],
[8, 1]]
new_shape = [9, 4]
"""
return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2)


def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .sort import *
from .scatter import *
from .sparse_fill_empty_rows import *
from .sparse_reshape import *
from .scatter_add import *
from .argwhere import *
from .cumsum import *
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
from .sparse_reshape import *
Loading