-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bnorm+relu fuse for mkldnn (inference) #11434
bnorm+relu fuse for mkldnn (inference) #11434
Conversation
7080d1a
to
bfc74df
Compare
894c34f
to
3063cdc
Compare
@@ -21,13 +22,13 @@ | |||
class InferenceTranspiler: | |||
def transpile(self, program, place, scope=None): | |||
''' | |||
Transpile the program. Support only fuse batch normalization now. | |||
Transpile the program. Support only batch normalization and relu fuse now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean batch norm and relu fuse in MKLDNN,
But the plain batch norm is fused here.
self.block.remove_op(i + 1) | ||
i = i + 1 | ||
|
||
# TODO(luotao): use clone() method to flush the program.desc in force, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luotao1 please help review below.
python/paddle/fluid/layers/nn.py
Outdated
**Gather Layer** | ||
|
||
Output is obtained by gathering entries of the outer-most dimension | ||
Output is obtained by gathering entries of the outer-most dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seem that there is some diff in the annotation of nn.py
, do you modify this file? If not, you could remain it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like I have unneceessary deleted "Gather Layer" - I will restore it.
benchmark/fluid/fluid_benchmark.py
Outdated
@@ -131,6 +131,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, | |||
exe = fluid.Executor(place) | |||
exe.run(startup_prog) | |||
|
|||
# Use inference_transpiler to speedup | |||
t = fluid.InferenceTranspiler() | |||
t.transpile(infer_prog, place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not suitable to add transpiler function here, since its a train
benchmark. @typhoonzero Do you have some better suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could move it higher to main
. However, InferenceTranspiler
requires place as a parameter and extracting it happens in train
. So the drawback would be that I would have to extract place again. Or I could add 'place' parameter to train
and extract it in `main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmark of inference is on the original model now. Thus, how about remove line 134-136, or add an option to control it?
if not use_mkldnn: | ||
self.fuse_batch_norm(program, place, scope) | ||
else: | ||
self.fuse_relu(program) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- fuse_batch_norm is also suitable for MKLDNN.
- After
fuse_batch_norm
,fuse_relu
is conv+relu, do you mean that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If
fuse_batch_norm
is suitable for MKLDNN, then I'll leave it without checkinguse_mkldnn
flag. - fuse_relu deletes relu from "batch norm + relu" pair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have fuse_batch_norm
already, there is no batch_norm
op in inference program. Thus, what's the usage of fuse_relu
?
if not use_mkldnn: | ||
self.fuse_batch_norm(program, place, scope) | ||
else: | ||
self.fuse_relu(program) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have fuse_batch_norm
already, there is no batch_norm
op in inference program. Thus, what's the usage of fuse_relu
?
@@ -80,6 +80,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |||
const float epsilon = ctx.Attr<float>("epsilon"); | |||
const float momentum = ctx.Attr<float>("momentum"); | |||
const bool is_test = ctx.Attr<bool>("is_test"); | |||
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why should we add fuse_with_relu
attribute here?
t = fluid.InferenceTranspiler()
t.transpile(infer_prog, place)
is enough. Do you mean there is mkldnn::fuse_bn_relu
function in MKLDNN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mkldnn::fuse_bn_relu
is a flag for MKLDNN batch norm telling it to execute relu along with batch norm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we execute fuse_batch_norm
always (no not ise_mkldnn
if), then fuse_relu
makes sense only in case where there is no conv before batch norm. I don't know if such case in fact ever exists.
If no, then I can skip this PR and create similiar one for the training, which I have already completed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where there is no conv before batch norm
For DenseNet https://github.com/liuzhuang13/DenseNet, there is BN+Relu+Conv, thus, fuse_relu
is useful in this case.
I can skip this PR and create similiar one for the training
fuse_relu
for training is needed as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I assume this PR is OK. After it's merged, I'll create PR for training transpiler.
2ef33af
to
c2a8d2c
Compare
benchmark/fluid/fluid_benchmark.py
Outdated
@@ -131,6 +131,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, | |||
exe = fluid.Executor(place) | |||
exe.run(startup_prog) | |||
|
|||
# Use inference_transpiler to speedup | |||
t = fluid.InferenceTranspiler() | |||
t.transpile(infer_prog, place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmark of inference is on the original model now. Thus, how about remove line 134-136, or add an option to control it?
current_op.set_attr("fuse_with_relu", True) | ||
# remove relu OP | ||
self.block.remove_op(i + 1) | ||
i = i + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- could you give some unit test to validate the accuracy of
fuse_with_relu
? - could you call
self._remove_unused_var
, sinceremove_op
will not remove variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There are several optimizations, only fuse batch normalization is supported now. | ||
Convert the fluid program to optimized inference program. | ||
|
||
There are several optimizations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several optimizations:
- fuse convolution and batch normalization
- fuse batch normalization and relu (MKLDNN only)
Transpile the program by fused relu activation for MKLDNN program. | ||
|
||
Relu activation following batch norm OP can be fused by adding | ||
'fuse_with_relu' attribute to batch norm OP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'fuse_with_relu' -> :math:fuse_with_relu
- before: | ||
- batch_norm->relu->any_other_op | ||
- after: | ||
- batch_norm->any_other_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format of line 69-73 is not correct. You can use https://github.com/PaddlePaddle/FluidDoc to see the generated html. And paste the generated picture here like #11521.
If you have any question about how to generate the API reference, please feel open to ask me.
if use_mkldnn: | ||
self.fuse_relu(program) | ||
|
||
def fuse_relu(self, program): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- how about the function name of
fuse_relu_mkldnn
? - you may add the check of
FLAGS_use_mkldnn=True
in this function, thus, other people will not use it in plain CPU.
07398ef
to
67d1640
Compare
67d1640
to
610dc19
Compare
|
||
i = 0 | ||
while i < len(self.block.ops): | ||
while i < len(self.block.ops) - 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you change line 159?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because in lines 174 and 183 we access i+2
element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I've added batch norm + relu fuse case to inference_transpiler.
In next step, I'm going to create training transpiler doing same operation.