Skip to content

Commit

Permalink
[Dy2St]Fix clone for test state problem (#51966)
Browse files Browse the repository at this point in the history
* [Dy2St]Fix clone for test state problem

* clean code
  • Loading branch information
Aurelius84 authored Mar 23, 2023
1 parent 2142326 commit 5031b44
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,6 @@ def __get__(self, instance, cls):
return val


def _change_is_test_status(program, is_test):
# change all `is_test` attributes
for block in program.blocks:
for op in block.ops:
if op.has_attr('is_test'):
op._set_attr('is_test', is_test)
return program


class ProgramInfo:
"""
A helper class to recoder Program information
Expand Down Expand Up @@ -618,8 +609,7 @@ def _insert_aggregation_ops_for_var(target_program, var):

@switch_to_static_graph
def _append_backward_desc(self, main_program):
# make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False)
program = main_program.clone(for_test=False)
if self._hooker:
program = self._hooker.before_append_backward(program)
targets = []
Expand Down

0 comments on commit 5031b44

Please sign in to comment.