Skip to content

Commit

Permalink
[Relay/TOPI][TFLite] Implemented MATRIX_SET_DIAG Operator for Relay/T…
Browse files Browse the repository at this point in the history
…OPI and TFLite Frontend. (#6303)

* Corrected docstring error.

* Minor changes.

* Changed MATRIX_SET_DIAG registration from broadcast to injective.
  • Loading branch information
jainris authored Aug 27, 2020
1 parent f6d3cee commit 082f27e
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 0 deletions.
29 changes: 29 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,35 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<Integer>
name, tag);
}

/*!
* \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonal.
* \param input input tensor.
* \param diagonal values to be filled in the diagonal.
* \param name output tensor name.
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
auto get_diag = [&]() {
Array<PrimExpr> diagonal_indices;
for (size_t i = 0; i < ndim; i++) {
diagonal_indices.push_back(iter_vars[i]);
}
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] == iter_vars[ndim - 1], get_diag(),
input(iter_vars));
},
name, tag);
}

} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_TRANSFORM_H_
28 changes: 28 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, model, subgraph, exp_tab):
'LOGICAL_NOT': self.convert_logical_not,
'LOGICAL_OR': self.convert_logical_or,
'LOGISTIC': self.convert_logistic,
'MATRIX_SET_DIAG': self.convert_matrix_set_diag,
'MAX_POOL_2D': self.convert_max_pool2d,
'MAXIMUM': self.convert_maximum,
'MEAN': self.convert_reduce_mean,
Expand Down Expand Up @@ -2989,6 +2990,33 @@ def convert_reverse_v2(self, op):
out = _op.reverse(input_expr, axis)
return out

def convert_matrix_set_diag(self, op):
"""Convert TFLite MATRIX_SET_DIAG"""

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensor's length should be 2"

assert input_tensors[0].tensor.Type() == input_tensors[1].tensor.Type(), \
"input and diagonal should be the same type of tensors"

if input_tensors[0].qnn_params:
# Check that input and output tensor have same qnn params.
output_tensors = self.get_output_tensors(op)
assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \
"TFLite MATRIX_SET_DIAG requires input and output tensors' \
scale and zero points to be equal"

# Check that input and diagonal tensor have same qnn params.
assert self.has_same_qnn_params(input_tensors[0], input_tensors[1]), \
"TFLite MATRIX_SET_DIAG requires input and diagonal tensors' \
scale and zero points to be equal"

input_expr = self.get_tensor_expr(input_tensors[0])
diagonal_expr = self.get_tensor_expr(input_tensors[1])

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out


def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_reg.register_reduce_schedule("collapse_sum_to")
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")

# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,44 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0
if default_value == 0:
default_value = const(0)
return _make.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value)


def matrix_set_diag(data, diagonal):
"""
Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values.
Parameters
----------
data : relay.Expr
Input Tensor.
diagonal : relay.Expr
Values to be filled in the diagonal.
Returns
-------
result : relay.Expr
New tensor with given diagonal values.
Examples
--------
.. code-block:: python
data = [[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]]]
diagonal = [[1, 2, 3],
[4, 5, 6]]
relay.matrix_set_diag(input, diagonal) =
[[[1, 7, 7, 7],
[7, 2, 7, 7],
[7, 7, 3, 7]],
[[4, 7, 7, 7],
[7, 5, 7, 7],
[7, 7, 6, 7]]]
"""
return _make.matrix_set_diag(data, diagonal)
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@
get_elemwise_schedule, get_conv2d_nchw_implement, dispatch
from .adaptive_pool_python import adaptive_pool
from .grid_sample_python import affine_grid_python, grid_sample_nchw_python
from .matrix_set_diag import matrix_set_diag
47 changes: 47 additions & 0 deletions python/tvm/topi/testing/matrix_set_diag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""MatrixSetDiag in Python"""
import numpy as np

def matrix_set_diag(input_np, diagonal):
"""matrix_set_diag operator implemented in numpy.
Returns a numpy array with the diagonal of input array
replaced with the provided diagonal values.
Parameters
----------
input : numpy.ndarray
Input Array.
Shape = [D1, D2, D3, ... , Dn-1 , Dn]
diagonal : numpy.ndarray
Values to be filled in the diagonal.
Shape = [D1, D2, D3, ... , Dn-1]
Returns
-------
result : numpy.ndarray
New Array with given diagonal values.
Shape = [D1, D2, D3, ... , Dn-1 , Dn]
"""
out = np.array(input_np, copy=True)
n = min(input_np.shape[-1], input_np.shape[-2])
for i in range(n):
out[..., i, i] = diagonal[..., i]

return out
40 changes: 40 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,43 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0
"""

return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value)

def matrix_set_diag(data, diagonal):
"""
Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values.
Parameters
----------
data : relay.Expr
Input Tensor.
diagonal : relay.Expr
Values to be filled in the diagonal.
Returns
-------
result : relay.Expr
New tensor with given diagonal values.
Examples
--------
.. code-block:: python
data = [[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]]]
diagonal = [[1, 2, 3],
[4, 5, 6]]
relay.matrix_set_diag(input, diagonal) =
[[[1, 7, 7, 7],
[7, 2, 7, 7],
[7, 7, 3, 7]],
[[4, 7, 7, 7],
[7, 5, 7, 7],
[7, 7, 6, 7]]]
"""
return cpp.matrix_set_diag(data, diagonal)
50 changes: 50 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3093,5 +3093,55 @@ RELAY_REGISTER_OP("sparse_to_dense")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);

// relay.matrix_set_diag
bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [input, diagonal, result]
CHECK_EQ(types.size(), 3);

const auto* input = types[0].as<TensorTypeNode>();
CHECK(input);

const auto* diagonal = types[1].as<TensorTypeNode>();
CHECK(diagonal);

int d_ndims = diagonal->shape.size();
for (int i = 0; i < d_ndims - 1; i++) {
reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
}
auto min_dim = if_then_else(input->shape[d_ndims - 1] >= input->shape[d_ndims],
input->shape[d_ndims], input->shape[d_ndims - 1]);
reporter->Assert(diagonal->shape[d_ndims - 1] >= min_dim);

reporter->Assign(types[2], TensorType(input->shape, input->dtype));
return true;
}

Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1])};
}

Expr MakeMatrixSetDiag(Expr input, Expr diagonal) {
static const Op& op = Op::Get("matrix_set_diag");
return Call(op, {input, diagonal}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);

RELAY_REGISTER_OP("matrix_set_diag")
.describe(
R"code(Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values.
**input** Input tensor.
**diagonal** Values to be filled in the diagonal.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("input", "Tensor", "Input Tensor.")
.add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
.set_support_level(10)
.add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
4 changes: 4 additions & 0 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,9 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
});

TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = matrix_set_diag(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
72 changes: 72 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,6 +2652,77 @@ def test_forward_reverse_v2():
_test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype)


#######################################################################
# MATRIX_SET_DIAG
# ---------------

def _test_matrix_set_diag(input_shape, input_type, quantized=False):
""" One iteration of MATRIX_SET_DIAG """
with tf.Graph().as_default():
diagonal_shape = list(input_shape[:-2])
diagonal_shape.append(min(input_shape[-2], input_shape[-1]))

if quantized:
# ignoring input_type as quantized requires uint8
input = np.random.uniform(0, 256, input_shape).astype('uint8')
in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input")
inq_input = tf.quantization.fake_quant_with_min_max_args(
in_input,
min=-100,
max=100,
name="q_input")

diagonal = np.random.uniform(0, 256, diagonal_shape).astype('uint8')
in_diagonal = tf.placeholder(dtype='float32', shape=diagonal.shape, name="diagonal")
inq_diagonal = tf.quantization.fake_quant_with_min_max_args(
in_diagonal,
min=-100,
max=100,
name="q_diagonal")

input_range = {'q_input': (-100, 100), 'q_diagonal': (-100, 100)}

out = array_ops.matrix_set_diag(inq_input, inq_diagonal)
out = tf.quantization.fake_quant_with_min_max_args(
out,
min=-100,
max=100,
name="out")

compare_tflite_with_tvm(
[input, diagonal],
["q_input", "q_diagonal"],
[inq_input, inq_diagonal],
[out],
quantized=True,
input_range=input_range)
else:
input = np.random.uniform(0, 100, input_shape).astype(input_type)
diagonal = np.random.uniform(0, 100, diagonal_shape).astype(input_type)

in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal")

out = array_ops.matrix_set_diag(in_input, in_diagonal)

compare_tflite_with_tvm(
[input, diagonal],
["input", "diagonal"],
[in_input, in_diagonal],
[out])

def test_forward_matrix_set_diag():
""" MATRIX_SET_DIAG """
for dtype in [np.float32, np.int32]:
_test_matrix_set_diag((4, 4), dtype)
_test_matrix_set_diag((5, 4, 3, 4), dtype)
_test_matrix_set_diag((4, 4, 2), dtype)

_test_matrix_set_diag((4, 4), np.uint8, quantized=True)
_test_matrix_set_diag((5, 4, 3, 4), np.uint8, quantized=True)
_test_matrix_set_diag((4, 4, 2), np.uint8, quantized=True)


#######################################################################
# Custom Operators
# ----------------
Expand Down Expand Up @@ -3131,6 +3202,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_arg_min_max()
test_forward_expand_dims()
test_forward_reverse_v2()
test_forward_matrix_set_diag()

# NN
test_forward_convolution()
Expand Down
Loading

0 comments on commit 082f27e

Please sign in to comment.