Skip to content

Commit

Permalink
Revert "modify"
Browse files Browse the repository at this point in the history
This reverts commit fd4e10b.
  • Loading branch information
ZzSean committed Oct 14, 2021
1 parent fd4e10b commit 33e261b
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 96 deletions.
18 changes: 9 additions & 9 deletions paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ template <typename T>
class CudnnScaleBiasAddRelu {
public:
CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx,
const std::string &act_type, bool fuse_add,
const std::string &act_type, bool fused_add,
bool has_shortcut, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape)
: fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK),
bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) {
fuse_add_ = fuse_add;
fused_add_ = fused_add;
has_shortcut_ = has_shortcut;
args_.Set(act_type, data_shape, param_shape, bitmask_shape);
}
Expand Down Expand Up @@ -132,7 +132,7 @@ class CudnnScaleBiasAddRelu {
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fuse_add_) {
if (fused_add_) {
T *z_ptr = const_cast<T *>(z->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
Expand Down Expand Up @@ -200,7 +200,7 @@ class CudnnScaleBiasAddRelu {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr);
bwd_op_.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON,
&eps);
if (has_shortcut_ || fuse_add_) {
if (has_shortcut_ || fused_add_) {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr);
}

Expand All @@ -227,14 +227,14 @@ class CudnnScaleBiasAddRelu {
{CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
} else if (fuse_add_) {
} else if (fused_add_) {
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}

// input desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
if (has_shortcut_ || fuse_add_) {
if (has_shortcut_ || fused_add_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc());
}

Expand Down Expand Up @@ -272,15 +272,15 @@ class CudnnScaleBiasAddRelu {
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_ || fuse_add_) {
if (has_shortcut_ || fused_add_) {
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}

// input desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc());
if (has_shortcut_ || fuse_add_) {
if (has_shortcut_ || fused_add_) {
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc());
}

Expand All @@ -304,7 +304,7 @@ class CudnnScaleBiasAddRelu {
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}

bool fuse_add_ = false;
bool fused_add_ = false;
bool has_shortcut_ = false;
size_t fwd_workspace_byte_;
size_t bwd_workspace_byte_;
Expand Down
81 changes: 37 additions & 44 deletions paddle/fluid/operators/fused/resnet_unit_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ namespace operators {

using Tensor = framework::Tensor;

// Shape of bitmask
static framework::DDim GetBitmaskDims(std::vector<int> out_shape) {
int c = out_shape.back();
int64_t nhw = std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()) /
c;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = ((nhw + 31) & ~31);
std::vector<int> bitmask_shape = {nhw_int32_elems, c_int32_elems, 1};
return framework::make_ddim(bitmask_shape);
}

class ResNetUnitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand All @@ -46,9 +34,9 @@ class ResNetUnitOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("MeanX"), "Input", "MeanX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("VarX"), "Input", "VarX", "ResNetUnitOp");

bool fuse_add = ctx->Attrs().Get<bool>("fuse_add");
bool fused_add = ctx->Attrs().Get<bool>("fused_add");
bool has_shortcut = ctx->Attrs().Get<bool>("has_shortcut");
if (fuse_add || has_shortcut) {
if (fused_add || has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitOp");
}
if (has_shortcut) {
Expand Down Expand Up @@ -110,46 +98,50 @@ class ResNetUnitOp : public framework::OperatorWithKernel {
const auto x_dims = ctx->GetInputDim("X");
const auto w_dims = ctx->GetInputDim("FilterX");
const auto bn_param_dims = ctx->GetInputDim("ScaleX");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument(
"The dimensions of input "
"must equal to 4."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_EQ(
x_dims.size(), 4,
platform::errors::InvalidArgument("ShapeError: the dimensions of input "
"must equal to 4."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of filter "
"ShapeError: the dimensions of filter "
"must equal to 4."
"But received: the shape of filter "
"= [%s], the dimension of filter = [%d] ",
"= [%s], the dimension of filter = "
"[%d]",
w_dims, w_dims.size()));
PADDLE_ENFORCE_EQ(bn_param_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of bn param "
"ShapeError: the dimensions of bn param "
"must equal to 4."
"But received: the shape of bn param "
"= [%s], the dimension of bn param = [%d] ",
"= [%s], the dimension of bn param = "
"[%d]",
bn_param_dims, bn_param_dims.size()));
auto data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(
data_format, "NHWC",
platform::errors::InvalidArgument("The data format must equal to NHWC. "
"But received: the data format "
"= [%s]",
data_format));
// Calculate the dims of outputs
int batch = x_dims[0];
int output_channel = w_dims[0];
int filter_size = w_dims[2];
int stride = ctx->Attrs().Get<int>("stride");
int padding = ctx->Attrs().Get<int>("padding");
int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1;
int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1;
int pad = ctx->Attrs().Get<int>("pad");
int out_h = (x_dims[1] + pad * 2 - filter_size) / stride + 1;
int out_w = (x_dims[2] + pad * 2 - filter_size) / stride + 1;
std::vector<int> out_shape = {batch, out_h, out_w, output_channel};
// Shape of bitmask
int C = output_channel;
int64_t NHW = std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()) /
output_channel;
int32_t C_int32Elems = ((C + 63) & ~63) / 32;
int32_t NHW_int32Elems = ((NHW + 31) & ~31);
std::vector<int> bitmask_shape = {NHW_int32Elems, C_int32Elems, 1};

auto y_dims = framework::make_ddim(out_shape);
auto bitmask_dims = GetBitmaskDims(out_shape);
auto bitmask_dims = framework::make_ddim(bitmask_shape);
// Set dims of outputs
ctx->SetOutputDim("Y", y_dims);
ctx->SetOutputDim("BitMask", bitmask_dims);
Expand Down Expand Up @@ -219,13 +211,14 @@ class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("RunningVarZ", "Shared memory with VarZ").AsDispensable();
AddAttr<int>("stride", "").SetDefault(1);
AddAttr<int>("stride_z", "").SetDefault(1);
AddAttr<int>("padding", "").SetDefault(0);
AddAttr<int>("dilation", "").SetDefault(1);
AddAttr<int>("pad", "").SetDefault(0);
AddAttr<int>("dilate", "").SetDefault(1);
AddAttr<int>("group", "").SetDefault(1);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "").SetDefault(1e-5);
AddAttr<std::string>("data_format", "").SetDefault("NHWC");
AddAttr<bool>("fuse_add", "").SetDefault(false);
AddAttr<std::string>("conv_format", "").SetDefault("NHWC");
AddAttr<std::string>("bn_format", "").SetDefault("NHWC");
AddAttr<bool>("fused_add", "").SetDefault(false);
AddAttr<bool>("has_shortcut", "").SetDefault(false);
AddAttr<bool>("use_global_stats", "").SetDefault(false);
AddAttr<bool>("is_test",
Expand Down Expand Up @@ -266,9 +259,9 @@ class ResNetUnitGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("SavedInvstdX"), "Input", "SavedInvstdX",
"ResNetUnitGradOp");

bool fuse_add = ctx->Attrs().Get<bool>("fuse_add");
bool fused_add = ctx->Attrs().Get<bool>("fused_add");
bool has_shortcut = ctx->Attrs().Get<bool>("has_shortcut");
if (fuse_add || has_shortcut) {
if (fused_add || has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitGradOp");
}
if (has_shortcut) {
Expand Down Expand Up @@ -300,7 +293,7 @@ class ResNetUnitGradOp : public framework::OperatorWithKernel {
framework::GradVarName("ScaleX"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasX")), "Output",
framework::GradVarName("BiasX"), "ResNetUnitGradOp");
if (fuse_add) {
if (fused_add) {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), "Output",
framework::GradVarName("Z"), "ResNetUnitGradOp");
}
Expand All @@ -320,7 +313,7 @@ class ResNetUnitGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("FilterX"), filter_x_dims);
ctx->SetOutputDim(framework::GradVarName("ScaleX"), param_dims);
ctx->SetOutputDim(framework::GradVarName("BiasX"), param_dims);
if (fuse_add || has_shortcut) {
if (fused_add || has_shortcut) {
const auto z_dims = ctx->GetInputDim("Z");
ctx->SetOutputDim(framework::GradVarName("Z"), z_dims);
}
Expand Down
27 changes: 13 additions & 14 deletions paddle/fluid/operators/fused/resnet_unit_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
Tensor *output = ctx.Output<Tensor>("Y");
Tensor *bitmask = ctx.Output<Tensor>("BitMask");
// attrs
int padding = ctx.Attr<int>("padding");
int pad = ctx.Attr<int>("pad");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool fused_add = ctx.Attr<bool>("fused_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool is_test = ctx.Attr<bool>("is_test");
bool is_train = !is_test && !use_global_stats;
Expand All @@ -87,8 +87,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_x.Resize(param_dims);
sum_of_squares_x.Resize(param_dims);
CudnnNormConvolution<T> conv_x_op(dev_ctx, input_x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
group);
output_shape, pad, stride, dilate, group);
conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x,
&sum_of_squares_x);

Expand All @@ -104,7 +103,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
is_train);

// 3. scale + bias + add + relu
CudnnScaleBiasAddRelu<T> sbar_op(dev_ctx, act_type, fuse_add, has_shortcut,
CudnnScaleBiasAddRelu<T> sbar_op(dev_ctx, act_type, fused_add, has_shortcut,
output_shape, param_shape, bitmask_shape);
if (has_shortcut) {
// input z
Expand All @@ -129,7 +128,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_z.Resize(param_dims);
sum_of_squares_z.Resize(param_dims);
CudnnNormConvolution<T> conv_z_op(dev_ctx, input_z_shape, filter_z_shape,
output_shape, padding, stride_z, dilate,
output_shape, pad, stride_z, dilate,
group);
conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z,
&sum_of_squares_z);
Expand All @@ -149,7 +148,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
conv_out_z, &equiv_scale_z, &equiv_bias_z, output,
bitmask);
} else {
const Tensor *input_z = fuse_add ? ctx.Input<Tensor>("Z") : nullptr;
const Tensor *input_z = fused_add ? ctx.Input<Tensor>("Z") : nullptr;
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
input_z, nullptr, nullptr, output, bitmask);
}
Expand Down Expand Up @@ -186,15 +185,15 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
Tensor *scale_x_grad = ctx.Output<Tensor>(framework::GradVarName("ScaleX"));
Tensor *bias_x_grad = ctx.Output<Tensor>(framework::GradVarName("BiasX"));

int padding = ctx.Attr<int>("padding");
int pad = ctx.Attr<int>("pad");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool fused_add = ctx.Attr<bool>("fused_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
std::string act_type = ctx.Attr<std::string>("act_type");

Expand All @@ -211,7 +210,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
// scale_x_grad, bias_x_grad
Tensor conv_out_x_grad;
conv_out_x_grad.Resize(conv_out_x->dims());
CudnnScaleBiasAddRelu<T> sbar_x_op(dev_ctx, act_type, fuse_add,
CudnnScaleBiasAddRelu<T> sbar_x_op(dev_ctx, act_type, fused_add,
has_shortcut, output_shape, param_shape,
bitmask_shape);
if (has_shortcut) {
Expand Down Expand Up @@ -262,15 +261,15 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
auto z_shape = framework::vectorize<int>(z->dims());
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
CudnnNormConvolutionGrad<T> conv_z_op(dev_ctx, z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilate, group);
output_shape, pad, stride_z, dilate,
group);
conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad,
filter_z_grad);
} else {
// 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad (and z_grad)
Tensor *z_grad =
fuse_add ? ctx.Output<Tensor>(framework::GradVarName("Z")) : nullptr;
fused_add ? ctx.Output<Tensor>(framework::GradVarName("Z")) : nullptr;
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, z_grad, scale_x_grad, bias_x_grad,
Expand All @@ -280,7 +279,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool use_addto = ctx.Attr<bool>("use_addto");
CudnnNormConvolutionGrad<T> conv_x_op(dev_ctx, x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
output_shape, pad, stride, dilate,
group);
conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad,
filter_x_grad, use_addto);
Expand Down
14 changes: 5 additions & 9 deletions python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,20 @@ def _dtype_to_str(dtype):


def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['batch_norm', 'layer_norm']:
if op.type in ['batch_norm', 'layer_norm']:
# Scale, Bias, Mean, Variance should be float32.
return in_name != 'X'
if op_type == 'fused_bn_add_activation':
if op.type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
if op.type == 'resnet_unit':
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
return False


def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']:
if op.type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']:
return out_name != 'Y'
if op_type == 'resnet_unit':
if op.type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
return False


def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_resnet_unit_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def build_fused_program(self,
num_filters=256,
filter_size=1,
stride=1,
fuse_add=True,
fused_add=True,
has_shortcut=True,
filter_x_attr=self.conv_param_attr1,
scale_x_attr=self.bn_scale_attr1,
Expand Down
Loading

0 comments on commit 33e261b

Please sign in to comment.