-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add fuse_bn_act op #27230
add fuse_bn_act op #27230
Conversation
Thanks for your contribution! |
832e4a3
to
baabd41
Compare
baabd41
to
4efbc49
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for paddle.fluid.contrib.fused_bn_add_act
wihout core.ops
since it is used in static graph only.
'matmul', | ||
'mul', | ||
} | ||
white_list = {'conv2d', 'matmul', 'mul', 'fused_bn_add_activation'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fused_bn_add_activation
should be added in the gray_list
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], | ||
'fused_bn_add_act') | ||
check_variable_and_dtype(y, 'input', ['float16', 'float32', 'float64'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, you have only registered the float16 kernel. So, 'float32' and 'float64' is not needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype check is performed during the compilation time, and the limit to float16 will cause the check to fail.
if in_name != 'X' or in_name != 'Z': | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the condition test here maybe wrong. Maybe the logic below is more understandable.
if in_name not in {'X', 'Z'}:
continue
Maybe the condition test about batch_norm
can be simplified as below:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in {'batch_norm', 'fused_bn_add_activation'}:
if in_name not in {'X', 'Z'}:
continue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
if op.type == 'batch_norm' and out_name != 'Y': | ||
continue | ||
if op.type == 'fused_bn_add_activation' and out_name != 'Y': | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if op.type in {'batch_norm', 'fused_bn_add_activation'} and out_name != 'Y':
continue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
saved_mean->template data<BatchNormParamType<float>>(); | ||
const auto *saved_var_data = | ||
saved_var->template data<BatchNormParamType<float>>(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use T
instead of float
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
2e1df1d
2e1df1d
to
f1ba94b
Compare
f1ba94b
to
bf5a7f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimizationPR changes
OPsDescribe
This Op performs batch norm on input x, and adds the result to input y. Then it performs activation on the sum. We use cuDNN API to implements this function, the following points need to be noted:
cudnnBatchNormalizationForwardTrainingEx requires inputs x and z must be float16.

The data format of inputs must be NHWC
[batch, in_height, in_width, in_channels]
.This Op will be used in automatic mixed precision training of the resnet model. The following image is part of the model. The red parts represent the inputs of this Op. The green parts represent the computation performed by the Op.

Performance of ResNet50 AMP Training
Test on V100, CUDA 10.1, cuDNN 7.6, single card, BS=128
loss and accuracy
set fuse_bn_add_act=true and train 63 epochs




