Skip to content

Commit

Permalink
optimization batch_norm 2D and NCHW format on CPU (#34585)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zjq9409 authored Aug 9, 2021
1 parent a3cc2d0 commit 56759ff
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions paddle/fluid/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
bool global_stats = test_mode || use_global_stats;

const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
DataLayout data_layout = framework::StringToDataLayout(data_layout_str);

const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
Expand Down Expand Up @@ -332,6 +331,12 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());

// input dimension is 2 and the format is NCHW. The input can be regarded
// as NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}

if (!global_stats) {
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(
Expand Down Expand Up @@ -578,8 +583,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
DataLayout data_layout = framework::StringToDataLayout(data_layout_str);

auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
Expand Down Expand Up @@ -633,6 +637,12 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
: x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C;

// input dimension is 2 and the format is NCHW. The input can be regarded as
// NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}

// init output
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
Expand Down

0 comments on commit 56759ff

Please sign in to comment.