From 4e4bd96114827c1685ef462bafb01a9967a9fd0d Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Fri, 2 Jul 2021 14:38:03 +0200 Subject: [PATCH 1/2] Add checks in batchnorm's infer shape --- src/operator/nn/batch_norm.cc | 20 ++++++------- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 30 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 87456dd59f87..86b3e0acd9d1 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -142,7 +142,7 @@ void BatchNormForwardImpl(mshadow::Stream *, const size_t itemCountPerChannel = inputData.Size() / channelCount; #pragma omp parallel for - for (int channel = 0; channel < static_cast(channelCount); ++channel) { + for (size_t channel = 0; channel < channelCount; ++channel) { if (is_train_and_not_global_stats) { // compute mean per input mean[channel] = 0; @@ -253,7 +253,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats; #pragma omp parallel for - for (int channel = 0; channel < static_cast(channelCount); ++channel) { + for (size_t channel = 0; channel < channelCount; ++channel) { const AccReal *weight = weights.dptr(); const AccReal w = !param_.fix_gamma ? weight[channel] : AccReal(1); AccReal mean, invstd; @@ -375,15 +375,15 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, const index_t channelCount = dshape[channelAxis]; - in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount)); - in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount)); - in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean - in_shape->at(batchnorm::kInMovingVar) = mxnet::TShape(Shape1(channelCount)); // kMovingVar + SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kGamma, Shape1(channelCount)); + SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kBeta, Shape1(channelCount)); + SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kInMovingMean, Shape1(channelCount)); // kMovingMean + SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kInMovingVar, Shape1(channelCount)); // kMovingVar - out_shape->clear(); - out_shape->push_back(dshape); // kOut - out_shape->push_back(Shape1(channelCount)); // kMean - out_shape->push_back(Shape1(channelCount)); // kVar + + SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kOut, dshape); + SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kMean, Shape1(channelCount)); + SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kVar, Shape1(channelCount)); return true; } diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 963ed2c5c475..5a6f84c4cdbf 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -159,10 +159,10 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, if (param.axis != 1 || shape.ndim() != 4) { // reshape to (N, C, 1, D) mxnet::TShape new_shape{ - static_cast(shape.ProdShape(0, real_axis)), + static_cast(shape.ProdShape(0, real_axis)), shape[real_axis], 1, - static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ProdShape(real_axis + 1, static_cast(shape.ndim()))) }; in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape); @@ -195,7 +195,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const mkldnn::memory &weight_mem = fwd.GetWeight(); float* weight_buf = reinterpret_cast(weight_mem.get_data_handle()); - nnvm::dim_t channels_ = data.shape()[1]; + index_t channels_ = data.shape()[1]; CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2); float* weight_ptr = gamma.data().dptr(); float* bias_ptr = beta.data().dptr(); @@ -204,13 +204,13 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, memcpy(weight_buf, weight_ptr, copy_size); memcpy(&weight_buf[channels_], bias_ptr, copy_size); } else if (IsBNWriting(req[batchnorm::kGamma])) { - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { weight_buf[i] = 1.0f; weight_ptr[i] = 1.0f; weight_buf[channels_ + i] = bias_ptr[i]; // bias } } else { - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { weight_buf[i] = 1.0f; weight_buf[channels_ + i] = bias_ptr[i]; // bias } @@ -237,7 +237,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, float* inmean = aux_states[batchnorm::kMovingMean].data().dptr(); float* invar = aux_states[batchnorm::kMovingVar].data().dptr(); // to align with origin implmentation: batch_norm.cc: L164 - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { omean[i] = inmean[i]; ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); } @@ -254,7 +254,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, MKLDNNStream::Get()->Submit(); float* ovar = outVar.data().dptr(); - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps); } } @@ -357,10 +357,10 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, if (param.axis != 1 || shape.ndim() != 4) { // reshape to (N, C, 1, D) mxnet::TShape new_shape{ - static_cast(shape.ProdShape(0, real_axis)), + static_cast(shape.ProdShape(0, real_axis)), shape[real_axis], 1, - static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ProdShape(real_axis + 1, static_cast(shape.ndim()))) }; data = data.Reshape(new_shape); @@ -384,7 +384,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; DType *weight_buf = reinterpret_cast(bwd.GetWeight().get_data_handle()); - nnvm::dim_t channels_ = data.shape()[1]; + index_t channels_ = data.shape()[1]; DType *weight_ptr = gamma.data().dptr(); DType* bias_ptr = beta.data().dptr(); const size_t copy_size = sizeof(DType) * channels_; @@ -392,7 +392,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, memcpy(weight_buf, weight_ptr, copy_size); memcpy(&weight_buf[channels_], bias_ptr, copy_size); } else { - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { weight_buf[i] = static_cast(1.0f); } memcpy(&weight_buf[channels_], bias_ptr, copy_size); @@ -422,7 +422,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); DType minus_mom = (1.0f - param.momentum); - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + out_mean_ptr[i] * minus_mom; float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps); @@ -451,13 +451,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, if (req[batchnorm::kGamma] != kAddTo) { memcpy(w_grad_1, gw_buf, copy_size); } else { - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { w_grad_1[i] += gw_buf[i]; } } } } else { - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { (in_grad[1].data().dptr())[i] = 0.0f; } } @@ -468,7 +468,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, memcpy(w_grad_2, &gw_buf[channels_], copy_size); } else { DType *grad_beta = &gw_buf[channels_]; - for (int i = 0; i < channels_; i++) { + for (index_t i = 0; i < channels_; i++) { w_grad_2[i] += grad_beta[i]; } } From 00247222e4f5d57307aadff6c6a1ee1a7a251132 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Mon, 5 Jul 2021 08:36:48 +0200 Subject: [PATCH 2/2] Fix windows build --- src/operator/nn/batch_norm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 86b3e0acd9d1..be0b015f5c6c 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -142,7 +142,7 @@ void BatchNormForwardImpl(mshadow::Stream *, const size_t itemCountPerChannel = inputData.Size() / channelCount; #pragma omp parallel for - for (size_t channel = 0; channel < channelCount; ++channel) { + for (int channel = 0; channel < static_cast(channelCount); ++channel) { if (is_train_and_not_global_stats) { // compute mean per input mean[channel] = 0; @@ -253,7 +253,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats; #pragma omp parallel for - for (size_t channel = 0; channel < channelCount; ++channel) { + for (int channel = 0; channel < static_cast(channelCount); ++channel) { const AccReal *weight = weights.dptr(); const AccReal w = !param_.fix_gamma ? weight[channel] : AccReal(1); AccReal mean, invstd;