Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fuse_bn_act op #27230

Merged
merged 6 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op")
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op"
"fused_bn_add_activation_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ register_operators(EXCLUDES
multihead_matmul_op
fused_embedding_eltwise_layernorm_op
fusion_group_op
fusion_gru_op)
fusion_gru_op
fused_bn_add_activation_op)

# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
Expand Down Expand Up @@ -47,4 +48,9 @@ if (WITH_GPU)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_group);\n")
cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op)
endif()
# fused_bn_add_activation
if (NOT ${CUDNN_VERSION} VERSION_LESS 7401)
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
endif()
endif()
255 changes: 255 additions & 0 deletions paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fused_bn_add_activation_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

void FusedBatchNormAddActOp::InferShape(
framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias",
"FusedBatchNormAddActOp");

// check output
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasOutput("MeanOut"), "Output", "MeanOut",
"FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasOutput("VarianceOut"), "Output", "VarianceOut",
"FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedMean"), "Output", "SavedMean",
"FusedBatchNormAddActOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedVariance"), "Output", "SavedVariance",
"FusedBatchNormAddActOp");

const auto x_dims = ctx->GetInputDim("X");
const auto z_dims = ctx->GetInputDim("Z");
PADDLE_ENFORCE_EQ(x_dims, z_dims,
platform::errors::InvalidArgument(
"ShapeError: the shapes of input "
"must be equal. But received: the shape "
"of input X = [%s], and the shape of "
"input Y = [%s]",
x_dims, z_dims));
PADDLE_ENFORCE_GE(x_dims.size(), 2, platform::errors::InvalidArgument(
"ShapeError: the dimensions of input "
"must greater than or equal to 2."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_LE(x_dims.size(), 5, platform::errors::InvalidArgument(
"ShapeError: the dimensions of input "
"must smaller than or equal to 5."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]",
x_dims, x_dims.size()));

const int64_t C = x_dims[x_dims.size() - 1];

auto scale_dim = ctx->GetInputDim("Scale");
auto bias_dim = ctx->GetInputDim("Bias");

PADDLE_ENFORCE_EQ(
scale_dim.size(), 1UL,
platform::errors::InvalidArgument(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]",
scale_dim, scale_dim.size()));
PADDLE_ENFORCE_EQ(bias_dim.size(), 1UL,
platform::errors::InvalidArgument(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]",
bias_dim, bias_dim.size()));

bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
framework::product(bias_dim) <= 0)) {
check = false;
}

if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0], C,
platform::errors::InvalidArgument(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]",
C, scale_dim[0]));
PADDLE_ENFORCE_EQ(bias_dim[0], C,
platform::errors::InvalidArgument(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]",
C, bias_dim[0]));
}
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}

framework::OpKernelType FusedBatchNormAddActOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto bn_param_type = framework::proto::VarType::FP32;

PADDLE_ENFORCE_EQ(
bn_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument("Scale input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument("Bias input should be of float type"));

framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;

return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}

void FusedBatchNormAddActOpMaker::Make() {
AddInput("X", "The input tensor");
AddInput("Z", "The input tensor");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size C "
"that is applied to the output");
AddInput("Bias",
"Bias is a 1-dimensional tensor of size C "
"that is applied to the output");
AddOutput("Y", "result after normalization");
AddOutput("MeanOut",
"Share memory with Mean. "
"Store the global mean when training");
AddOutput("VarianceOut",
"Share memory with Variance. "
"Store the global Variance when training");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel");
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' should be between 0.0 and 0.001."));
});
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddComment(R"DOC(
Fused Batch Normalization with activation.

Batch Norm has been implemented as discussed in the paper:
https://arxiv.org/pdf/1502.03167.pdf
Batch Norm can be used as a normalizer function for conv2d and fully_connected operations.
Now, the required data format for FusedBatchNormAddActOp is NHWC `[batch, in_height, in_width, in_channels]`.

)DOC");
}

void FusedBatchNormAddActGradOp::InferShape(
framework::InferShapeContext *ctx) const {
// check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "FusedBatchNormAddActGradOp");

// check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), "Output",
framework::GradVarName("Z"), "FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Scale")), "Output",
framework::GradVarName("Scale"), "FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Bias")), "Output",
framework::GradVarName("Bias"), "FusedBatchNormAddActGradOp");

const auto in_dims = ctx->GetInputDim("X");
const int C = in_dims[in_dims.size() - 1];

ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
ctx->SetOutputDim(framework::GradVarName("Z"), in_dims);
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}

framework::OpKernelType FusedBatchNormAddActGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find Y@GRAD in the execution context."));
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("Can not get the tensor value of Y@GRAD."));
}

framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
}

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_bn_add_activation, ops::FusedBatchNormAddActOp,
ops::FusedBatchNormAddActOpMaker, ops::FusedBatchNormAddActOpInferVarType,
ops::FusedBatchNormAddActGradOpMaker<paddle::framework::OpDesc>,
ops::FusedBatchNormAddActGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_bn_add_activation_grad,
ops::FusedBatchNormAddActGradOp);
Loading