Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
Bartlomiej Gawrych committed Jan 4, 2022
1 parent 8a46b85 commit 300e3b3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct TakeNonzeroAxisCPU {
* \param i global thread id
* \param out_data ptr to output buffer
* \param in_data ptr to input buffer
* \param idx ptr to indices buffer
* \param indices ptr to indices buffer
* \param outer_dim_stride stride of dimension before axis
* \param axis_dim_stride stride of axis dimension
* \param idx_size size of the indices tensor
Expand All @@ -87,8 +87,8 @@ struct TakeNonzeroAxisCPU {
for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
int index = indices[j];
if (clip) {
index = (index < 0) ? 0 : index;
index = (index > axis_dim - 1) ? (axis_dim - 1) : index;
index = std::max(index, 0);
index = std::min(axis_dim - 1, index);
} else {
index %= axis_dim;
index += (index < 0) ? axis_dim : 0;
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
* or TakeNonZeroAxisCPU for CPU optimized version
/*! \brief TakeNonzeroAxis is desinated for general take when
* axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and
for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU)
*/
template <bool clip = true>
struct TakeNonzeroAxis {
Expand Down

0 comments on commit 300e3b3

Please sign in to comment.