From 226aee367a7c189193321ead4e48adc9d8bac806 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 4 Jun 2018 16:32:55 +0800 Subject: [PATCH 01/14] set the dtype of index to be int32 for ordering ops --- src/operator/tensor/ordering_op-inl.h | 274 +++++++++++++++----------- src/operator/tensor/ordering_op.cc | 4 +- tests/python/unittest/test_ndarray.py | 219 +++++++++++--------- 3 files changed, 284 insertions(+), 213 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 105ee8b90db8..0da54b8383d8 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -154,19 +154,12 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha using namespace mshadow; -template -void TopKSort(const Tensor& dat, - const Tensor& ind, - const Tensor& work, - int K, int N, bool is_ascend, - Stream *s); - -template<> -MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, - const Tensor& work, - int K, int N, bool is_ascend, - Stream *s) { +template +MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s) { // Use full sort when K is relatively large. const bool full_sort(K*8 > N); // Batch size. @@ -174,7 +167,7 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < M; ++i) { - real_t *vals = dat.dptr_; + DType *vals = dat.dptr_; int *indices = ind.dptr_+i*N; if (is_ascend) { if (full_sort) { @@ -193,7 +186,7 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); } } - real_t *buff = reinterpret_cast(work.dptr_)+i*K; + DType *buff = reinterpret_cast(work.dptr_)+i*K; for (int j = 0; j < K; ++j) { buff[j] = vals[indices[j]]; } @@ -285,12 +278,12 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as } } -template<> -MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, - const Tensor& work, - int K, int N, bool is_ascend, - Stream *s) { +template +MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s) { // Use full sort for all but very small K for which we // can do a partial sort entirely within shared memory. const bool full_sort(K > 5); @@ -311,7 +304,7 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, } } else { const int nthreads(mshadow::cuda::kBaseThreadNum); - PartialSortSmallK<<::GetStream(s)>>> (K, N, dat.dptr_, ind.dptr_, is_ascend); } @@ -332,24 +325,22 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, * \param param the topk parameters * \tparam xpu the device type. */ -template +template void TopKImpl(RunContext ctx, Resource resource, + const std::vector& req, const TBlob& src, const std::vector& ret, const TopKParam& param) { using namespace mshadow; using namespace mshadow::expr; - for (auto ret_ele : ret) { - CHECK_EQ(ret_ele.type_flag_, src.type_flag_); - } // 1. Parse and initialize information Stream *s = ctx.get_stream(); Tensor workspace; Tensor temp_workspace; - Tensor sorted_dat; + Tensor sorted_dat; Tensor indices, sel_indices; - Tensor mask_val; + Tensor mask_val; int batch_size, element_num; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; @@ -358,25 +349,25 @@ void TopKImpl(RunContext ctx, TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - Tensor dat = src.FlatTo3D(axis, axis, s); + Tensor dat = src.FlatTo3D(axis, axis, s); size_t temp_size = 0; // Temp space needed by the gpu-based full sorts. temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. temp_size += sizeof(int) * src.Size(); // Temp space for cpu sorts. - temp_size = std::max(temp_size, sizeof(real_t) * src.Size()); - size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size(); + temp_size = std::max(temp_size, sizeof(DType) * src.Size()); + size_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size(); if (param.ret_typ == topk_enum::kReturnMask) { - workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k; + workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k; } workspace = resource.get_space_typed(Shape1(workspace_size), s); char* workspace_curr_ptr = workspace.dptr_; - sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), + sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // contain sorted dat - workspace_curr_ptr += sizeof(real_t) * src.Size(); + workspace_curr_ptr += sizeof(DType) * src.Size(); indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += sizeof(int) * src.Size(); @@ -394,10 +385,10 @@ void TopKImpl(RunContext ctx, sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(batch_size * k), s); workspace_curr_ptr += sizeof(int) * batch_size * k; - mask_val = Tensor(reinterpret_cast(workspace_curr_ptr), + mask_val = Tensor(reinterpret_cast(workspace_curr_ptr), Shape2(batch_size * k, 1), s); - workspace_curr_ptr += sizeof(real_t) * batch_size * k; - mask_val = scalar(1); + workspace_curr_ptr += sizeof(DType) * batch_size * k; + mask_val = scalar(1); CHECK_EQ(sel_indices.CheckContiguous(), true); CHECK_EQ(mask_val.CheckContiguous(), true); } @@ -411,9 +402,9 @@ void TopKImpl(RunContext ctx, // 3. Assign results to the ret blob if (param.ret_typ == topk_enum::kReturnMask) { - Tensor ret_mask = - ret[0].get_with_shape(Shape2(ret[0].Size(), 1), s); - ret_mask = scalar(0); + Tensor ret_mask = + ret[0].get_with_shape(Shape2(ret[0].Size(), 1), s); + ret_mask = scalar(0); sel_indices = reshape(slice<1>( inplace_reshape(indices, Shape2(batch_size, @@ -425,12 +416,18 @@ void TopKImpl(RunContext ctx, sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } - IndexFill(ret_mask, sel_indices, mask_val); + if(req[0] == kNullOp) { + return; + } else if(req[0] == kWriteTo) { + IndexFill(ret_mask, sel_indices, mask_val); + } else { + LOG(FATAL) << "req=" << req[0] << " is not supported yet."; + } } else if (param.ret_typ == topk_enum::kReturnIndices) { indices = F(indices, element_num); if (do_transpose) { - Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); - ret_indices = tcast(transpose( + Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); + Assign(ret_indices, req[0], transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], @@ -438,22 +435,22 @@ void TopKImpl(RunContext ctx, 0, k), Shape3(0, 2, 1))); } else { - Tensor ret_indices = - ret[0].get_with_shape(Shape2(batch_size, k), s); - ret_indices = tcast(slice<1>( + Tensor ret_indices = + ret[0].get_with_shape(Shape2(batch_size, k), s); + Assign(ret_indices, req[0], slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); } } else { indices = F(indices, element_num); if (do_transpose) { - Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); - Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); - ret_value = transpose( + Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); + Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); + Assign(ret_value, req[0], transpose( slice<2>(inplace_reshape(sorted_dat, Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)), 0, k), - Shape3(0, 2, 1)); - ret_indices = tcast(transpose( + Shape3(0, 2, 1))); + Assign(ret_indices, req[1], transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], @@ -461,12 +458,12 @@ void TopKImpl(RunContext ctx, 0, k), Shape3(0, 2, 1))); } else { - Tensor ret_value = - ret[0].get_with_shape(Shape2(batch_size, k), s); - Tensor ret_indices = - ret[1].get_with_shape(Shape2(batch_size, k), s); - ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k); - ret_indices = tcast(slice<1>( + Tensor ret_value = + ret[0].get_with_shape(Shape2(batch_size, k), s); + Tensor ret_indices = + ret[1].get_with_shape(Shape2(batch_size, k), s); + Assign(ret_value, req[0], slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k)); + Assign(ret_indices, req[1], slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); } } @@ -479,9 +476,9 @@ void TopK(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const TopKParam& param = nnvm::get(attrs.parsed); - // TODO(sxjscience) We can support inplace in the future - CHECK_EQ(req[0], kWriteTo) << "TopK does not support inplace"; - TopKImpl(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, param); + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + }); } template @@ -491,13 +488,14 @@ void Sort(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const SortParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Sort does not support inplace"; TopKParam topk_param; topk_param.axis = param.axis; topk_param.is_ascend = param.is_ascend; topk_param.k = 0; topk_param.ret_typ = topk_enum::kReturnValue; - TopKImpl(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param); + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + }); } template @@ -507,13 +505,14 @@ void ArgSort(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ArgSortParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "ArgSort does not support inplace"; TopKParam topk_param; topk_param.axis = param.axis; topk_param.is_ascend = param.is_ascend; topk_param.k = 0; topk_param.ret_typ = topk_enum::kReturnIndices; - TopKImpl(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param); + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + }); } template @@ -536,56 +535,56 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, TShape target_shape; ParseTopKParam(outputs[0].shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(batch_size * k * 2 + batch_size), s); - Tensor sel_indices = - Tensor(workspace.dptr_, Shape1(batch_size * k), s); - Tensor batch_shift = - Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); - Tensor dummy_index = - Tensor(workspace.dptr_ + batch_size * k + batch_size, + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(batch_size * k * 2 + batch_size), s); + Tensor sel_indices = + Tensor(workspace.dptr_, Shape1(batch_size * k), s); + Tensor batch_shift = + Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); + Tensor dummy_index = + Tensor(workspace.dptr_ + batch_size * k + batch_size, Shape1(batch_size * k), s); - Tensor out_grad = - inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); - Tensor in_grad = - outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); - mxnet_op::Kernel::Launch(s, batch_size, 1, 0.0f, - static_cast(element_num), kWriteTo, batch_shift.dptr_); - if (do_transpose) { - Tensor indices = inputs[2].FlatTo1D(s); - TShape src_shape = outputs[0].shape_.FlatTo3D(axis); - sel_indices = reshape(transpose( - broadcast_to(inplace_reshape(batch_shift, - Shape3(src_shape[0], src_shape[2], 1)), - TShape(Shape3(src_shape[0], src_shape[2], k))), - Shape3(0, 2, 1)), - Shape1(batch_size * k)); - sel_indices += indices; - sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), - Shape3(0, 2, 1)); - } else { - Tensor indices = - inputs[2].get_with_shape(Shape2(batch_size, k), s); - sel_indices = reshape(indices + - broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), - TShape(Shape2(batch_size, k))), - Shape1(batch_size * k)); - } - CHECK_EQ(sel_indices.CheckContiguous(), true); - if (kWriteTo == req[0]) { - in_grad = scalar(0); - IndexFill(in_grad, sel_indices, out_grad); - } else if (kAddTo == req[0]) { - // TODO(sxjscience) We can use AddTakeGrad in the future. - // However, the current implementation of AddTakeGrad is not so efficient. - mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0.0f, - 1.0f, kWriteTo, dummy_index.dptr_); - mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad); - } else if (kNullOp == req[0]) { - return; - } else { - LOG(FATAL) << "Not Implemented!"; - } + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor out_grad = + inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); + Tensor in_grad = + outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); + mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, batch_shift.dptr_); + if (do_transpose) { + Tensor indices = inputs[2].FlatTo1D(s); + TShape src_shape = outputs[0].shape_.FlatTo3D(axis); + sel_indices = reshape(transpose( + broadcast_to(inplace_reshape(batch_shift, + Shape3(src_shape[0], src_shape[2], 1)), + TShape(Shape3(src_shape[0], src_shape[2], k))), + Shape3(0, 2, 1)), + Shape1(batch_size * k)); + sel_indices += indices; + sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), + Shape3(0, 2, 1)); + } else { + Tensor indices = + inputs[2].get_with_shape(Shape2(batch_size, k), s); + sel_indices = reshape(indices + + broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), + TShape(Shape2(batch_size, k))), + Shape1(batch_size * k)); + } + CHECK_EQ(sel_indices.CheckContiguous(), true); + if (kWriteTo == req[0]) { + in_grad = scalar(0); + IndexFill(in_grad, sel_indices, out_grad); + } else if (kAddTo == req[0]) { + // TODO(sxjscience) We can use AddTakeGrad in the future. + // However, the current implementation of AddTakeGrad is not so efficient. + mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo, dummy_index.dptr_); + mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad); + } else if (kNullOp == req[0]) { + return; + } else { + LOG(FATAL) << "Not Implemented!"; + } + }); } inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) { @@ -610,8 +609,25 @@ inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) { inline bool TopKType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - return ElemwiseAttr( - attrs, in_attrs, out_attrs, -1); + const TopKParam& param = nnvm::get(attrs.parsed); + int data_type = -1; + size_t in_size = in_attrs->size(); + size_t out_size = out_attrs->size(); + CHECK_EQ(in_size, 1); + CHECK(out_size == 1 || out_size == 2); + if(out_size > 1) { + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + } + if(param.ret_typ == topk_enum::kReturnIndices) { + CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + } else { + CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; + CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; + CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; + if(data_type == -1) return false; + } + return true; } inline bool TopKShapeImpl(const TopKParam& param, @@ -650,6 +666,23 @@ inline bool TopKShape(const nnvm::NodeAttrs& attrs, return TopKShapeImpl(param, in_attrs, out_attrs); } +inline bool SortType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + int data_type = -1; + size_t in_size = in_attrs->size(); + size_t out_size = out_attrs->size(); + CHECK_EQ(in_size, 1); + CHECK_EQ(out_size, 2); + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; + CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; + CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; + if(data_type == -1) return false; + return true; +} + inline bool SortShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -662,6 +695,13 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs, return TopKShapeImpl(topk_param, in_attrs, out_attrs); } +inline bool ArgSortType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + return true; +} + inline bool ArgSortShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc index ebd7c62ec886..f30be6c12b8e 100644 --- a/src/operator/tensor/ordering_op.cc +++ b/src/operator/tensor/ordering_op.cc @@ -128,7 +128,7 @@ Examples:: .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("FInferShape", SortShape) -.set_attr("FInferType", ElemwiseType<1, 2>) +.set_attr("FInferType", SortType) .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) .set_attr("FCompute", Sort) .set_attr("FGradient", @@ -178,7 +178,7 @@ Examples:: .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ArgSortShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", ArgSortType) .set_attr("FCompute", ArgSort) .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FResourceRequest", diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 496f80f927f6..495df690ec90 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -639,107 +639,138 @@ def gt_topk(dat, axis, ret_typ, k, is_ascend): # values, making it hard to generate a numpy 'golden copy' to compare against # the mxnet operator. The 'mask' function is particularly hard to test given that # equal values might span the 'k' boundary. Issue exposed with seed 1405838964. - def get_values(ensure_unique): - while True: - data = np.float32(np.random.normal(size=(dat_size, dat_size, dat_size, dat_size))) - if not ensure_unique: - return data - num_unique_values = len(set(data.flatten())) - if data.size == num_unique_values: - return data - - a_npy = get_values(ensure_unique=True) - a_nd = mx.nd.array(a_npy, ctx=ctx) - - # test for ret_typ=indices - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for ret_typ=value - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for ret_typ=mask - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for ret_typ=both - nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) - nd_ret_topk_val = nd_ret_topk_val.asnumpy() - nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() - gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) - gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk_val, gt_val) - assert_almost_equal(nd_ret_topk_ind, gt_ind) - - # test for sort - nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_sort, gt) - nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_sort, gt) - - # test for argsort - nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_argsort, gt) - nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="indices", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_argsort, gt) + def get_values(ensure_unique, dtype): + if dtype == np.int16 or dtype == np.int32 or dtype == np.int64: + return np.arange(dat_size ** 4, dtype=dtype).reshape((dat_size, dat_size, dat_size, dat_size)) + elif dtype == np.float32 or dtype == np.float64: + while True: + data = np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)).astype(dtype) + if not ensure_unique: + return data + num_unique_values = len(set(data.flatten())) + if data.size == num_unique_values: + return data + else: + raise NotImplementedError + + for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]: + a_npy = get_values(ensure_unique=True, dtype=dtype) + a_nd = mx.nd.array(a_npy, ctx=ctx) + + # test for ret_typ=indices + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for ret_typ=value + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for ret_typ=mask + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for ret_typ=both + nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) + nd_ret_topk_val = nd_ret_topk_val.asnumpy() + nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() + gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) + gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk_val, gt_val) + assert_almost_equal(nd_ret_topk_ind, gt_ind) + + # test for sort + nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_sort, gt) + nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_sort, gt) + + # test for argsort + nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_argsort, gt) + nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="indices", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_argsort, gt) # test topk with a big shape - a = mx.nd.arange(0, 54686454, step=1, repeat=1) + a = mx.nd.arange(0, 54686454, step=1, repeat=1, dtype=np.int32) assert_almost_equal(a.topk(k=54686454).asnumpy(), a.asnumpy()[::-1]) + a.attach_grad() + + k = 10 + with mx.autograd.record(): + b = mx.nd.topk(a, k=k, ret_typ='value') + b.backward(mx.nd.ones((k,), dtype=np.int32)) + a_grad = a.grad.asnumpy() + for i in range(-1, - k - 1, -1): + assert a_grad[i] == 1 + + # test topk gradient with a small shape + for dtype in [np.int32, np.int64, np.float32, np.float64]: + a = mx.nd.arange(0, 1000, step=1, repeat=1, dtype=dtype) + a.attach_grad() + k = 10 + ograd = mx.nd.arange(0, k, dtype=dtype) + with mx.autograd.record(): + b = mx.nd.topk(a, k=k, ret_typ='value') + b.backward(ograd) + a_grad = a.grad.asnumpy() + ograd_npy = ograd.asnumpy() + for i in range(-1, - k - 1, -1): + assert a_grad[i] == ograd_npy[-i - 1] + # Repeat those tests that don't involve indices. These should pass even with # duplicated input data values (over many repeated runs with different random seeds, # this will be tested). - a_npy = get_values(ensure_unique=False) - a_nd = mx.nd.array(a_npy, ctx=ctx) - - # test for ret_typ=value - nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) - assert_almost_equal(nd_ret_topk, gt) - - # test for sort - nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_sort, gt) - nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="value", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_sort, gt) + for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]: + a_npy = get_values(ensure_unique=False, dtype=dtype) + a_nd = mx.nd.array(a_npy, ctx=ctx) + + # test for ret_typ=value + nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False) + assert_almost_equal(nd_ret_topk, gt) + + # test for sort + nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() + gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_sort, gt) + nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="value", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_sort, gt) @with_seed() def test_ndarray_equal(): From 0781f552848b9b7b30339d6238dbca4cf008b5b8 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 4 Jun 2018 17:13:25 +0800 Subject: [PATCH 02/14] fix --- src/operator/tensor/ordering_op-inl.h | 55 +++++++++++++++++---------- tests/python/unittest/test_ndarray.py | 8 ++++ 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 0da54b8383d8..08320aceb44b 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -416,9 +416,9 @@ void TopKImpl(RunContext ctx, sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } - if(req[0] == kNullOp) { + if (req[0] == kNullOp) { return; - } else if(req[0] == kWriteTo) { + } else if (req[0] == kWriteTo) { IndexFill(ret_mask, sel_indices, mask_val); } else { LOG(FATAL) << "req=" << req[0] << " is not supported yet."; @@ -462,7 +462,8 @@ void TopKImpl(RunContext ctx, ret[0].get_with_shape(Shape2(batch_size, k), s); Tensor ret_indices = ret[1].get_with_shape(Shape2(batch_size, k), s); - Assign(ret_value, req[0], slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k)); + Assign(ret_value, req[0], + slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k)); Assign(ret_indices, req[1], slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); } @@ -549,7 +550,8 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); Tensor in_grad = outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); - mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, batch_shift.dptr_); + mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, + batch_shift.dptr_); if (do_transpose) { Tensor indices = inputs[2].FlatTo1D(s); TShape src_shape = outputs[0].shape_.FlatTo3D(axis); @@ -577,7 +579,8 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, } else if (kAddTo == req[0]) { // TODO(sxjscience) We can use AddTakeGrad in the future. // However, the current implementation of AddTakeGrad is not so efficient. - mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo, dummy_index.dptr_); + mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo, + dummy_index.dptr_); mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad); } else if (kNullOp == req[0]) { return; @@ -615,17 +618,23 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, size_t out_size = out_attrs->size(); CHECK_EQ(in_size, 1); CHECK(out_size == 1 || out_size == 2); - if(out_size > 1) { - CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + if (out_size > 1) { + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + << "Failed to set the type of ret_indices to int32."; } - if(param.ret_typ == topk_enum::kReturnIndices) { - CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + if (param.ret_typ == topk_enum::kReturnIndices) { + CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) + << "Failed to set the type of ret_indices to int32."; } else { - CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; - CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; - CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; - CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; - if(data_type == -1) return false; + CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + if (data_type == -1) return false; } return true; } @@ -674,11 +683,16 @@ inline bool SortType(const nnvm::NodeAttrs& attrs, size_t out_size = out_attrs->size(); CHECK_EQ(in_size, 1); CHECK_EQ(out_size, 2); - CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; - CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; - CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; - CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; - CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + << "Failed to set the type of ret_indices to int32."; + CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; + CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" + << (*in_attrs)[0]; + CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" + << (*out_attrs)[0]; if(data_type == -1) return false; return true; } @@ -698,7 +712,8 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs, inline bool ArgSortType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) << "Failed to set the type of ret_indices to int32."; + CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) + << "Failed to set the type of ret_indices to int32."; return true; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 495df690ec90..47c4e408733b 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -698,6 +698,14 @@ def get_values(ensure_unique, dtype): gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) assert_almost_equal(nd_ret_topk_val, gt_val) assert_almost_equal(nd_ret_topk_ind, gt_ind) + # test for kNullOp + _, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) + nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() + assert_almost_equal(nd_ret_topk_ind, gt_ind) + # test for kNullOp + nd_ret_topk_val, _ = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True) + nd_ret_topk_val = nd_ret_topk_val.asnumpy() + assert_almost_equal(nd_ret_topk_val, gt_val) # test for sort nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy() From 9c2235802ada3ae75d01d28a9e2730c015b6a78f Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 4 Jun 2018 17:26:03 +0800 Subject: [PATCH 03/14] fix --- src/operator/tensor/ordering_op-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 08320aceb44b..efa364f84f90 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -693,7 +693,7 @@ inline bool SortType(const nnvm::NodeAttrs& attrs, << (*in_attrs)[0]; CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" << (*out_attrs)[0]; - if(data_type == -1) return false; + if (data_type == -1) return false; return true; } From d38cc016d6b05a283563087c7099ced795897804 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 4 Jun 2018 22:21:05 +0800 Subject: [PATCH 04/14] fix --- src/operator/mxnet_op.h | 45 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index c3f6dc6558e3..696006d4df87 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -176,6 +176,51 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_NO_FLOAT16_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + LOG(FATAL) << "This operation does not " \ + "support float16"; \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } /*! * \brief assign the val to out according From 4bb6be77283de5d8f95770d0a7396e1d935747bf Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 10:54:24 +0800 Subject: [PATCH 05/14] try to add the dtype option --- src/operator/tensor/ordering_op-inl.h | 218 +++++++++++++++++--------- tests/python/unittest/test_ndarray.py | 2 +- 2 files changed, 144 insertions(+), 76 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index efa364f84f90..152e01ef3de9 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -79,6 +79,16 @@ struct TopKParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(is_ascend).set_default(false) .describe("Whether to choose k largest or k smallest elements." " Top K largest elements will be chosen if set to false."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .set_default(mshadow::kFloat32) + .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". " + "An error will be raised if the selected data type cannot precisely represent the " + "indices."); } }; @@ -103,6 +113,16 @@ struct ArgSortParam : public dmlc::Parameter { " If not given, the flattened array is used. Default is -1."); DMLC_DECLARE_FIELD(is_ascend).set_default(true) .describe("Whether to sort in ascending or descending order."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .set_default(mshadow::kFloat32) + .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or" + " \"both\". An error will be raised if the selected data type cannot precisely " + "represent the indices."); } }; @@ -324,10 +344,12 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, * \param k the K elements to keep * \param param the topk parameters * \tparam xpu the device type. + * \tparam DType type of the output value/mask. + * \tparam IDType type of the output indices. */ -template -void TopKImpl(RunContext ctx, - Resource resource, +template +void TopKImpl(const RunContext &ctx, + const Resource &resource, const std::vector& req, const TBlob& src, const std::vector& ret, @@ -349,6 +371,10 @@ void TopKImpl(RunContext ctx, TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'IDType' does not have a sufficient precision to represent the indices of the input array. " + << "The total element_num is " << element_num << ", but the selected IDType can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); size_t temp_size = 0; // Temp space needed by the gpu-based full sorts. @@ -426,46 +452,46 @@ void TopKImpl(RunContext ctx, } else if (param.ret_typ == topk_enum::kReturnIndices) { indices = F(indices, element_num); if (do_transpose) { - Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); - Assign(ret_indices, req[0], transpose( + Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); + Assign(ret_indices, req[0], tcast(transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], element_num)), 0, k), - Shape3(0, 2, 1))); + Shape3(0, 2, 1)))); } else { Tensor ret_indices = ret[0].get_with_shape(Shape2(batch_size, k), s); - Assign(ret_indices, req[0], slice<1>( - inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); + Assign(ret_indices, req[0], tcast(slice<1>( + inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k))); } } else { indices = F(indices, element_num); if (do_transpose) { Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); - Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); + Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); Assign(ret_value, req[0], transpose( slice<2>(inplace_reshape(sorted_dat, Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)), 0, k), Shape3(0, 2, 1))); - Assign(ret_indices, req[1], transpose( + Assign(ret_indices, req[1], tcast(transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], element_num)), 0, k), - Shape3(0, 2, 1))); + Shape3(0, 2, 1)))); } else { Tensor ret_value = ret[0].get_with_shape(Shape2(batch_size, k), s); - Tensor ret_indices = - ret[1].get_with_shape(Shape2(batch_size, k), s); + Tensor ret_indices = + ret[1].get_with_shape(Shape2(batch_size, k), s); Assign(ret_value, req[0], slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k)); - Assign(ret_indices, req[1], slice<1>( - inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); + Assign(ret_indices, req[1], tcast(slice<1>( + inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k))); } } } @@ -477,9 +503,17 @@ void TopK(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const TopKParam& param = nnvm::get(attrs.parsed); - MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); - }); + if(param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(param.dtype, IDType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + }) + }); + } else { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + }); + } } template @@ -495,7 +529,7 @@ void Sort(const nnvm::NodeAttrs& attrs, topk_param.k = 0; topk_param.ret_typ = topk_enum::kReturnValue; MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); }); } @@ -510,23 +544,26 @@ void ArgSort(const nnvm::NodeAttrs& attrs, topk_param.axis = param.axis; topk_param.is_ascend = param.is_ascend; topk_param.k = 0; + topk_param.dtype = param.dtype; topk_param.ret_typ = topk_enum::kReturnIndices; MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + MSHADOW_TYPE_SWITCH(param.dtype, IDType, { + TopKImpl(ctx.run_ctx, + ctx.requested[0], req, inputs[0], outputs, topk_param); + }); }); } -template -void TopKBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +template +void TopKBackwardImpl(const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const TopKParam& param) { CHECK_NE(req[0], kWriteInplace); using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.run_ctx.get_stream(); - const TopKParam& param = nnvm::get(attrs.parsed); CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth); int batch_size, element_num; // number of batches + the size of each batch int axis = 0; @@ -536,6 +573,10 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, TShape target_shape; ParseTopKParam(outputs[0].shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'IDType' does not have a sufficient precision to represent the indices of the input array. " + << "The total element_num is " << element_num << ", but the selected IDType can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; Tensor workspace = ctx.requested[0].get_space_typed(Shape1(batch_size * k * 2 + batch_size), s); Tensor sel_indices = @@ -545,49 +586,70 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, Tensor dummy_index = Tensor(workspace.dptr_ + batch_size * k + batch_size, Shape1(batch_size * k), s); - MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor out_grad = - inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); - Tensor in_grad = - outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); - mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, - batch_shift.dptr_); - if (do_transpose) { - Tensor indices = inputs[2].FlatTo1D(s); - TShape src_shape = outputs[0].shape_.FlatTo3D(axis); - sel_indices = reshape(transpose( - broadcast_to(inplace_reshape(batch_shift, - Shape3(src_shape[0], src_shape[2], 1)), - TShape(Shape3(src_shape[0], src_shape[2], k))), - Shape3(0, 2, 1)), - Shape1(batch_size * k)); - sel_indices += indices; - sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), - Shape3(0, 2, 1)); - } else { - Tensor indices = - inputs[2].get_with_shape(Shape2(batch_size, k), s); - sel_indices = reshape(indices + - broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), - TShape(Shape2(batch_size, k))), - Shape1(batch_size * k)); - } - CHECK_EQ(sel_indices.CheckContiguous(), true); - if (kWriteTo == req[0]) { - in_grad = scalar(0); - IndexFill(in_grad, sel_indices, out_grad); - } else if (kAddTo == req[0]) { - // TODO(sxjscience) We can use AddTakeGrad in the future. - // However, the current implementation of AddTakeGrad is not so efficient. - mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo, - dummy_index.dptr_); - mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad); - } else if (kNullOp == req[0]) { - return; - } else { - LOG(FATAL) << "Not Implemented!"; - } - }); + + Tensor out_grad = + inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); + Tensor in_grad = + outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); + mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, + batch_shift.dptr_); + if (do_transpose) { + Tensor indices = inputs[2].FlatTo1D(s); + TShape src_shape = outputs[0].shape_.FlatTo3D(axis); + sel_indices = reshape(transpose( + broadcast_to(inplace_reshape(batch_shift, + Shape3(src_shape[0], src_shape[2], 1)), + TShape(Shape3(src_shape[0], src_shape[2], k))), + Shape3(0, 2, 1)), + Shape1(batch_size * k)); + sel_indices += tcast(indices); + sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), + Shape3(0, 2, 1)); + } else { + Tensor indices = + inputs[2].get_with_shape(Shape2(batch_size, k), s); + sel_indices = reshape(tcast(indices) + + broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), + TShape(Shape2(batch_size, k))), + Shape1(batch_size * k)); + } + CHECK_EQ(sel_indices.CheckContiguous(), true); + if (kWriteTo == req[0]) { + in_grad = scalar(0); + IndexFill(in_grad, sel_indices, out_grad); + } else if (kAddTo == req[0]) { + // TODO(sxjscience) We can use AddTakeGrad in the future. + // However, the current implementation of AddTakeGrad is not so efficient. + mxnet_op::Kernel::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo, + dummy_index.dptr_); + mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad); + } else if (kNullOp == req[0]) { + return; + } else { + LOG(FATAL) << "Not Implemented!"; + } +} + +template +void TopKBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const TopKParam& param = nnvm::get(attrs.parsed); + if (param.ret_typ == topk_enum::kReturnBoth) { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(param.dtype, IDType, { + TopKBackwardImpl(ctx, inputs, req, outputs, param); + }); + }); + } else if (param.ret_typ == topk_enum::kReturnValue) { + MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + TopKBackwardImpl(ctx, inputs, req, outputs, param); + }); + } else { + LOG(FATAL) << "Not Implemented"; + } } inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) { @@ -619,12 +681,17 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_size, 1); CHECK(out_size == 1 || out_size == 2); if (out_size > 1) { - CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) - << "Failed to set the type of ret_indices to int32."; + if(param.ret_typ == topk_enum::kReturnValue) { + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + << "Failed to set the type of ret_indices."; + } else { + CHECK(type_assign(&(*out_attrs)[1], param.dtype)) + << "Failed to set the type of ret_indices."; + } } if (param.ret_typ == topk_enum::kReturnIndices) { - CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) - << "Failed to set the type of ret_indices to int32."; + CHECK(type_assign(&(*out_attrs)[0], param.dtype)) + << "Failed to set the type of ret_indices."; } else { CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; @@ -712,7 +779,8 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs, inline bool ArgSortType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - CHECK(type_assign(&(*out_attrs)[0], mshadow::kInt32)) + const ArgSortParam& param = nnvm::get(attrs.parsed); + CHECK(type_assign(&(*out_attrs)[0], param.dtype)) << "Failed to set the type of ret_indices to int32."; return true; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index fc4217bac676..ba9852aecb51 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -727,7 +727,7 @@ def get_values(ensure_unique, dtype): # test topk with a big shape a = mx.nd.arange(0, 54686454, step=1, repeat=1, dtype=np.int32) - assert_almost_equal(a.topk(k=54686454).asnumpy(), a.asnumpy()[::-1]) + assert_almost_equal(a.topk(k=54686454, dtype=np.int32).asnumpy(), a.asnumpy()[::-1]) a.attach_grad() k = 10 From d16f61471356138cd9c63175b04ebfb6a5ed37a1 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 11:25:01 +0800 Subject: [PATCH 06/14] fix --- src/operator/tensor/ordering_op-inl.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 152e01ef3de9..ab0ff0f419f8 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -58,6 +58,7 @@ struct TopKParam : public dmlc::Parameter { int k; int ret_typ; bool is_ascend; + int dtype; DMLC_DECLARE_PARAMETER(TopKParam) { DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) .describe("Axis along which to choose the top k indices." @@ -107,6 +108,7 @@ struct SortParam : public dmlc::Parameter { struct ArgSortParam : public dmlc::Parameter { dmlc::optional axis; bool is_ascend; + int dtype; DMLC_DECLARE_PARAMETER(ArgSortParam) { DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) .describe("Axis along which to sort the input tensor." From 9d3bdf3829302bc973f4a65a0df0b0ca5e26211b Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 11:26:39 +0800 Subject: [PATCH 07/14] try to fix --- src/operator/tensor/ordering_op-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index ab0ff0f419f8..c722288fc6d1 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -604,7 +604,7 @@ void TopKBackwardImpl(const OpContext &ctx, TShape(Shape3(src_shape[0], src_shape[2], k))), Shape3(0, 2, 1)), Shape1(batch_size * k)); - sel_indices += tcast(indices); + sel_indices = sel_indices + tcast(indices); sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } else { From 3c1d162406cd9d0be469b0218c21172ff7e33475 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 11:39:15 +0800 Subject: [PATCH 08/14] try to fix --- src/operator/tensor/ordering_op-inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index c722288fc6d1..a444565c7bd5 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -463,8 +463,8 @@ void TopKImpl(const RunContext &ctx, 0, k), Shape3(0, 2, 1)))); } else { - Tensor ret_indices = - ret[0].get_with_shape(Shape2(batch_size, k), s); + Tensor ret_indices = + ret[0].get_with_shape(Shape2(batch_size, k), s); Assign(ret_indices, req[0], tcast(slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k))); } @@ -604,7 +604,7 @@ void TopKBackwardImpl(const OpContext &ctx, TShape(Shape3(src_shape[0], src_shape[2], k))), Shape3(0, 2, 1)), Shape1(batch_size * k)); - sel_indices = sel_indices + tcast(indices); + sel_indices += tcast(indices); sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } else { From 51736e11d916c482d55b1000b65b7c5d14d01fa8 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 13:15:01 +0800 Subject: [PATCH 09/14] fix lint --- src/operator/tensor/ordering_op-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index a444565c7bd5..4e0bc4237853 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -505,7 +505,7 @@ void TopK(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const TopKParam& param = nnvm::get(attrs.parsed); - if(param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) { + if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH(param.dtype, IDType, { TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); @@ -683,7 +683,7 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_size, 1); CHECK(out_size == 1 || out_size == 2); if (out_size > 1) { - if(param.ret_typ == topk_enum::kReturnValue) { + if (param.ret_typ == topk_enum::kReturnValue) { CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) << "Failed to set the type of ret_indices."; } else { From c1016c72e242abb40f8ef9623f70a087d21d1b63 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 13:29:08 +0800 Subject: [PATCH 10/14] add more tests --- tests/python/unittest/test_ndarray.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index ba9852aecb51..da40c4e89344 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -717,13 +717,14 @@ def get_values(ensure_unique, dtype): assert_almost_equal(nd_ret_sort, gt) # test for argsort - nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy() - gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) - assert_almost_equal(nd_ret_argsort, gt) - nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy() - gt = gt_topk(a_npy, axis=None, ret_typ="indices", - k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) - assert_almost_equal(nd_ret_argsort, gt) + for idtype in [np.int, np.float16, np.float32, np.float64]: + nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy() + gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) + assert_almost_equal(nd_ret_argsort, gt) + nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy() + gt = gt_topk(a_npy, axis=None, ret_typ="indices", + k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) + assert_almost_equal(nd_ret_argsort, gt) # test topk with a big shape a = mx.nd.arange(0, 54686454, step=1, repeat=1, dtype=np.int32) From b26d24ba71f51ea8fa65f436d65f2f30582c5af1 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Tue, 12 Jun 2018 13:29:49 +0800 Subject: [PATCH 11/14] fix test --- tests/python/unittest/test_ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index da40c4e89344..0234476c4693 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -717,7 +717,7 @@ def get_values(ensure_unique, dtype): assert_almost_equal(nd_ret_sort, gt) # test for argsort - for idtype in [np.int, np.float16, np.float32, np.float64]: + for idtype in [np.int32, np.float16, np.float32, np.float64]: nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy() gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True) assert_almost_equal(nd_ret_argsort, gt) From cde4b5f98b71b028980e7adaf040420655862516 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Fri, 15 Jun 2018 11:15:21 +0800 Subject: [PATCH 12/14] Do not change unrelated file --- tests/python/mkl/test_quantization_mkldnn.py | 56 ++++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/python/mkl/test_quantization_mkldnn.py b/tests/python/mkl/test_quantization_mkldnn.py index ca36ae3158c2..290f1a195c24 100644 --- a/tests/python/mkl/test_quantization_mkldnn.py +++ b/tests/python/mkl/test_quantization_mkldnn.py @@ -1,28 +1,28 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -import sys -import mxnet as mx - -os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] = '1' -curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) -sys.path.insert(0, os.path.join(curr_path, '../quantization')) -from test_quantization import * - -if __name__ == '__main__': - import nose - nose.runmodule() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import sys +import mxnet as mx + +os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] = '1' +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../quantization')) +from test_quantization import * + +if __name__ == '__main__': + import nose + nose.runmodule() From 30a4457325de1125f8c93c767fa665bb09771e5c Mon Sep 17 00:00:00 2001 From: sxjscience Date: Wed, 4 Jul 2018 09:24:13 +0800 Subject: [PATCH 13/14] remove float16 --- src/operator/tensor/ordering_op-inl.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 4e0bc4237853..f1db25e48e26 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -83,7 +83,6 @@ struct TopKParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(dtype) .add_enum("uint8", mshadow::kUint8) .add_enum("int32", mshadow::kInt32) - .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .set_default(mshadow::kFloat32) @@ -118,7 +117,6 @@ struct ArgSortParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(dtype) .add_enum("uint8", mshadow::kUint8) .add_enum("int32", mshadow::kInt32) - .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .set_default(mshadow::kFloat32) From c4bae4dc7b44d36bccfe98f8871bff1a034705a0 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Wed, 4 Jul 2018 11:50:01 +0800 Subject: [PATCH 14/14] Revert "remove float16" This reverts commit 30a4457325de1125f8c93c767fa665bb09771e5c. --- src/operator/tensor/ordering_op-inl.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index f1db25e48e26..4e0bc4237853 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -83,6 +83,7 @@ struct TopKParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(dtype) .add_enum("uint8", mshadow::kUint8) .add_enum("int32", mshadow::kInt32) + .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .set_default(mshadow::kFloat32) @@ -117,6 +118,7 @@ struct ArgSortParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(dtype) .add_enum("uint8", mshadow::kUint8) .add_enum("int32", mshadow::kInt32) + .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .set_default(mshadow::kFloat32)