Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] Support zero-dim and zero-size tensors in MXNet #14661

Merged
merged 32 commits into from
Apr 16, 2019

Conversation

reminisce
Copy link
Contributor

@reminisce reminisce commented Apr 10, 2019

Description

This PR provides the infrastructure of supporting zero-dim and zero-size tensors as the first outcome of the initiative of introducing NumPy compatible coding experience into MXNet (see this RFC). Great thanks to many folks who have contributed design, implementation and code review to this feature. The credits should go to all of them. Sorry if I missed anyone below. It would be impossible to make so many things correct scattered all over the codebase with just a couple of hands.

FAQ

What are zero-dim and zero-size tensors?

  • Zero-dim tensors are scalars with shape equal to ().
  • A zero-size tensor is the one whose shape has at least one dimension size as 0. For example, np.ones((1, 0, 3)) would generate a zero-size tensor with shape equal to (1, 0, 3).

Why are they important?

Mathematically speaking, their presence keeps the completeness and consistency of the all the tensor operations. For example, given x = mx.nd.array([1, 2, 3]), x[0] should return 1, instead of [1], and x[0:0] should return [] with shape equal to (0,). Zero-size tensors can also be convenient in easing the code logic as placeholders for accumulations or aggregations.

I find this thread provides very good insights on the importance of zero-dim tensors.

How are they implemented?

In the backend (C++), we use ndim = -1 to represent unknown shapes and dim_size = -1 for unknown dim sizes. Before, they were represented by 0.

Is backward compatibility guaranteed in this PR?

Yes, the backward compatibility is guaranteed by default. That means 0 still represents unknown ndim or dim size in any frontend language bindings. It's just converted to -1 in the infer shape logic of the backend.

How to enable zero-dim or zero-size tensors?

Since we are committed to keep backward compatibility, we provided APIs for users to decide whether to opt in for this NumPy compatibility. Users can call mx.set_np_compat(True) to opt in and mx.set_np_compat(False) to opt out. Or in a safer way, use the with statement. To turn on this:

 with mx.np_compat(active=True):
     # A scalar tensor's shape is `()`, whose `ndim` is `0`.
     scalar = mx.nd.ones(shape=())
     assert scalar.shape == ()
 
     # In NumPy compatible mode, 0 in a shape means that dimension contains zero elements.
     data = mx.sym.var("data", shape=(0, 2, 3))
     ret = mx.sym.sin(data)
     arg_shapes, out_shapes, _ = ret.infer_shape()
     assert arg_shapes[0] == (0, 2, 3)
     assert out_shapes[0] == (0, 2, 3)
 
     # -1 means unknown shape dimension size in the new NumPy-compatible shape definition
     data = mx.sym.var("data", shape=(-1, 2, 3))
     ret = mx.sym.sin(data)
     arg_shapes, out_shapes, _ = ret.infer_shape_partial()
     assert arg_shapes[0] == (-1, 2, 3)
     assert out_shapes[0] == (-1, 2, 3)
 
     # When a shape is completely unknown in NumPy-compatible mode, it is
     # represented as `None` in Python.
     data = mx.sym.var("data")
     ret = mx.sym.sin(data)
     arg_shapes, out_shapes, _ = ret.infer_shape_partial()
     assert arg_shapes[0] is None
     assert out_shapes[0] is None

or to disable this:

 with mx.np_compat(active=False):
     # 0 means unknown shape dimension size in the legacy shape definition.
     data = mx.sym.var("data", shape=(0, 2, 3))
     ret = mx.sym.sin(data)
     arg_shapes, out_shapes, _ = ret.infer_shape_partial()
     assert arg_shapes[0] == (0, 2, 3)
     assert out_shapes[0] == (0, 2, 3)
 
     # When a shape is completely unknown in the legacy mode (default), its ndim is
     # equal to 0 and it is represented as `()` in Python.
     data = mx.sym.var("data")
     ret = mx.sym.sin(data)
     arg_shapes, out_shapes, _ = ret.infer_shape_partial()
     assert arg_shapes[0] == ()
     assert out_shapes[0] == ()

Does this mean that every existing operator should support zero-dim or zero-size tensors in this PR?

Please note that the existing operators were implemented when these two types of tensors were not supported in MXNet. Some strong assumptions may have been made in their implementation and hence, lead to errors when NumPy compatibility is turned on. This PR only provides the infrastructure of supporting zero-dim and zero-size tensors in the backend. It does not guarantee that every existing operator would deal with zero-dim/size tensors correctly as in NumPy. As discussed in the RFC, we are going to implement NumPy operators under mxnet.numpy that would fully support these two types of tensors.

@wkcn
Copy link
Member

wkcn commented Apr 10, 2019

Great work! Thank you!

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed all functions in *.py files, they look good to me

@reminisce reminisce force-pushed the final_merge_numpy_to_master branch 3 times, most recently from 5cf1517 to 1b6132d Compare April 13, 2019 21:02
@reminisce
Copy link
Contributor Author

@hetong007 @yzhliu @sergeykolychev Could you please help review the changes in R, Scala, and Perl respectively? We would like to merge this PR sooner than later since our subsequent work depends on this and it's becoming harder to rebase.

@szha
Copy link
Member

szha commented Apr 14, 2019

@jeremiedb @lanking520 could you help out on this for reviews too? Thanks!

@yzhliu
Copy link
Member

yzhliu commented Apr 14, 2019

LGTM

cpp-package/include/mxnet-cpp/symbol.hpp Outdated Show resolved Hide resolved
include/mxnet/tuple.h Show resolved Hide resolved
include/mxnet/tuple.h Show resolved Hide resolved
include/mxnet/tuple.h Show resolved Hide resolved
include/mxnet/tuple.h Outdated Show resolved Hide resolved
python/mxnet/base.py Outdated Show resolved Hide resolved
Copy link
Member

@lanking520 lanking520 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some comments on the changes for Scala part.

@jeremiedb
Copy link
Contributor

@hetong007 Were you seing a need for adding integrating a set_np_compat option in the R-package? Otherwise, LGTM.

@TaoLv
Copy link
Member

TaoLv commented Apr 15, 2019

Reviewed changes for mkldnn files. @juliusshufan you might want to have a performance validation on this PR.

@hetong007
Copy link
Contributor

hetong007 commented Apr 15, 2019

@hetong007 Were you seing a need for adding integrating a set_np_compat option in the R-package? Otherwise, LGTM.

@jeremiedb Would you elaborate more on this need? I think the current R API maps nd.array to native R array so this PR has minimum effect on R-related uses.

reminisce and others added 10 commits April 15, 2019 11:33
* Support scalar and zero-size tensors with np.sum

* Add sanity check when ndim is set
* Init checkin

* Fix ndarray alloc bug

* Use TShape(0) as default empty tuple params

* Fix bugs

* Fix TShape init value

* Fix infer shape pass shape type and reshape infer shape func
…che#14487)

* Fix infer shape rnn

* Fix boolean mask and custom op unit tests

* Fix multi proposal

* Fix diag

* Add global switch for backward compatibility and fix infer shape bugs

* Fix slice op infer shape

* Fix rnn infer shape

* Add util funcs for ndim_is_known and dim_size_is_known

* Revert rnn_cell.py
* fix.

* remove type conversion.

* remove type cast.
* Initial commit

* Address comments from Jun
* Fix several test failures

* Fix subgraph op infer shape

* Fix sparse slice

* Fix deconv infer shape

* Fix numpy import compatibility problem in python2
@anirudh2290
Copy link
Member

Thanks for the decorator! I think customers may not understand the internal workings of custom op and the dispatch to a different thread workflow, and run into this pitfall. Can we synchronize the state of is_numpy_compat when pushing the op to custom op thread. Custom op is one area where customers use numpy operations and I feel there may be users running into this.

@reminisce
Copy link
Contributor Author

@anirudh2290 That's a good point from users' perspective. But this may not be a trivial implementation by propagating the np_compat from main thread to worker thread as potentially there are other pitfalls. I discussed this @eric-haibin-lin, maybe we can attach a compatibility state to the generated custom op. This can become a follow-up implementation in another PR. This PR does not need to be blocked by this.

@szha
Copy link
Member

szha commented Apr 16, 2019

@anirudh2290 good points on custom op. Given the size of this PR and the fact that users need to opt-in to be affected by this change, I'd recommend we address these points in a follow-up PR, so that @reminisce won't have to keep full-time resolving conflicts :)

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python, C API, and operator changes LGTM

@anirudh2290
Copy link
Member

I am okay with making custom op changes later. just to be sure, the the is_numpy_compat thread local variable is set in the main thread, and is expected to be accessed only from main thread ? For example, this wont cause any issue the multi threaded inference interface provided by frontends like Scala ?

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@reminisce
Copy link
Contributor Author

@anirudh2290 Thanks for the review. In practice, np_compat needs to be set in the same thread as the one invoking shape inference functions, because that's where zero-dim and zero-size shapes are treated as unknown. In your example, the custom op's forward function invokes mx.nd.ones whose shape inference function was executed on the worker thread, and hence leads to failure. We are going to move all ops towards the direction of being numpy compatible, and that should include the custom op, where the forward/backward functions should always be scoped by the numpy-compatible state, and this problem can be resolved without asking users to use the decorator.

I'm not familiar with how scala multi-threaded inference is implemented, but as long as the np_compat is set in the same thread as the one invoking shape inference functions, the result should be as expected.

Copy link
Member

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified performance and accuracy for several CNN models with MKL-DNN backend. They look good to me.

@szha szha merged commit 3f3ba92 into apache:master Apr 16, 2019
kedarbellare pushed a commit to kedarbellare/incubator-mxnet that referenced this pull request Apr 20, 2019
* [numpy] Shape support scalar tensor (apache#14315)

* Support scalar and zero-size tensors with np.sum

* Add sanity check when ndim is set

* [Numpy] Change semantics of ndim for operators in `src/operator/contrib` (apache#14409)

* Initial commit

* Address comments

* [WIP] Use new shape definition (apache#14453)

* Init checkin

* Fix ndarray alloc bug

* Use TShape(0) as default empty tuple params

* Fix bugs

* Fix TShape init value

* Fix infer shape pass shape type and reshape infer shape func

* [numpy] Fix unit tests after introducing numpy compatible shapes (apache#14487)

* Fix infer shape rnn

* Fix boolean mask and custom op unit tests

* Fix multi proposal

* Fix diag

* Add global switch for backward compatibility and fix infer shape bugs

* Fix slice op infer shape

* Fix rnn infer shape

* Add util funcs for ndim_is_known and dim_size_is_known

* Revert rnn_cell.py

* Fix a bug to pass the test in test_contrib_rnn (apache#14520)

* fix.

* remove type conversion.

* remove type cast.

* [numpy] Fix test_dynamic_shape.test_dynamic_shape (apache#14538)

* Initial commit

* Address comments from Jun

* [numpy] Fix numpy import in python2 (apache#14537)

* Fix several test failures

* Fix subgraph op infer shape

* Fix sparse slice

* Fix deconv infer shape

* Fix numpy import compatibility problem in python2

* fix concat and slice (apache#14549)

* fix R-package (apache#14536)

* Fix cpp package build after using new shape definition (apache#14554)

* Fix pooling_v1 and deformable_convolution param initialization (apache#14577)

* Fix pooling_v1 param initialization

* Fix deformable_convolution param initialization

* [Numpy] Misc fix (apache#14612)

* [Numpy] Misc Fix

* fix build

* !shape_is_none => shape_is_known

* Address comments

* Fix

* [Numpy] fix test_operator_gpu.test_upsampling_bilinear_with_type (apache#14557)

* Fix test_operator_gpu.test_upsampling_bilinear_with_type

* Address comments

* [Numpy] Java/Scala modification (apache#14625)

* modify jni to support 0 dim/shape

* fix transpose axes default value

* fix shape index bug (apache#14630)

* fix jni lint (apache#14634)

* [numpy] Fix numpy branch failing tests in CI (apache#14639)

* Remove numpy namespaces for operator registration

* Fix bug when shape is compeltely unknown

* Fix singed/unsigned compare warning

* Fix CI

* Fix pylint

* Avoid launching gpu kernels for zero-size output tensors

* Fix test_ndarray

* Fix binary broadcast with zero-size tensors

* Better error message for infer shape failure in imperative

* Fix TShape constructor ambiguity on certain platforms

* Fix mkldnn build failure

* Fix build failure in gpu and cpp test

* Fix gpu cpp test build with mkldnn

* Fix mkldnn cpp test

* Fix concatenating zero-size tensors

* Avoid letting mkldnn handle zero-size tensors in concat

* Fix quantized_concat infer shape

* Try to fix perl c api

* fix invalid ndarray dispose (apache#14657)

* swig fixes for the changes in c_api.h (apache#14655)

* Rename np_comp to np_compat for readability

* Fix import error

* Keep old c apis unchanged

* Fix lint

* Rebase and fix build

* Fix R build failure

* Fix Perl build failure

* Rebase with master

* Address cr comments

* Use just one scope to represent numpy compatibility

* Add code comment to NumpyScope object in Scala

* Add use_np_compat decorator

* Fix pylint
nickguletskii added a commit to nickguletskii/incubator-mxnet that referenced this pull request Apr 30, 2019
Changes the implementation of index_array to be compatible with the
recently merged support for zero-dim and zero-size arrays. Resolves the
incompatibilities with apache#14661.
nickguletskii added a commit to nickguletskii/incubator-mxnet that referenced this pull request May 22, 2019
Changes the implementation of index_array to be compatible with the
recently merged support for zero-dim and zero-size arrays. Resolves the
incompatibilities with apache#14661.
szha pushed a commit that referenced this pull request May 25, 2019
* Implement the index_array operator

* Add index_array operator tests

* Add index_array operator GPU tests

* Add the index_array operator to the Python docs autosummary

* Add the author of the index_array operator to CONTRIBUTORS.md

* Make index_array compatible with zero-dim and zero-size arrays

Changes the implementation of index_array to be compatible with the
recently merged support for zero-dim and zero-size arrays. Resolves the
incompatibilities with #14661.

* Fix the index_array gradient checks in the unit tests

In the previous implementation, the output gradient had an incorrect
shape. This commit fixes the shapes and makes the tests more readable.

* Add zero-dim and zero-size array tests for index_array

* Use mxnet::Tuple<int> instead of TShape for the axes parameter

* Fix incorrect array indexing in index_array

Solves access violations when compiling with MSVC++ 14.0.

* Avoid copying the input shape array in the index_array shape function

* Add unknown shape handling to index_array

* Use SHAPE_ASSIGN_CHECK to assign the shape in index_array

* Remove the redundant index_array GPU tests from test_operator_gpu.py

* Move the index_array tests into a single function (test_index_array)

* Use @mx.use_np_compat instead of mx.np_compat in index_array op tests

* Remove the use of template specialization for IndexArrayForward

* Add the index_array operator to the AMP symbol list

* Retrigger CI
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* [numpy] Shape support scalar tensor (apache#14315)

* Support scalar and zero-size tensors with np.sum

* Add sanity check when ndim is set

* [Numpy] Change semantics of ndim for operators in `src/operator/contrib` (apache#14409)

* Initial commit

* Address comments

* [WIP] Use new shape definition (apache#14453)

* Init checkin

* Fix ndarray alloc bug

* Use TShape(0) as default empty tuple params

* Fix bugs

* Fix TShape init value

* Fix infer shape pass shape type and reshape infer shape func

* [numpy] Fix unit tests after introducing numpy compatible shapes (apache#14487)

* Fix infer shape rnn

* Fix boolean mask and custom op unit tests

* Fix multi proposal

* Fix diag

* Add global switch for backward compatibility and fix infer shape bugs

* Fix slice op infer shape

* Fix rnn infer shape

* Add util funcs for ndim_is_known and dim_size_is_known

* Revert rnn_cell.py

* Fix a bug to pass the test in test_contrib_rnn (apache#14520)

* fix.

* remove type conversion.

* remove type cast.

* [numpy] Fix test_dynamic_shape.test_dynamic_shape (apache#14538)

* Initial commit

* Address comments from Jun

* [numpy] Fix numpy import in python2 (apache#14537)

* Fix several test failures

* Fix subgraph op infer shape

* Fix sparse slice

* Fix deconv infer shape

* Fix numpy import compatibility problem in python2

* fix concat and slice (apache#14549)

* fix R-package (apache#14536)

* Fix cpp package build after using new shape definition (apache#14554)

* Fix pooling_v1 and deformable_convolution param initialization (apache#14577)

* Fix pooling_v1 param initialization

* Fix deformable_convolution param initialization

* [Numpy] Misc fix (apache#14612)

* [Numpy] Misc Fix

* fix build

* !shape_is_none => shape_is_known

* Address comments

* Fix

* [Numpy] fix test_operator_gpu.test_upsampling_bilinear_with_type (apache#14557)

* Fix test_operator_gpu.test_upsampling_bilinear_with_type

* Address comments

* [Numpy] Java/Scala modification (apache#14625)

* modify jni to support 0 dim/shape

* fix transpose axes default value

* fix shape index bug (apache#14630)

* fix jni lint (apache#14634)

* [numpy] Fix numpy branch failing tests in CI (apache#14639)

* Remove numpy namespaces for operator registration

* Fix bug when shape is compeltely unknown

* Fix singed/unsigned compare warning

* Fix CI

* Fix pylint

* Avoid launching gpu kernels for zero-size output tensors

* Fix test_ndarray

* Fix binary broadcast with zero-size tensors

* Better error message for infer shape failure in imperative

* Fix TShape constructor ambiguity on certain platforms

* Fix mkldnn build failure

* Fix build failure in gpu and cpp test

* Fix gpu cpp test build with mkldnn

* Fix mkldnn cpp test

* Fix concatenating zero-size tensors

* Avoid letting mkldnn handle zero-size tensors in concat

* Fix quantized_concat infer shape

* Try to fix perl c api

* fix invalid ndarray dispose (apache#14657)

* swig fixes for the changes in c_api.h (apache#14655)

* Rename np_comp to np_compat for readability

* Fix import error

* Keep old c apis unchanged

* Fix lint

* Rebase and fix build

* Fix R build failure

* Fix Perl build failure

* Rebase with master

* Address cr comments

* Use just one scope to represent numpy compatibility

* Add code comment to NumpyScope object in Scala

* Add use_np_compat decorator

* Fix pylint
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* Implement the index_array operator

* Add index_array operator tests

* Add index_array operator GPU tests

* Add the index_array operator to the Python docs autosummary

* Add the author of the index_array operator to CONTRIBUTORS.md

* Make index_array compatible with zero-dim and zero-size arrays

Changes the implementation of index_array to be compatible with the
recently merged support for zero-dim and zero-size arrays. Resolves the
incompatibilities with apache#14661.

* Fix the index_array gradient checks in the unit tests

In the previous implementation, the output gradient had an incorrect
shape. This commit fixes the shapes and makes the tests more readable.

* Add zero-dim and zero-size array tests for index_array

* Use mxnet::Tuple<int> instead of TShape for the axes parameter

* Fix incorrect array indexing in index_array

Solves access violations when compiling with MSVC++ 14.0.

* Avoid copying the input shape array in the index_array shape function

* Add unknown shape handling to index_array

* Use SHAPE_ASSIGN_CHECK to assign the shape in index_array

* Remove the redundant index_array GPU tests from test_operator_gpu.py

* Move the index_array tests into a single function (test_index_array)

* Use @mx.use_np_compat instead of mx.np_compat in index_array op tests

* Remove the use of template specialization for IndexArrayForward

* Add the index_array operator to the AMP symbol list

* Retrigger CI
}
ReverseReshapeInferShape(&dshape, oshape);
#if 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce is this #if 0 intentional? Will it be removed eventually?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.