Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][MXNET][FRONTEND] add support for MXNET numpy operators #6054

Merged
merged 19 commits into from
Aug 21, 2020

Conversation

sandyhu533
Copy link
Contributor

@sandyhu533 sandyhu533 commented Jul 14, 2020

add support for some MXNET numpy operators, related issue: dmlc/gluon-nlp#1244
For now, these operators have been implemented

_npi_transpose
_npi_pad
_npi_multiply_scalar
_npi_true_divide_scalar
_npi_add
_npi_concatenate
_npi_multiply
_np_copy
_npi_tanh
_npi_power_scalar
_npi_less
_npi_add_scalar
_npx_reshape
_split_v2
_npi_where_rscalar


def verify(data_shape, scalar):
cond_np = np.random.uniform(size=data_shape).astype("bool")
data_np = np.random.uniform(size=data_shape).astype("float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to test for multiple dtypes. Especially for cond_np, we can have bool or float32.

@@ -1350,7 +1351,7 @@ def verify(batch, seq_length, num_heads, head_dim):
verify(1, 10, 4, 16)
verify(3, 10, 6, 8)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the change

pytest.skip("mx.sym.np.pad hasn't been publish yet")

def verify(data_shape, mode, pad_width, constant_value=0.0):
for dtype in dtype_list:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be easier to use @pytest.mark.parametrize?

test_forward_npi_binary_scalar()
test_forward_npi_tanh()
test_forward_npi_where_rscalar()
test_forward_split_v2()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can just do pytest.main([__file__])


def test_forward_npi_pad():
if not hasattr(mx.sym.np, 'pad'):
pytest.skip("mx.sym.np.pad hasn't been publish yet")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to update CI's mxnet version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid we couldn't update CI's mxnet version now. The latest version of mxnet has made some changes that the test functions in tvm didn't updated with them. So the CI would fail if we update mxnet version now.
This is the error message when I try to run pytest for tests/python/frontend/mxnet/test_forward.py.
E AttributeError: module 'mxnet' has no attribute 'mod'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to find a good way to support both mx 1.x and 2.0.

For marking the test, maybe use @pytest.mark.skipif ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use two separate pipelines to test with the different mxnet versions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip ci for now sounds good

@tqchen
Copy link
Member

tqchen commented Jul 21, 2020

cc @antinucleon @junrushao1994

mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
for target, ctx in ctx_list():
for kind in ["debug"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this one using debug only?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can also use @pytest.mark.parametrize for ctx_list and kind.

verify((3, 2, 1), axis=0, indices_or_sections=3, squeeze_axis=True)
verify((3, 2, 1), axis=0, indices_or_sections=(1, 2))


if __name__ == '__main__':
test_forward_mlp()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove all the test_xxx function calls and add pytest.main([__file__])

@yzhliu
Copy link
Member

yzhliu commented Jul 30, 2020

Please check the ci failure, otherwise good to me.

@yzhliu yzhliu merged commit 4c728d5 into apache:master Aug 21, 2020
@yzhliu
Copy link
Member

yzhliu commented Aug 21, 2020

Thanks @sandyhu533 @sxjscience

@sxjscience
Copy link
Member

Thanks! Cannot wait to try it in GluonNLP 👍

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…e#6054)

* [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet

* Update test_forward.py

* Update mxnet.py

* Update mxnet.py

* Update test_forward.py

* update and bugfix

* test for multiple dtypes

* Update test_forward.py

* add data type and optimize coding style

* replace pytest.skip with @pytest.mark.skipif

* Update test_forward.py

* update pytest style

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-169.ap-northeast-1.compute.internal>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…e#6054)

* [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet

* Update test_forward.py

* Update mxnet.py

* Update mxnet.py

* Update test_forward.py

* update and bugfix

* test for multiple dtypes

* Update test_forward.py

* add data type and optimize coding style

* replace pytest.skip with @pytest.mark.skipif

* Update test_forward.py

* update pytest style

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-169.ap-northeast-1.compute.internal>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…e#6054)

* [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet

* Update test_forward.py

* Update mxnet.py

* Update mxnet.py

* Update test_forward.py

* update and bugfix

* test for multiple dtypes

* Update test_forward.py

* add data type and optimize coding style

* replace pytest.skip with @pytest.mark.skipif

* Update test_forward.py

* update pytest style

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-169.ap-northeast-1.compute.internal>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Sep 2, 2020
…e#6054)

* [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet

* Update test_forward.py

* Update mxnet.py

* Update mxnet.py

* Update test_forward.py

* update and bugfix

* test for multiple dtypes

* Update test_forward.py

* add data type and optimize coding style

* replace pytest.skip with @pytest.mark.skipif

* Update test_forward.py

* update pytest style

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-169.ap-northeast-1.compute.internal>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 3, 2020
…e#6054)

* [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet

* Update test_forward.py

* Update mxnet.py

* Update mxnet.py

* Update test_forward.py

* update and bugfix

* test for multiple dtypes

* Update test_forward.py

* add data type and optimize coding style

* replace pytest.skip with @pytest.mark.skipif

* Update test_forward.py

* update pytest style

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-169.ap-northeast-1.compute.internal>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants