Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse reshape op #7125

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
begin_ids.push_back(idx);
}

Array<Array<PrimExpr> > out_shapes;
Array<Array<PrimExpr>> out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
PrimExpr out_axis_size;
if (i == begin_ids.size() - 1) {
Expand Down Expand Up @@ -1386,6 +1386,88 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
return result;
}

/*!
* \brief Compute new sparse indices and return them after the sparse_reshape operation
*
* \param sparse_indices Indices where values of the dense tensor exist
* \param prev_shape Old Shape of the sparse tensor corresponding to sparse_indices
* \param new_shape Desired Shape of the sparse tensor which will correspond to output
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sparse_reshape operation
*/
inline Array<Tensor> SparseReshape(const Tensor& sparse_indices, const Tensor& prev_shape,
const Tensor& new_shape,
const std::string name = "T_sparse_reshape",
std::string tag = kInjective) {
Array<Tensor> result;
Array<PrimExpr> new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]};

int new_shape_size = GetConstInt(new_shape->shape[0]);
int prev_shape_size = GetConstInt(prev_shape->shape[0]);
Comment on lines +1407 to +1408
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main complaint is that this will fail with dynamic input shapes. From what I understand, you expect multiple chained dynamically-shaped sparse ops in the model you're trying to target, so I'm hesitant to merge this because I'm under the impression that this will not solve the larger problem you're trying to solve.

I'd really like to see you either test the model in a branch containing all three of your PRs, or write a unit test with a representative subgraph.

std::vector<PrimExpr> multipliers(prev_shape_size, 1);
std::vector<PrimExpr> dividers(new_shape_size, 1);

auto neg_shape_val = compute(Array<PrimExpr>{1}, [&](const Array<Var>& indices) {
tvm::PrimExpr total_ele = prev_shape[0];
for (int i = prev_shape_size - 2; i >= 0; --i) {
multipliers[i] = prev_shape[i + 1] * multipliers[i + 1];
total_ele *= prev_shape[i + 1];
}
PrimExpr division_total_ele = 1;
for (int i = 0; i < new_shape_size; ++i) {
division_total_ele *= if_then_else(new_shape[i] != -1, new_shape[i], 1);
}
for (int i = new_shape_size - 2; i >= 0; --i) {
dividers[i] = dividers[i + 1] * if_then_else(new_shape[i + 1] != -1, new_shape[i + 1],
div(total_ele, division_total_ele));
}
return div(total_ele, division_total_ele);
});

result.push_back(compute(
new_sparse_indices_shape,
[&](const Array<Var>& indices) {
PrimExpr flattened_idx = 0;
if (sparse_indices->shape.size() == 1) {
flattened_idx += sparse_indices[indices[0]];
} else {
for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
flattened_idx += (sparse_indices[indices[0]][k] * multipliers[k]);
}
}
Array<PrimExpr> new_sparse_indices;
if (new_shape_size != 1) {
for (int i = 0; i < new_shape_size; i++) {
new_sparse_indices.push_back(floordiv(flattened_idx, dividers[i]));
flattened_idx = floormod(flattened_idx, dividers[i]);
}
PrimExpr ret = -1;

for (int i = 0; i < new_shape_size; i++) {
if (indices.size() == 1) {
return new_sparse_indices[0];
} else {
ret = if_then_else(indices[1] == i, new_sparse_indices[i], ret);
}
}
return ret;
} else {
return flattened_idx;
}
},
name, tag));
result.push_back(compute(
Array<PrimExpr>{new_shape_size},
[&](const Array<Var>& indices) {
PrimExpr ret = new_shape(indices);
ret = if_then_else(ret == -1, neg_shape_val[0], ret);
return ret;
},
name, tag));
return result;
} // namespace topi
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,22 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_reshape():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
values_tensor = params["SparseTensor/values"].asnumpy()
prev_shape_tensor = _infer_value(inputs[1], params, mod).asnumpy()
new_shape = inputs[2]
indices_data = _expr.const(indices_tensor, indices_tensor.dtype)
prev_shape_data = _expr.const(prev_shape_tensor, prev_shape_tensor.dtype)
ret = _op.sparse_reshape(indices_data, prev_shape_data, new_shape).astuple()
return ret, _expr.const(values_tensor, values_tensor.dtype)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2423,6 +2439,7 @@ def _impl(inputs, attr, params, mod):
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseReshape": _sparse_reshape(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")
_reg.register_injective_schedule("sparse_reshape")


# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
45 changes: 45 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,3 +1320,48 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))


def sparse_reshape(sparse_indices, prev_shape, new_shape):
"""
Reshape a Sparse Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you note that this function only support tensors in COO format, not CSR. In other parts of the codebase, we tend to use CSR.

Copy link
Contributor Author

@codeislife99 codeislife99 Dec 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how this convention is different from the sparse_to_dense operator. I could only find that operator as an example of existing representations ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convention is the same as sparse_to_dense. However sparse_dense uses CSR and BSR formats. We should probably add documentation to sparse_to_dense too.


Parameters
----------
sparse_indices : relay.Expr
A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the
number of sparse values and n_dim is the number of dimensions of the dense_shape
prev_shape : relay.Expr
A 1-D tensor containing the previous shape of the dense tensor
new_shape : relay.Expr
A 1-D tensor containing the new shape of the dense tensor

Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python

codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
sparse_indices = [[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[1, 2, 3]]

prev_shape = [2, 3, 4]

new_shape = [9, -1]

new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices,
prev_shape,
new_shape)
new_sparse_indices = [[0, 0],
[0, 1],
[1, 2],
[4, 2],
[8, 1]]
new_shape = [9, 4]
"""
return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2)
45 changes: 45 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,48 @@ def adv_index(data, indices):
Output tensor
"""
return cpp.adv_index(data, indices)


def sparse_reshape(sparse_indices, prev_shape, new_shape):
"""
Reshape a Sparse Tensor

Parameters
----------
sparse_indices : relay.Expr
A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the
number of sparse values and n_dim is the number of dimensions of the dense_shape
prev_shape : relay.Expr
A 1-D tensor containing the previous shape of the dense tensor
new_shape : relay.Expr
A 1-D tensor containing the new shape of the dense tensor

Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python

sparse_indices = [[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[1, 2, 3]]

prev_shape = [2, 3, 4]

new_shape = [9, -1]

codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices,
prev_shape,
new_shape)
new_sparse_indices = [[0, 0],
[0, 1],
[1, 2],
[4, 2],
[8, 1]]
new_shape = [9, 4]
"""
return cpp.sparse_reshape(sparse_indices, prev_shape, new_shape)
41 changes: 41 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,47 @@ RELAY_REGISTER_OP("meshgrid")
.set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

bool SparseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [sparse_indices, prev_shape, new_shape, result]
ICHECK_EQ(types.size(), 4) << "SparseReshapeRel expects 4 types but " << types.size()
<< " provided";
auto sparse_indices = types[0].as<TensorTypeNode>();
auto new_shape = types[2].as<TensorTypeNode>();
Array<PrimExpr> new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]};
std::vector<Type> fields;
fields.push_back(TensorType(new_sparse_indices_shape, sparse_indices->dtype));
fields.push_back(TensorType(new_shape->shape, new_shape->dtype));
reporter->Assign(types[3], TupleType(Array<Type>(fields)));
return true;
}

Array<te::Tensor> SparseReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
ICHECK_EQ(inputs.size(), 3) << "SparseReshapeCompute expects 2 input but provided "
<< inputs.size();
return {topi::SparseReshape(inputs[0], inputs[1], inputs[2])};
}

Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) {
static const Op& op = Op::Get("sparse_reshape");
return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape);

RELAY_REGISTER_OP("sparse_reshape")
.describe(R"code(Return new sparse indices of the reshaped tensor
)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("sparse_indices", "Tensor", "The first tensor")
.add_argument("prev_shape", "Tensor", "The second tensor")
.add_argument("new_shape", "Tensor", "The third tensor")
.add_type_rel("sparse_reshape", SparseReshapeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", SparseReshapeCompute);

// tile operator
TVM_REGISTER_NODE_TYPE(TileAttrs);

Expand Down
70 changes: 70 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,76 @@ def test_forward_sparse_dense_matmul():
)


#######################################################################
# SparseReshape
# ------------


def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, dtype):
with tf.Graph().as_default():
sp_input = tf.sparse.SparseTensor(
indices=indices_np, values=values_np, dense_shape=prev_shape_np
)
new_shape = tf.placeholder(
shape=new_shape_np.shape, dtype=new_shape_np.dtype, name="new_shape"
)

tf.sparse.reshape(sp_input, new_shape, name="sparse_reshape")
compare_tf_with_tvm(
[new_shape_np],
[new_shape.name],
["sparse_reshape:0", "sparse_reshape:1", "sparse_reshape/Identity:0"],
)


def test_forward_sparse_reshape():
""" sparse_reshape op test"""
###################################################################
#
# In order to create a SparseTensor, it requires 3 input as below:
# SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
#
# Above Sparse can be represented in Dense as below :
# [[1, 0, 0, 0]
# [0, 0, 2, 0]
# [0, 0, 0, 0]]
#
# ------------------------------------------------------------------
sparse_indices_np = np.array(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32
)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([2, 3, 6], dtype=np.int32)
new_shape_np = np.array([9, 4], dtype=np.int32)
_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")

sparse_indices_np = np.array(
[[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], dtype=np.int32
)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([2, 3, 6, 7], dtype=np.int32)
new_shape_np = np.array([9, -1, 7], dtype=np.int32)
_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")

sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([9, 4], dtype=np.int32)
new_shape_np = np.array([2, -1, 6], dtype=np.int32)
_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")

sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([9, 4], dtype=np.int32)
new_shape_np = np.array([-1], dtype=np.int32)
_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")

sparse_indices_np = np.array([[0], [5], [10], [20], [24]], dtype=np.int32)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([25], dtype=np.int32)
new_shape_np = np.array([5, 5], dtype=np.int32)
_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")


#######################################################################
# StridedSlice
# ------------
Expand Down
Loading