Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseFillEmptyRows Op #7442

Merged
merged 32 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,30 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_fill_empty_rows():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
sparse_indices = inputs[0]
sparse_values = inputs[1]
sparse_indices_num_cols = _infer_shape(sparse_indices, mod)[1]
first_column = _op.split(sparse_indices, sparse_indices_num_cols, axis=1)[0]
sorted_indices = _op.argsort(_op.squeeze(first_column))
sorted_sparse_indices = _op.take(sparse_indices, sorted_indices, axis=0)
sorted_sparse_values = _op.take(sparse_values, sorted_indices, axis=0)
new_sparse_indices, new_sparse_values, empty_row_indicator = _op.sparse_fill_empty_rows(
sorted_sparse_indices, sorted_sparse_values, inputs[2], inputs[3]
)

return _expr.TupleWrapper(
_expr.Tuple(
[new_sparse_indices, new_sparse_values, _op.cast(empty_row_indicator, dtype="bool")]
),
3,
)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2447,6 +2471,7 @@ def _impl(inputs, attr, params, mod):
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseFillEmptyRows": _sparse_fill_empty_rows(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
63 changes: 62 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks,
# pylint: disable=too-many-local-variables, too-many-arguments, no-else-return

from __future__ import absolute_import
import tvm
from tvm import te
Expand Down Expand Up @@ -94,6 +96,24 @@ def compute_scatter(attrs, inputs, output_type):

_reg.register_strategy("scatter", strategy.scatter_strategy)

# sparse_fill_empty_rows
@_reg.register_compute("sparse_fill_empty_rows")
def compute_sparse_fill_empty_rows(attrs, inputs, output_type):
"""Compute definition of sparse_fill_empty_rows"""

return topi.sparse_fill_empty_rows(
inputs[0],
inputs[1],
inputs[2],
inputs[3],
output_type.fields[0].shape,
output_type.fields[1].shape,
output_type.fields[2].shape,
)


_reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy)

# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
Expand Down Expand Up @@ -445,6 +465,47 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)


@script
def _sparse_fill_empty_rows_shape_func(sparse_indices, dense_shape):

new_sparse_indices_shape = output_tensor((2,), "int64")
new_sparse_values_shape = output_tensor((1,), "int64")
empty_row_indicator_shape = output_tensor((1,), "int64")
num_dense_rows = int64(dense_shape[0])

if int64(sparse_indices.shape[0]) == int64(0): # Handle Empty Case
# Total rows will equal dense_shape[0]
new_sparse_indices_shape[0] = num_dense_rows
new_sparse_indices_shape[1] = int64(sparse_indices.shape[1])
new_sparse_values_shape[0] = num_dense_rows
empty_row_indicator_shape[0] = num_dense_rows
return (new_sparse_indices_shape, new_sparse_values_shape, empty_row_indicator_shape)

else:
count = int64(sparse_indices.shape[0]) # Add count of all rows already in sparse_indices
for i in range(1, int64(sparse_indices.shape[0])):
index = int64(sparse_indices[i, 0])
prev_index = int64(sparse_indices[i - 1, 0] + 1)

if index > prev_index:
count += index - prev_index # Add count of all rows between two consecutive indices

count += int64(sparse_indices[0, 0]) # Add count from 0 to first row id in sparse_indices
count += int64(
num_dense_rows - 1 - sparse_indices[sparse_indices.shape[0] - 1, 0]
) # Add count from last row id to dense_shape - 1
new_sparse_indices_shape[0] = int64(count)
new_sparse_indices_shape[1] = int64(sparse_indices.shape[1])
new_sparse_values_shape[0] = int64(count)
empty_row_indicator_shape[0] = num_dense_rows
return (new_sparse_indices_shape, new_sparse_values_shape, empty_row_indicator_shape)


@_reg.register_shape_func("sparse_fill_empty_rows", True)
def sparse_fill_empty_rows_func(attrs, inputs, _):
return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2])


@script
def _layout_transform_shape_func(
data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,35 @@ def roi_align_strategy(attrs, inputs, out_type, target):
return strategy


# sparse_fill_empty_rows
@override_native_generic_func("sparse_fill_empty_rows_strategy")
def sparse_fill_empty_rows_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_fill_empty_rows(topi.sparse_fill_empty_rows),
wrap_topi_schedule(topi.generic.schedule_sparse_fill_empty_rows),
name="sparse_fill_empty_rows.generic",
)
return strategy


def wrap_compute_sparse_fill_empty_rows(topi_compute):
"""Wrap sparse_fill_empty_rows compute"""

def _compute_sparse_fill_empty_rows(attrs, inputs, output_type):
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
inputs[3],
output_type.fields[0].shape,
output_type.fields[1].shape,
output_type.fields[2].shape,
)

return _compute_sparse_fill_empty_rows


# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
Expand Down
67 changes: 67 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,73 @@ def adv_index(inputs):
return _make.adv_index(Tuple(inputs))


def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_value):
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
"""
Fill rows in a sparse matrix that do no contain any values. Values are placed in the first
column of empty rows. The sparse array is in COO format.
It returns a TupleWrapper with 3 outputs
Parameters
----------
sparse_indices : relay.Expr
A 2-D int64 tensor[N, ndims] of integers containing location of sparse values, where N is
the number of sparse values and n_dim is the number of dimensions of the dense_shape.
The first column of this relay parameter must be sorted in ascending order.
sparse_values : relay.Expr
A 1-D int64 tensor[N] containing the sparse values for the sparse indices.
dense_shape : relay.Expr
A 1-D int64 tensor[ndims] which contains shape of the dense output tensor.
default_value : relay.Expr
A 1-D tensor[1] containing the default value for the remaining locations.
Returns
-------
new_sparse_indices : relay.Expr
A 2-D int64 tensor[?, ndims] of integers containing location of new sparse
indices. The first column outputs must be sorted in ascending order.
new_sparse_values : relay.Expr
A 1-D int64 tensor[?] containing the sparse values for the sparse indices.
empty_row_indicator : relay.Expr
A 1-D int64 tensor[dense_shape[0]] filled with zeros and ones
indicating whether the particular row is empty or full respectively

Note
----
This op exactly follows the documentation here:
https://www.tensorflow.org/api_docs/python/tf/sparse/fill_empty_rows
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
There are two exceptions:
1. Input Sparse Indices are expected to be in row-major order.
2. Empty Row Indicator has int64 output type with 1(for True) and 0(for False).

Examples
-------
.. code-block:: python
sparse_indices = [[0, 1],
[0, 3],
[2, 0],
[3, 1]]
sparse_values = [1, 2, 3, 4]
default_value = [10]
dense_shape = [5, 6]
new_sparse_indices, empty_row_indicator, new_sparse_values, slice_element_index =
relay.sparse_fill_empty_rows(
sparse_indices,
sparse_values,
default_value,
dense_shape)
new_sparse_indices = [[0, 1],
[0, 3],
[1, 0],
[2, 0],
[3, 1],
[4, 0]]
empty_row_indicator = [0, 1, 0, 0, 1]
new_sparse_values = [1, 2, 10, 3, 4, 10]

"""
return TupleWrapper(
_make.sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_value), 3
)


def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .broadcast import *
from .sort import *
from .scatter import *
from .sparse_fill_empty_rows import *
from .scatter_add import *
from .argwhere import *
from .cumsum import *
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/topi/generic/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ def schedule_scatter_add(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_sparse_fill_empty_rows(outs):
return _default_schedule(outs, False)
109 changes: 109 additions & 0 deletions python/tvm/topi/sparse_fill_empty_rows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHnew_sparse_indices WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=no-else-return, too-many-locals, too-many-arguments, too-many-branches
# pylint: disable=undefined-variable, invalid-name
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
"""SparseFillEmptyRows operator"""
from ..te import hybrid


@hybrid.script
def _sparse_fill_empty_rows(
sparse_indices,
sparse_values,
dense_shape,
default_value,
new_sparse_indices_shape,
new_sparse_values_shape,
empty_row_indicator_shape,
):
default_value_ = int64(default_value[0])
new_sparse_indices = output_tensor(new_sparse_indices_shape, "int64")
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
new_sparse_values = output_tensor(new_sparse_values_shape, "int64")
empty_row_indicator = output_tensor(empty_row_indicator_shape, "int64")
new_sparse_indices_row_id = 0

if int64(sparse_indices.shape[0]) == int64(0): # Handle Empty Case
# Fill all rows with default values
for i in range(0, int64(new_sparse_indices_shape[0])):
new_sparse_indices[i, 0] = int64(i)
new_sparse_values[i] = default_value_
empty_row_indicator[i] = int64(1)
for k in range(1, int64(new_sparse_indices_shape[1])):
new_sparse_indices[i, k] = int64(0)

return (new_sparse_indices, new_sparse_values, empty_row_indicator)

else:
# Iterate through sparse_indices and add rows if/when required
for i in range(0, int64(sparse_indices.shape[0])):
if i == 0:
prev_row_id = int64(0)
else:
prev_row_id = int64(sparse_indices[i - 1, 0] + 1)
row_id = int64(sparse_indices[i, 0])

# Since input is in row-major order, add rows between prev_row_id and row_id
for j in range(prev_row_id, row_id):
new_sparse_indices[new_sparse_indices_row_id, 0] = int64(j)
for k in range(1, int64(new_sparse_indices_shape[1])):
new_sparse_indices[new_sparse_indices_row_id, k] = int64(0)
empty_row_indicator[prev_row_id] = int64(1)
new_sparse_values[new_sparse_indices_row_id] = default_value_
new_sparse_indices_row_id += 1

# Add current element to output
new_sparse_indices[new_sparse_indices_row_id, 0] = row_id
for k in range(1, int64(new_sparse_indices_shape[1])):
new_sparse_indices[new_sparse_indices_row_id, k] = int64(sparse_indices[i, k])
new_sparse_values[new_sparse_indices_row_id] = int64(sparse_values[i])
empty_row_indicator[row_id] = int64(0)
new_sparse_indices_row_id += 1

# Add rows with default value if last row id of sparse_indices is not dense_shape[0] - 1
for i in range(
int64(sparse_indices[sparse_indices.shape[0] - 1, 0] + 1), int64(dense_shape[0])
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
):

new_sparse_indices[new_sparse_indices_row_id, 0] = int64(i)
for k in range(1, int64(new_sparse_indices_shape[1])):
new_sparse_indices[new_sparse_indices_row_id, k] = int64(0)
empty_row_indicator[i] = int64(1)
new_sparse_values[new_sparse_indices_row_id] = default_value_
new_sparse_indices_row_id += 1

return (new_sparse_indices, new_sparse_values, empty_row_indicator)


def sparse_fill_empty_rows(
sparse_indices,
sparse_values,
dense_shape,
default_value,
new_sparse_indices_shape,
new_sparse_values_shape,
empty_row_indicator_shape,
):
return _sparse_fill_empty_rows(
sparse_indices,
sparse_values,
dense_shape,
default_value,
new_sparse_indices_shape,
new_sparse_values_shape,
empty_row_indicator_shape,
)
Loading