diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 28dea2631885..28eca41b2ae4 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -67,7 +67,7 @@ struct TakeNonzeroAxisCPU { * \param i global thread id * \param out_data ptr to output buffer * \param in_data ptr to input buffer - * \param idx ptr to indices buffer + * \param indices ptr to indices buffer * \param outer_dim_stride stride of dimension before axis * \param axis_dim_stride stride of axis dimension * \param idx_size size of the indices tensor @@ -87,8 +87,8 @@ struct TakeNonzeroAxisCPU { for (index_t j = 0; j < static_cast(idx_size); ++j) { int index = indices[j]; if (clip) { - index = (index < 0) ? 0 : index; - index = (index > axis_dim - 1) ? (axis_dim - 1) : index; + index = std::max(index, 0); + index = std::min(axis_dim - 1, index); } else { index %= axis_dim; index += (index < 0) ? axis_dim : 0; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index af4b559d0692..ed75c8fd270a 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -215,9 +215,9 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } -/*! \brief name the struct TakeNonzeroAxis for general take when - * axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero - * or TakeNonZeroAxisCPU for CPU optimized version +/*! \brief TakeNonzeroAxis is desinated for general take when + * axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and + for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU) */ template struct TakeNonzeroAxis {