forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MXNET-1382] Add the index_array operator (apache#14638)
* Implement the index_array operator * Add index_array operator tests * Add index_array operator GPU tests * Add the index_array operator to the Python docs autosummary * Add the author of the index_array operator to CONTRIBUTORS.md * Make index_array compatible with zero-dim and zero-size arrays Changes the implementation of index_array to be compatible with the recently merged support for zero-dim and zero-size arrays. Resolves the incompatibilities with apache#14661. * Fix the index_array gradient checks in the unit tests In the previous implementation, the output gradient had an incorrect shape. This commit fixes the shapes and makes the tests more readable. * Add zero-dim and zero-size array tests for index_array * Use mxnet::Tuple<int> instead of TShape for the axes parameter * Fix incorrect array indexing in index_array Solves access violations when compiling with MSVC++ 14.0. * Avoid copying the input shape array in the index_array shape function * Add unknown shape handling to index_array * Use SHAPE_ASSIGN_CHECK to assign the shape in index_array * Remove the redundant index_array GPU tests from test_operator_gpu.py * Move the index_array tests into a single function (test_index_array) * Use @mx.use_np_compat instead of mx.np_compat in index_array op tests * Remove the use of template specialization for IndexArrayForward * Add the index_array operator to the AMP symbol list * Retrigger CI
- Loading branch information
1 parent
0d41135
commit fe046bf
Showing
8 changed files
with
430 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ | ||
#define MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ | ||
|
||
#include <vector> | ||
#include <utility> | ||
#include "../mshadow_op.h" | ||
#include "../tensor/init_op.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
namespace index_array_enum { | ||
enum IndexArrayOpInputs {kIn}; | ||
enum IndexArrayOpOutputs {kOut}; | ||
enum IndexArrayOpResource {kTempSpace}; | ||
} // namespace index_array_enum | ||
|
||
template<int req> | ||
struct IndexArrayKernel { | ||
MSHADOW_XINLINE static void Map(int i, | ||
int64_t* out_data, | ||
const int n, | ||
const int64_t* workspace) { | ||
for (ptrdiff_t j = 0; j < n; j++) { | ||
int64_t upper = workspace[ptrdiff_t(2) * j]; | ||
int64_t lower = workspace[ptrdiff_t(2) * j + ptrdiff_t(1)]; | ||
KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(n) + j], req, (i % upper) / lower); | ||
} | ||
} | ||
}; | ||
|
||
template<int req> | ||
struct IndexArrayDefaultKernel { | ||
MSHADOW_XINLINE static void Map(int i, | ||
int64_t* out_data, | ||
const int ndim, | ||
const dim_t* shape) { | ||
int64_t index = i; | ||
for (ptrdiff_t j = ndim - 1; j >= 0; j--) { | ||
KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(ndim) + j], req, index % shape[j]); | ||
index /= shape[j]; | ||
} | ||
} | ||
}; | ||
|
||
inline std::vector<int64_t> IndexArrayComputeIndexProducts(const TShape &inshape) { | ||
const int ndim = inshape.ndim(); | ||
|
||
std::vector<int64_t> index_products(static_cast<size_t>(ndim + 1)); | ||
|
||
index_products[ndim] = 1; | ||
|
||
for (int i = ndim - 1; i >= 0; i--) { | ||
index_products[i] = index_products[i + 1] * inshape[i]; | ||
} | ||
|
||
return index_products; | ||
} | ||
|
||
inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple<int> &axes, | ||
const std::vector<int64_t> &index_products, | ||
int64_t* workspace, | ||
const int ndim) { | ||
for (int i = 0; i < axes.ndim(); i++) { | ||
// Make sure that the axis is between 0 and ndim. | ||
const int axis = ((axes[i] % ndim) + ndim) % ndim; | ||
|
||
workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis]; | ||
workspace[ptrdiff_t(2) * ptrdiff_t(i) + ptrdiff_t(1)] = index_products[axis + 1]; | ||
} | ||
} | ||
|
||
struct IndexArrayParam : public dmlc::Parameter<IndexArrayParam> { | ||
dmlc::optional<mxnet::Tuple<int>> axes; | ||
DMLC_DECLARE_PARAMETER(IndexArrayParam) { | ||
DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional<mxnet::Tuple<int>>()) | ||
.describe("The axes to include in the index array. Supports negative values."); | ||
} | ||
}; // struct IndexArrayParam | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include <mshadow/tensor.h> | ||
#include "./index_array-inl.h" | ||
|
||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
void IndexArrayForwardCPU(const nnvm::NodeAttrs &attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mshadow; | ||
CHECK_EQ(inputs.size(), 1U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
CHECK_EQ(req.size(), 1U); | ||
const TBlob& in_data = inputs[0]; | ||
const TBlob& out_data = outputs[0]; | ||
|
||
const IndexArrayParam& param = nnvm::get<IndexArrayParam>(attrs.parsed); | ||
|
||
const TShape inshape = in_data.shape_; | ||
const int ndim = inshape.ndim(); | ||
|
||
Stream<cpu> *stream = ctx.get_stream<cpu>(); | ||
|
||
using namespace mxnet_op; | ||
|
||
if (param.axes.has_value()) { | ||
const mxnet::Tuple<int>& axes = param.axes.value(); | ||
const int naxes = axes.ndim(); | ||
|
||
std::vector<int64_t> index_products = IndexArrayComputeIndexProducts(inshape); | ||
|
||
Tensor<cpu, 1, int64_t> workspace = | ||
ctx.requested[0].get_space_typed<cpu, 1, int64_t>(Shape1(2 * naxes), stream); | ||
|
||
IndexArrayBuildSelectedAxesWorkspace(axes, index_products, workspace.dptr_, ndim); | ||
|
||
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { | ||
Kernel<IndexArrayKernel<req_type>, cpu>::Launch(stream, in_data.Size(), | ||
out_data.dptr<int64_t>(), naxes, workspace.dptr_); | ||
}); | ||
} else { | ||
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { | ||
Kernel<IndexArrayDefaultKernel<req_type>, cpu>::Launch(stream, in_data.Size(), | ||
out_data.dptr<int64_t>(), ndim, inshape.data()); | ||
}); | ||
} | ||
} | ||
|
||
DMLC_REGISTER_PARAMETER(IndexArrayParam); | ||
|
||
NNVM_REGISTER_OP(_contrib_index_array) | ||
.describe(R"code(Returns an array of indexes of the input array. | ||
For an input array with shape :math:`(d_1, d_2, ..., d_n)`, `index_array` returns a | ||
:math:`(d_1, d_2, ..., d_n, n)` array `idx`, where | ||
:math:`idx[i_1, i_2, ..., i_n, :] = [i_1, i_2, ..., i_n]`. | ||
Additionally, when the parameter `axes` is specified, `idx` will be a | ||
:math:`(d_1, d_2, ..., d_n, m)` array where `m` is the length of `axes`, and the following | ||
equality will hold: :math:`idx[i_1, i_2, ..., i_n, j] = i_{axes[j]}`. | ||
Examples:: | ||
x = mx.nd.ones((3, 2)) | ||
mx.nd.contrib.index_array(x) = [[[0 0] | ||
[0 1]] | ||
[[1 0] | ||
[1 1]] | ||
[[2 0] | ||
[2 1]]] | ||
x = mx.nd.ones((3, 2, 2)) | ||
mx.nd.contrib.index_array(x, axes=(1, 0)) = [[[[0 0] | ||
[0 0]] | ||
[[1 0] | ||
[1 0]]] | ||
[[[0 1] | ||
[0 1]] | ||
[[1 1] | ||
[1 1]]] | ||
[[[0 2] | ||
[0 2]] | ||
[[1 2] | ||
[1 2]]]] | ||
)code" ADD_FILELINE) | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs &attrs) { | ||
return std::vector<std::string>{ "data" }; | ||
}) | ||
.set_attr<nnvm::FListOutputNames>("FListOutputNames", | ||
[](const NodeAttrs &attrs) { | ||
return std::vector<std::string>{ "output" }; | ||
}) | ||
.set_attr_parser(ParamParser<IndexArrayParam>) | ||
.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs &attrs, | ||
mxnet::ShapeVector *in_shape, | ||
mxnet::ShapeVector *out_shape) { | ||
const IndexArrayParam ¶m = nnvm::get<IndexArrayParam>(attrs.parsed); | ||
CHECK_EQ(in_shape->size(), 1U); | ||
CHECK_EQ(out_shape->size(), 1U); | ||
const mxnet::TShape &inshape = (*in_shape)[index_array_enum::kIn]; | ||
if (!mxnet::ndim_is_known(inshape)) return false; | ||
|
||
mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); | ||
|
||
for (int i = 0; i < inshape.ndim(); i++) { | ||
oshape[i] = inshape[i]; | ||
} | ||
if (param.axes.has_value()) { | ||
oshape[inshape.ndim()] = param.axes.value().ndim(); | ||
} else { | ||
oshape[inshape.ndim()] = inshape.ndim(); | ||
} | ||
|
||
SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); | ||
return shape_is_known(oshape); | ||
}) | ||
.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 1U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); | ||
return out_attrs->at(0) != -1; | ||
}) | ||
.set_attr<FCompute>("FCompute<cpu>", IndexArrayForwardCPU) | ||
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) | ||
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) { | ||
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; | ||
}) | ||
.add_argument("data", "NDArray-or-Symbol", "Input data") | ||
.add_arguments(IndexArrayParam::__FIELDS__()); | ||
|
||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include <mshadow/tensor.h> | ||
#include "./index_array-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
using namespace mshadow::cuda; | ||
|
||
void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mshadow; | ||
CHECK_EQ(inputs.size(), 1U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
CHECK_EQ(req.size(), 1U); | ||
const TBlob& in_data = inputs[0]; | ||
const TBlob& out_data = outputs[0]; | ||
|
||
const IndexArrayParam& param = nnvm::get<IndexArrayParam>(attrs.parsed); | ||
|
||
const TShape inshape = in_data.shape_; | ||
const int ndim = inshape.ndim(); | ||
|
||
Stream<gpu> *stream = ctx.get_stream<gpu>(); | ||
|
||
using namespace mxnet_op; | ||
|
||
if (param.axes.has_value()) { | ||
const mxnet::Tuple<int>& axes = param.axes.value(); | ||
const int naxes = axes.ndim(); | ||
|
||
std::vector<int64_t> index_products = IndexArrayComputeIndexProducts(inshape); | ||
|
||
std::vector<int64_t> cpu_workspace(2 * naxes); | ||
IndexArrayBuildSelectedAxesWorkspace(axes, index_products, cpu_workspace.data(), ndim); | ||
|
||
Tensor<gpu, 1, int64_t> workspace = | ||
ctx.requested[0].get_space_typed<gpu, 1, int64_t>(Shape1(2 * naxes), stream); | ||
|
||
CUDA_CALL(cudaMemcpy(workspace.dptr_, cpu_workspace.data(), sizeof(int64_t) * (2 * naxes), | ||
cudaMemcpyHostToDevice)); | ||
|
||
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { | ||
Kernel<IndexArrayKernel<req_type>, gpu>::Launch(stream, in_data.Size(), | ||
out_data.dptr<int64_t>(), naxes, workspace.dptr_); | ||
}); | ||
} else { | ||
Tensor<gpu, 1, dim_t> workspace = | ||
ctx.requested[0].get_space_typed<gpu, 1, dim_t>(Shape1(ndim), stream); | ||
|
||
CUDA_CALL(cudaMemcpy(workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, | ||
cudaMemcpyHostToDevice)); | ||
|
||
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { | ||
Kernel<IndexArrayDefaultKernel<req_type>, gpu>::Launch(stream, in_data.Size(), | ||
out_data.dptr<int64_t>(), ndim, workspace.dptr_); | ||
}); | ||
} | ||
} | ||
|
||
NNVM_REGISTER_OP(_contrib_index_array) | ||
.set_attr<FCompute>("FCompute<gpu>", IndexArrayForwardGPU); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
Oops, something went wrong.