diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index c3f6dc6558e3..696006d4df87 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -176,6 +176,51 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_NO_FLOAT16_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + LOG(FATAL) << "This operation does not " \ + "support float16"; \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } /*! * \brief assign the val to out according diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 105ee8b90db8..4e0bc4237853 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -58,6 +58,7 @@ struct TopKParam : public dmlc::Parameter { int k; int ret_typ; bool is_ascend; + int dtype; DMLC_DECLARE_PARAMETER(TopKParam) { DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) .describe("Axis along which to choose the top k indices." @@ -79,6 +80,16 @@ struct TopKParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(is_ascend).set_default(false) .describe("Whether to choose k largest or k smallest elements." " Top K largest elements will be chosen if set to false."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .set_default(mshadow::kFloat32) + .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". " + "An error will be raised if the selected data type cannot precisely represent the " + "indices."); } }; @@ -97,12 +108,23 @@ struct SortParam : public dmlc::Parameter { struct ArgSortParam : public dmlc::Parameter { dmlc::optional axis; bool is_ascend; + int dtype; DMLC_DECLARE_PARAMETER(ArgSortParam) { DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) .describe("Axis along which to sort the input tensor." " If not given, the flattened array is used. Default is -1."); DMLC_DECLARE_FIELD(is_ascend).set_default(true) .describe("Whether to sort in ascending or descending order."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .set_default(mshadow::kFloat32) + .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or" + " \"both\". An error will be raised if the selected data type cannot precisely " + "represent the indices."); } }; @@ -154,19 +176,12 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha using namespace mshadow; -template -void TopKSort(const Tensor& dat, - const Tensor& ind, - const Tensor& work, - int K, int N, bool is_ascend, - Stream *s); - -template<> -MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, - const Tensor& work, - int K, int N, bool is_ascend, - Stream *s) { +template +MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s) { // Use full sort when K is relatively large. const bool full_sort(K*8 > N); // Batch size. @@ -174,7 +189,7 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < M; ++i) { - real_t *vals = dat.dptr_; + DType *vals = dat.dptr_; int *indices = ind.dptr_+i*N; if (is_ascend) { if (full_sort) { @@ -193,7 +208,7 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); } } - real_t *buff = reinterpret_cast(work.dptr_)+i*K; + DType *buff = reinterpret_cast(work.dptr_)+i*K; for (int j = 0; j < K; ++j) { buff[j] = vals[indices[j]]; } @@ -285,12 +300,12 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as } } -template<> -MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, - const Tensor& work, - int K, int N, bool is_ascend, - Stream *s) { +template +MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s) { // Use full sort for all but very small K for which we // can do a partial sort entirely within shared memory. const bool full_sort(K > 5); @@ -311,7 +326,7 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, } } else { const int nthreads(mshadow::cuda::kBaseThreadNum); - PartialSortSmallK<<::GetStream(s)>>> (K, N, dat.dptr_, ind.dptr_, is_ascend); } @@ -331,25 +346,25 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, * \param k the K elements to keep * \param param the topk parameters * \tparam xpu the device type. + * \tparam DType type of the output value/mask. + * \tparam IDType type of the output indices. */ -template -void TopKImpl(RunContext ctx, - Resource resource, +template +void TopKImpl(const RunContext &ctx, + const Resource &resource, + const std::vector& req, const TBlob& src, const std::vector& ret, const TopKParam& param) { using namespace mshadow; using namespace mshadow::expr; - for (auto ret_ele : ret) { - CHECK_EQ(ret_ele.type_flag_, src.type_flag_); - } // 1. Parse and initialize information Stream *s = ctx.get_stream(); Tensor workspace; Tensor temp_workspace; - Tensor sorted_dat; + Tensor sorted_dat; Tensor indices, sel_indices; - Tensor mask_val; + Tensor mask_val; int batch_size, element_num; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; @@ -358,25 +373,29 @@ void TopKImpl(RunContext ctx, TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - Tensor dat = src.FlatTo3D(axis, axis, s); + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'IDType' does not have a sufficient precision to represent the indices of the input array. " + << "The total element_num is " << element_num << ", but the selected IDType can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; + Tensor dat = src.FlatTo3D(axis, axis, s); size_t temp_size = 0; // Temp space needed by the gpu-based full sorts. temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. temp_size += sizeof(int) * src.Size(); // Temp space for cpu sorts. - temp_size = std::max(temp_size, sizeof(real_t) * src.Size()); - size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size(); + temp_size = std::max(temp_size, sizeof(DType) * src.Size()); + size_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size(); if (param.ret_typ == topk_enum::kReturnMask) { - workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k; + workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k; } workspace = resource.get_space_typed(Shape1(workspace_size), s); char* workspace_curr_ptr = workspace.dptr_; - sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), + sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // contain sorted dat - workspace_curr_ptr += sizeof(real_t) * src.Size(); + workspace_curr_ptr += sizeof(DType) * src.Size(); indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += sizeof(int) * src.Size(); @@ -394,10 +413,10 @@ void TopKImpl(RunContext ctx, sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(batch_size * k), s); workspace_curr_ptr += sizeof(int) * batch_size * k; - mask_val = Tensor(reinterpret_cast(workspace_curr_ptr), + mask_val = Tensor(reinterpret_cast(workspace_curr_ptr), Shape2(batch_size * k, 1), s); - workspace_curr_ptr += sizeof(real_t) * batch_size * k; - mask_val = scalar(1); + workspace_curr_ptr += sizeof(DType) * batch_size * k; + mask_val = scalar(1); CHECK_EQ(sel_indices.CheckContiguous(), true); CHECK_EQ(mask_val.CheckContiguous(), true); } @@ -411,9 +430,9 @@ void TopKImpl(RunContext ctx, // 3. Assign results to the ret blob if (param.ret_typ == topk_enum::kReturnMask) { - Tensor ret_mask = - ret[0].get_with_shape(Shape2(ret[0].Size(), 1), s); - ret_mask = scalar(0); + Tensor ret_mask = + ret[0].get_with_shape(Shape2(ret[0].Size(), 1), s); + ret_mask = scalar(0); sel_indices = reshape(slice<1>( inplace_reshape(indices, Shape2(batch_size, @@ -425,49 +444,56 @@ void TopKImpl(RunContext ctx, sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } - IndexFill(ret_mask, sel_indices, mask_val); + if (req[0] == kNullOp) { + return; + } else if (req[0] == kWriteTo) { + IndexFill(ret_mask, sel_indices, mask_val); + } else { + LOG(FATAL) << "req=" << req[0] << " is not supported yet."; + } } else if (param.ret_typ == topk_enum::kReturnIndices) { indices = F(indices, element_num); if (do_transpose) { - Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); - ret_indices = tcast(transpose( + Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); + Assign(ret_indices, req[0], tcast(transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], element_num)), 0, k), - Shape3(0, 2, 1))); + Shape3(0, 2, 1)))); } else { - Tensor ret_indices = - ret[0].get_with_shape(Shape2(batch_size, k), s); - ret_indices = tcast(slice<1>( - inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); + Tensor ret_indices = + ret[0].get_with_shape(Shape2(batch_size, k), s); + Assign(ret_indices, req[0], tcast(slice<1>( + inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k))); } } else { indices = F(indices, element_num); if (do_transpose) { - Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); - Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); - ret_value = transpose( + Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); + Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); + Assign(ret_value, req[0], transpose( slice<2>(inplace_reshape(sorted_dat, Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)), 0, k), - Shape3(0, 2, 1)); - ret_indices = tcast(transpose( + Shape3(0, 2, 1))); + Assign(ret_indices, req[1], tcast(transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], element_num)), 0, k), - Shape3(0, 2, 1))); + Shape3(0, 2, 1)))); } else { - Tensor ret_value = - ret[0].get_with_shape(Shape2(batch_size, k), s); - Tensor ret_indices = - ret[1].get_with_shape(Shape2(batch_size, k), s); - ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k); - ret_indices = tcast(slice<1>( - inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); + Tensor ret_value = + ret[0].get_with_shape(Shape2(batch_size, k), s); + Tensor ret_indices = + ret[1].get_with_shape(Shape2(batch_size, k), s); + Assign(ret_value, req[0], + slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k)); + Assign(ret_indices, req[1], tcast(slice<1>( + inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k))); } } } @@ -479,9 +505,17 @@ void TopK(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const TopKParam& param = nnvm::get(attrs.parsed); - // TODO(sxjscience) We can support inplace in the future - CHECK_EQ(req[0], kWriteTo) << "TopK does not support inplace"; - TopKImpl(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, param); + if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(param.dtype, IDType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + }) + }); + } else { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + }); + } } template @@ -491,13 +525,14 @@ void Sort(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const SortParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Sort does not support inplace"; TopKParam topk_param; topk_param.axis = param.axis; topk_param.is_ascend = param.is_ascend; topk_param.k = 0; topk_param.ret_typ = topk_enum::kReturnValue; - TopKImpl(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param); + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + }); } template @@ -507,26 +542,30 @@ void ArgSort(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ArgSortParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "ArgSort does not support inplace"; TopKParam topk_param; topk_param.axis = param.axis; topk_param.is_ascend = param.is_ascend; topk_param.k = 0; + topk_param.dtype = param.dtype; topk_param.ret_typ = topk_enum::kReturnIndices; - TopKImpl(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param); + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(param.dtype, IDType, { + TopKImpl(ctx.run_ctx, + ctx.requested[0], req, inputs[0], outputs, topk_param); + }); + }); } -template -void TopKBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +template +void TopKBackwardImpl(const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const TopKParam& param) { CHECK_NE(req[0], kWriteInplace); using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.run_ctx.get_stream(); - const TopKParam& param = nnvm::get(attrs.parsed); CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth); int batch_size, element_num; // number of batches + the size of each batch int axis = 0; @@ -536,23 +575,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, TShape target_shape; ParseTopKParam(outputs[0].shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(batch_size * k * 2 + batch_size), s); - Tensor sel_indices = - Tensor(workspace.dptr_, Shape1(batch_size * k), s); - Tensor batch_shift = - Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); - Tensor dummy_index = - Tensor(workspace.dptr_ + batch_size * k + batch_size, + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'IDType' does not have a sufficient precision to represent the indices of the input array. " + << "The total element_num is " << element_num << ", but the selected IDType can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(batch_size * k * 2 + batch_size), s); + Tensor sel_indices = + Tensor(workspace.dptr_, Shape1(batch_size * k), s); + Tensor batch_shift = + Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); + Tensor dummy_index = + Tensor(workspace.dptr_ + batch_size * k + batch_size, Shape1(batch_size * k), s); - Tensor out_grad = - inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); - Tensor in_grad = - outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); - mxnet_op::Kernel::Launch(s, batch_size, 1, 0.0f, - static_cast(element_num), kWriteTo, batch_shift.dptr_); + + Tensor out_grad = + inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); + Tensor in_grad = + outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); + mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, + batch_shift.dptr_); if (do_transpose) { - Tensor indices = inputs[2].FlatTo1D(s); + Tensor indices = inputs[2].FlatTo1D(s); TShape src_shape = outputs[0].shape_.FlatTo3D(axis); sel_indices = reshape(transpose( broadcast_to(inplace_reshape(batch_shift, @@ -560,26 +604,26 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, TShape(Shape3(src_shape[0], src_shape[2], k))), Shape3(0, 2, 1)), Shape1(batch_size * k)); - sel_indices += indices; + sel_indices += tcast(indices); sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } else { - Tensor indices = - inputs[2].get_with_shape(Shape2(batch_size, k), s); - sel_indices = reshape(indices + + Tensor indices = + inputs[2].get_with_shape(Shape2(batch_size, k), s); + sel_indices = reshape(tcast(indices) + broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), TShape(Shape2(batch_size, k))), Shape1(batch_size * k)); } CHECK_EQ(sel_indices.CheckContiguous(), true); if (kWriteTo == req[0]) { - in_grad = scalar(0); + in_grad = scalar(0); IndexFill(in_grad, sel_indices, out_grad); } else if (kAddTo == req[0]) { // TODO(sxjscience) We can use AddTakeGrad in the future. // However, the current implementation of AddTakeGrad is not so efficient. - mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0.0f, - 1.0f, kWriteTo, dummy_index.dptr_); + mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo, + dummy_index.dptr_); mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad); } else if (kNullOp == req[0]) { return; @@ -588,6 +632,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, } } +template +void TopKBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const TopKParam& param = nnvm::get(attrs.parsed); + if (param.ret_typ == topk_enum::kReturnBoth) { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(param.dtype, IDType, { + TopKBackwardImpl(ctx, inputs, req, outputs, param); + }); + }); + } else if (param.ret_typ == topk_enum::kReturnValue) { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKBackwardImpl(ctx, inputs, req, outputs, param); + }); + } else { + LOG(FATAL) << "Not Implemented"; + } +} + inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) { const TopKParam& param = nnvm::get(attrs.parsed); if (param.ret_typ == topk_enum::kReturnIndices || @@ -610,8 +676,36 @@ inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) { inline bool TopKType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - return ElemwiseAttr( - attrs, in_attrs, out_attrs, -1); + const TopKParam& param = nnvm::get(attrs.parsed); + int data_type = -1; + size_t in_size = in_attrs->size(); + size_t out_size = out_attrs->size(); + CHECK_EQ(in_size, 1); + CHECK(out_size == 1 || out_size == 2); + if (out_size > 1) { + if (param.ret_typ == topk_enum::kReturnValue) { + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + << "Failed to set the type of ret_indices."; + } else { + CHECK(type_assign(&(*out_attrs)[1], param.dtype)) + << "Failed to set the type of ret_indices."; + } + } + if (param.ret_typ == topk_enum::kReturnIndices) { + CHECK(type_assign(&(*out_attrs)[0], param.dtype)) + << "Failed to set the type of ret_indices."; + } else { + CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + if (data_type == -1) return false; + } + return true; } inline bool TopKShapeImpl(const TopKParam& param, @@ -650,6 +744,28 @@ inline bool TopKShape(const nnvm::NodeAttrs& attrs, return TopKShapeImpl(param, in_attrs, out_attrs); } +inline bool SortType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + int data_type = -1; + size_t in_size = in_attrs->size(); + size_t out_size = out_attrs->size(); + CHECK_EQ(in_size, 1); + CHECK_EQ(out_size, 2); + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + << "Failed to set the type of ret_indices to int32."; + CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + if (data_type == -1) return false; + return true; +} + inline bool SortShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -662,6 +778,15 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs, return TopKShapeImpl(topk_param, in_attrs, out_attrs); } +inline bool ArgSortType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ArgSortParam& param = nnvm::get(attrs.parsed); + CHECK(type_assign(&(*out_attrs)[0], param.dtype)) + << "Failed to set the type of ret_indices to int32."; + return true; +} + inline bool ArgSortShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc index ebd7c62ec886..f30be6c12b8e 100644 --- a/src/operator/tensor/ordering_op.cc +++ b/src/operator/tensor/ordering_op.cc @@ -128,7 +128,7 @@ Examples:: .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("FInferShape", SortShape) -.set_attr("FInferType", ElemwiseType<1, 2>) +.set_attr("FInferType", SortType) .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) .set_attr("FCompute", Sort) .set_attr("FGradient", @@ -178,7 +178,7 @@ Examples:: .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ArgSortShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", ArgSortType) .set_attr("FCompute", ArgSort) .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FResourceRequest", diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index aeaa0b726794..30abd5a9a3e7 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -639,106 +639,147 @@ def gt_topk(dat, axis, ret_typ, k, is_ascend): # values, making it hard to generate a numpy 'golden copy' to compare against # the mxnet operator. The 'mask' function is particularly hard to test given that # equal values might span the 'k' boundary. Issue exposed with seed 1405838964. - def get_values(ensure_unique): - while True: - data = np.float32(np.random.normal(size=(dat_size, dat_size, dat_size, dat_size))) - if not ensure_unique: - return data - num_unique_values = len(set(data.flatten())) - if data.size == num_unique_values: - return data - - a_npy = get_values(ensure_unique=True) - a_nd = mx.nd.array(a_npy, ctx=ctx) - - # test for ret_typ=indices - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for ret_typ=value - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for ret_typ=mask - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for ret_typ=both - nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) - nd_ret_topk_val = nd_ret_topk_val.asnumpy() - nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() - gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) - gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk_val, gt_val) - assert_almost_equal(nd_ret_topk_ind, gt_ind) - - # test for sort - nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_sort, gt) - nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_sort, gt) - - # test for argsort - nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_argsort, gt) - nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="indices", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_argsort, gt) - - a = mx.nd.arange(0, 1024, step=1, repeat=1) - assert_almost_equal(a.topk(k=1024).asnumpy(), a.asnumpy()[::-1]) + def get_values(ensure_unique, dtype): + if dtype == np.int16 or dtype == np.int32 or dtype == np.int64: + return np.arange(dat_size ** 4, dtype=dtype).reshape((dat_size, dat_size, dat_size, dat_size)) + elif dtype == np.float32 or dtype == np.float64: + while True: + data = np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)).astype(dtype) + if not ensure_unique: + return data + num_unique_values = len(set(data.flatten())) + if data.size == num_unique_values: + return data + else: + raise NotImplementedError + + for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]: + a_npy = get_values(ensure_unique=True, dtype=dtype) + a_nd = mx.nd.array(a_npy, ctx=ctx) + + # test for ret_typ=indices + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for ret_typ=value + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for ret_typ=mask + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for ret_typ=both + nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) + nd_ret_topk_val = nd_ret_topk_val.asnumpy() + nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() + gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) + gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk_val, gt_val) + assert_almost_equal(nd_ret_topk_ind, gt_ind) + # test for kNullOp + _, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) + nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() + assert_almost_equal(nd_ret_topk_ind, gt_ind) + # test for kNullOp + nd_ret_topk_val, _ = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) + nd_ret_topk_val = nd_ret_topk_val.asnumpy() + assert_almost_equal(nd_ret_topk_val, gt_val) + + # test for sort + nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_sort, gt) + nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_sort, gt) + + # test for argsort + for idtype in [np.int32, np.float16, np.float32, np.float64]: + nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_argsort, gt) + nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="indices", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_argsort, gt) + + # test topk with a big shape + a = mx.nd.arange(0, 1024, step=1, repeat=1, dtype=np.int32) + assert_almost_equal(a.topk(k=1024, dtype=np.int32).asnumpy(), a.asnumpy()[::-1]) + a.attach_grad() + + k = 10 + with mx.autograd.record(): + b = mx.nd.topk(a, k=k, ret_typ='value') + b.backward(mx.nd.ones((k,), dtype=np.int32)) + a_grad = a.grad.asnumpy() + for i in range(-1, - k - 1, -1): + assert a_grad[i] == 1 + + # test topk gradient with a small shape + for dtype in [np.int32, np.int64, np.float32, np.float64]: + a = mx.nd.arange(0, 1000, step=1, repeat=1, dtype=dtype) + a.attach_grad() + k = 10 + ograd = mx.nd.arange(0, k, dtype=dtype) + with mx.autograd.record(): + b = mx.nd.topk(a, k=k, ret_typ='value') + b.backward(ograd) + a_grad = a.grad.asnumpy() + ograd_npy = ograd.asnumpy() + for i in range(-1, - k - 1, -1): + assert a_grad[i] == ograd_npy[-i - 1] + # Repeat those tests that don't involve indices. These should pass even with # duplicated input data values (over many repeated runs with different random seeds, # this will be tested). - a_npy = get_values(ensure_unique=False) - a_nd = mx.nd.array(a_npy, ctx=ctx) - - # test for ret_typ=value - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for sort - nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_sort, gt) - nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_sort, gt) + for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]: + a_npy = get_values(ensure_unique=False, dtype=dtype) + a_nd = mx.nd.array(a_npy, ctx=ctx) + + # test for ret_typ=value + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for sort + nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_sort, gt) + nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_sort, gt) @with_seed() def test_ndarray_equal():