Skip to content

Commit

Permalink
apply some suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Mar 29, 2023
1 parent ed5b0dd commit 7183931
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
7 changes: 2 additions & 5 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'to_variable',
]

# Flag that indicates whether running code under `@to_static`
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"


def in_declarative_mode():
Expand Down Expand Up @@ -147,17 +147,14 @@ def _convert_into_variable(tensor):
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `dropout_state` in lstm api.
is_persistable = True
if "dropout_state" in tensor.name:
if tensor.name.endswith(NON_PERSISTABLE_VAR_NAME_SUFFIX):
is_persistable = False

new_var = tensor._to_static_var(
to_parameter=False, persistable=is_persistable
)
# add param into parameter recorder to collect all the params used in this program.
if new_var.persistable is True:
# TODO(@xiongkun): 0d-tensor may be affected at present,
# but there is no particularly good method to identify whether 0d-tensor
# is used as buffer or "dropout_state" in LSTM buffer variable.
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
Expand Down
11 changes: 5 additions & 6 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@
import multiprocessing
import sys
import logging
import paddle

from .proto import framework_pb2, data_feed_pb2


from . import core
from . import unique_name
import paddle.version as fluid_version
Expand Down Expand Up @@ -1719,12 +1717,14 @@ def backward(self, retain_graph=False):
loss.backward()
"""
from .backward import append_backward

if retain_graph is True:
raise AssertionError(
"`retain_graph` == True is not supported in @to_static function."
"please set retain_graph = False."
)
param_grad_list = paddle.fluid.backward.append_backward(self)
param_grad_list = append_backward(self)
for param, param_grad in param_grad_list:
# set grad to simulate dygraph loss.backward() in static mode.
setattr(param, "grad", param_grad)
Expand Down Expand Up @@ -7625,12 +7625,11 @@ def _get_var(name, program=None):
@signature_safe_contextmanager
def dygraph_guard_if_declarative():
from .dygraph.base import in_declarative_mode
from .dygraph import Tracer

if in_declarative_mode():
# Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
with paddle.fluid.framework._dygraph_guard(
tracer=paddle.fluid.dygraph.Tracer()
):
with _dygraph_guard(tracer=Tracer()):
yield
else:
yield
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ if(WITH_PYTHON)
set_tests_properties(test_lac PROPERTIES TIMEOUT 120)
endif()

# Disable on Windows CPU CI for timeout
if(WIN32 AND NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_resnet_amp
)# disable on Windows CPU CI for timeout
list(REMOVE_ITEM TEST_OPS test_resnet_amp)
# TODO(SigureMo): Temporarily disable train step on Windows CPU CI.
# We should remove this after fix the performance issue.
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_adam)
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_sgd)
endif()

foreach(TEST_OP ${TEST_OPS})
Expand All @@ -66,8 +70,6 @@ set_tests_properties(test_transformer PROPERTIES TIMEOUT 200)
set_tests_properties(test_bmn PROPERTIES TIMEOUT 120)
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)
set_tests_properties(test_train_step_resnet18_sgd PROPERTIES TIMEOUT 1200)
set_tests_properties(test_train_step_resnet18_adam PROPERTIES TIMEOUT 1200)

if(NOT WIN32)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 120)
Expand Down
23 changes: 12 additions & 11 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import _C_ops, _legacy_C_ops, framework, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.fluid.framework import (
_non_static_mode,
default_startup_program,
Expand All @@ -47,7 +48,7 @@ def rnn(
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs
**kwargs,
):
r"""
rnn creates a recurrent neural network specified by RNNCell `cell`,
Expand Down Expand Up @@ -109,7 +110,7 @@ def rnn(
sequence_length,
time_major,
is_reverse,
**kwargs
**kwargs,
)
else:
return _rnn_static_graph(
Expand All @@ -119,7 +120,7 @@ def rnn(
sequence_length,
time_major,
is_reverse,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -155,7 +156,7 @@ def _rnn_dynamic_graph(
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs
**kwargs,
):
time_step_index = 0 if time_major else 1
flat_inputs = paddle.utils.flatten(inputs)
Expand Down Expand Up @@ -223,7 +224,7 @@ def _rnn_static_graph(
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs
**kwargs,
):
check_type(inputs, 'inputs', (Variable, list, tuple), 'rnn')
if isinstance(inputs, (list, tuple)):
Expand Down Expand Up @@ -359,7 +360,7 @@ def birnn(
initial_states=None,
sequence_length=None,
time_major=False,
**kwargs
**kwargs,
):
r"""
birnn creates a bidirectional recurrent neural network specified by
Expand Down Expand Up @@ -432,7 +433,7 @@ def birnn(
states_fw,
sequence_length,
time_major=time_major,
**kwargs
**kwargs,
)

outputs_bw, states_bw = rnn(
Expand All @@ -442,7 +443,7 @@ def birnn(
sequence_length,
time_major=time_major,
is_reverse=True,
**kwargs
**kwargs,
)

outputs = paddle.utils.map_structure(
Expand Down Expand Up @@ -1209,7 +1210,7 @@ def forward(
sequence_length=sequence_length,
time_major=self.time_major,
is_reverse=self.is_reverse,
**kwargs
**kwargs,
)
return final_outputs, final_states

Expand Down Expand Up @@ -1296,7 +1297,7 @@ def forward(
initial_states,
sequence_length,
self.time_major,
**kwargs
**kwargs,
)
return outputs, final_states

Expand Down Expand Up @@ -1429,7 +1430,7 @@ def flatten_parameters(self):
# should dropout state be persistable for static-graph
self._dropout_state = self.create_variable(
dtype=core.VarDesc.VarType.UINT8,
name="dropout_state",
name=f"dropout_state{NON_PERSISTABLE_VAR_NAME_SUFFIX}",
)
if in_dynamic_mode():
with paddle.no_grad():
Expand Down

0 comments on commit 7183931

Please sign in to comment.