-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Added large tensor support and test for gather_nd #16371
Conversation
@mxnet-label-bot add [pr-work-in-progress] |
92fcb15
to
a751c02
Compare
a751c02
to
3d93d25
Compare
@mxnet-label-bot add [pr-awaiting-review] |
@ChaiBapchya @apeforest This PR is ready for review |
if sys.version_info[0] > 2 and _int64_enabled(): | ||
idx_dtype = 'int64' | ||
else: | ||
idx_dtype = 'int32' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performing this check since idx can be an instance of NDArray, list, tuple, slice(python), integer_type etc. Also, earlier idx was hard coded to be 'int32' therefore with Large Tensor enabled
it makes sense to use only 'int64'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we measured runtime difference by adding this check in ndarray.py? Can we some simple test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ran test_ndarray.py on my branch and on master. Both take 20-21 secs. Master is taking 1-2 secs more than my change.
src/operator/tensor/indexing_op.h
Outdated
int offset = 0; | ||
for (int j = 0; j < M; ++j) { | ||
index_t offset = 0; | ||
for (index_t j = 0; j < M; ++j) { | ||
offset += strides[j] * static_cast<int>(indices[j*N + i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would static_cast cause the same overflow problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad. it should be cast to index_t
@@ -1212,6 +1212,15 @@ def test_full(): | |||
assert a[-1][-1] == 3 | |||
|
|||
|
|||
def test_gather(): | |||
arr = mx.nd.ones((LARGE_X, SMALL_Y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need nested parentheses right??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do need parentheses
mx.nd.ones(3,4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/bapac/workspace/transpose/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3153, in ones
return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
File "<string>", line 36, in _ones
File "/Users/bapac/workspace/transpose/incubator-mxnet/python/mxnet/_ctypes/ndarray.py", line 100, in _imperative_invoke
ctypes.byref(out_stypes)))
File "/Users/bapac/workspace/transpose/incubator-mxnet/python/mxnet/base.py", line 254, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [11:15:17] ../include/mxnet/./base.h:526: Invalid context string 4
tests/nightly/test_large_array.py
Outdated
@@ -1212,6 +1212,15 @@ def test_full(): | |||
assert a[-1][-1] == 3 | |||
|
|||
|
|||
def test_gather(): | |||
arr = mx.nd.ones((LARGE_X, SMALL_Y)) | |||
idx = mx.nd.random.randint(0, LARGE_X, (SMALL_X), dtype=np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove parenthesis around SMALL_X
3d93d25
to
44f7b72
Compare
@sxjscience @apeforest @anirudh2290 @zheng-da @pengzhao-intel This PR is ready for review |
bd9a7a6
to
51c5386
Compare
` Ran 1 test in 3.056s OK |
tests/nightly/test_large_array.py
Outdated
@@ -1199,6 +1199,15 @@ def test_full(): | |||
assert a[-1][-1] == 3 | |||
|
|||
|
|||
def test_gather(): | |||
arr = mx.nd.ones((LARGE_X, SMALL_Y)) | |||
idx = mx.nd.random.randint(0, LARGE_X, SMALL_X, dtype=np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comment here on the test will help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. added test on large vector as well which i missed in previous commit. Thanks for pointing that out
51c5386
to
66f31c9
Compare
9dc433f
to
e27758a
Compare
e27758a
to
e6f8d59
Compare
Description
changed the operator code to use
index_t
instead ofint
Checklist
Essentials
Testing
Currently running full test suite. Will update once done.