Skip to content

Commit

Permalink
fixing batch_norm and layer_norm for large tensors (apache#17805) (ap…
Browse files Browse the repository at this point in the history
…ache#18261)

Co-authored-by: Rohit Kumar Srivastava <srivastava.141@buckeyemail.osu.edu>

Co-authored-by: Rohit Kumar Srivastava <srivastava.141@buckeyemail.osu.edu>
  • Loading branch information
ChaiBapchya and Rohit Kumar Srivastava authored May 11, 2020
1 parent 80baab8 commit ceb0f06
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
: param.axis);
CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis;

const int channelCount = dshape[channelAxis];
const index_t channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

const int channelCount = dshape[axis];
const index_t channelCount = dshape[axis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
Expand Down

0 comments on commit ceb0f06

Please sign in to comment.