-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay/TOPI][TFLite] Implemented MATRIX_SET_DIAG Operator for Relay/TOPI and TFLite Frontend. #6303
Conversation
jainris
commented
Aug 19, 2020
- Added implementation for MATRIX_SET_DIAG Operator in Relay and TOPI.
- Added tests for MATRIX_SET_DIAG Operator in Relay and TOPI.
- Added implementation for MATRIX_SET_DIAG Operator for TFLite Frontend.
- Added tests for MATRIX_SET_DIAG Operator for TFLite Frontend.
a3bcb34
to
adb905e
Compare
adb905e
to
d9bc1f3
Compare
auto min_dim = if_then_else(input->shape[d_ndims - 1] >= input->shape[d_ndims], | ||
input->shape[d_ndims], input->shape[d_ndims - 1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is if_then_else
appropriate here? Could a ? x : y
not be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing.
if_then_else
is needed here because input->shape[i]
is a PrimExpr
, and so a ? x : y
can't be used.
diagonal_shape = list(input_shape[:-2]) | ||
diagonal_shape.append(min(input_shape[-2], input_shape[-1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the broadcasting case be tested here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFLite MATRIX_SET_DIAG doesn't seem to be a broadcast operator. So, I'll change the registration to be injective.
Also cc @siju-samuel |
src/relay/op/tensor/transform.cc
Outdated
.set_support_level(10) | ||
.add_type_rel("MatrixSetDiag", MatrixSetDiagRel) | ||
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute) | ||
.set_attr<TOpPattern>("TOpPattern", kBroadcast); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why kBroadcast? i think it shud be injective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing.
Changed it to be injective.
[out]) | ||
|
||
def test_forward_matrix_set_diag(): | ||
""" MATRIX_SET_DIAG """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a pkg version check > '1.14.0'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API docs seem to suggest that matrix_set_diag is present even in version '1.0'.
So, is there some other reason to add this check?
* \param tag output tensor tag. | ||
* \return new tensor with given diagonal values. | ||
*/ | ||
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A suggestion:- may be if we can support alignment
and k
(offset) similar to MatrixSetDiagV3
in tf, it will be good. we can support directly for tensorflow ops as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might take some time.
Would it be fine to have that in a follow-up PR?
Thanks @jainris @siju-samuel @mbaret it is merged. For the alignment, wish @jainris could follow up, thanks! |
…OPI and TFLite Frontend. (apache#6303) * Corrected docstring error. * Minor changes. * Changed MATRIX_SET_DIAG registration from broadcast to injective.
…OPI and TFLite Frontend. (apache#6303) * Corrected docstring error. * Minor changes. * Changed MATRIX_SET_DIAG registration from broadcast to injective.
…OPI and TFLite Frontend. (apache#6303) * Corrected docstring error. * Minor changes. * Changed MATRIX_SET_DIAG registration from broadcast to injective.