Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-381] Enhancement of take operator #11326

Merged
merged 3 commits into from
Jul 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ struct clip : public mxnet_op::tunable {
return x;
}
}
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) {
if (x > upper_bound) {
return upper_bound;
} else if (x < lower_bound) {
return lower_bound;
}
return x;
}
};

/***** gamma ******/
Expand Down
4 changes: 3 additions & 1 deletion src/operator/tensor/indexing_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_scan.cuh>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "./util/tensor_util-inl.cuh"

#if CUDA_VERSION >= 9000
#define FULLMASK 0xFFFFFFFF
Expand Down Expand Up @@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace) {
mshadow::Tensor<gpu, 1, char>* workspace = NULL) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: NULL -> nullptr. NULL has more semantic meanings than nullptr and should be deprecated in C++11.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will change.

CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(sorted.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
Expand Down
29 changes: 20 additions & 9 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,36 +367,46 @@ NNVM_REGISTER_OP(take)

This function slices the input array along a particular axis with the provided indices.

Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output
will have shape ``(i0, i1, d1, d2)``, computed by::

output[i,j,:,:] = input[indices[i,j],:,:]

.. note::
- `axis`- Only slicing along axis 0 is supported for now.
- `mode`- Only `clip` mode is supported for now.
Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis
dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them
in an output tensor of rank q + (r - 1).

Examples::
x = [4. 5. 6.]

// Trivial case, take the second element along the first axis.

take(x, [1]) = [ 5. ]

// The other trivial case, axis=-1, take the third element along the first axis

take(x, [3], axis=-1, mode='clip') = [ 6. ]

x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]

// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0

take(x, [[0,1],[1,2]]) = [[[ 1., 2.],
[ 3., 4.]],

[[ 3., 4.],
[ 5., 6.]]]

// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around).
// Along axis 1

take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.],
[ 3., 4.]],

[[ 3., 4.],
[ 5., 6.]]]

)code" ADD_FILELINE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output This only holds true for axis =0 right ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update that doc.

.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(TakeParamParser<TakeParam>)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
Expand All @@ -420,6 +430,7 @@ Examples::
NNVM_REGISTER_OP(_backward_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down
Loading