Skip to content

Commit

Permalink
Unset ReserveSpace of batch_norm for inference program. (#32493)
Browse files Browse the repository at this point in the history
* Unset ReserveSpace for inference program.

* Support training from an inference program.
  • Loading branch information
Xreki authored Apr 26, 2021
1 parent 41bfec8 commit 202b0ea
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ void OpDesc::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args;
}

void OpDesc::RemoveOutput(const std::string &name) {
outputs_.erase(name);
need_update_ = true;
}

bool OpDesc::HasProtoAttr(const std::string &name) const {
auto &op_info = OpInfoMap::Instance();
if (op_info.Has(desc_.type())) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class OpDesc {

void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
void RemoveOutput(const std::string &name);

bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ void BindOpDesc(pybind11::module *m) {
const std::vector<std::string> &vec_var_name) {
self.SetOutput(name, vec_var_name);
})
.def("remove_output", &pd::OpDesc::RemoveOutput)
.def("input_arg_names", &pd::OpDesc::InputArgumentNames)
.def("output_arg_names", &pd::OpDesc::OutputArgumentNames)
.def("_rename_input", &pd::OpDesc::RenameInput)
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,23 @@ def _append_backward_desc(self, infer_program_desc):
# Therefore, in order to reuse the method of backward.py, build the program here.
program = _build_program_by_desc(program_desc_copy)

# 3. Add the outputs which is only used for training and not saved in
# inference program.
for block_idx in six.moves.range(program.num_blocks):
block = program.block(block_idx)
for op in block.ops:
if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])

targets = []
for out in self._output_descs:
targets.append(program.global_block().var(out.name()))
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -5021,6 +5021,9 @@ def _inference_optimize(self, prune_read_op=True):
op = block.op(j)
if op.has_attr('is_test'):
op._set_attr('is_test', True)
if op.type() == "batch_norm":
# Remove the output ReserveSpace of batch_norm if exists.
op.remove_output("ReserveSpace")
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
Expand Down

0 comments on commit 202b0ea

Please sign in to comment.