Skip to content

Commit

Permalink
[Dy2St] support train step in to_static (#51693)
Browse files Browse the repository at this point in the history
Co-authored-by: xiongkun <xiongkun03@baidu.com>
  • Loading branch information
SigureMo and 2742195759 authored Apr 4, 2023
1 parent 15aa73d commit 7728efb
Show file tree
Hide file tree
Showing 13 changed files with 673 additions and 100 deletions.
4 changes: 3 additions & 1 deletion python/paddle/fluid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ def patch_getter(self, item):

self._caches[item_id] = (
concrete_program,
partial_program_from(concrete_program),
partial_program_from(
concrete_program, item.class_instance is not None
),
)
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
Expand Down
11 changes: 5 additions & 6 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'to_variable',
]

# Flag that indicates whether running code under `@to_static`
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"


def in_declarative_mode():
Expand Down Expand Up @@ -143,17 +143,16 @@ def _convert_into_variable(tensor):
# and necessary for inferring. It will be pruned if it's not necessary for inferring.

# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(tensor.shape) > 0
# non-persistable. See case of `dropout_state` in lstm api.
is_persistable = True
if tensor.name.endswith(NON_PERSISTABLE_VAR_NAME_SUFFIX):
is_persistable = False

new_var = tensor._to_static_var(
to_parameter=False, persistable=is_persistable
)
# add param into parameter recorder to collect all the params used in this program.
if new_var.persistable is True:
# TODO(@xiongkun): 0d-tensor may be affected at present,
# but there is no particularly good method to identify whether 0d-tensor
# is used as buffer or "drop_out_state" in LSTM buffer variable.
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
Expand Down
27 changes: 25 additions & 2 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ def numpy(self):
"""
pass

@fake_interface_only
@non_static_only
def backward(self, retain_graph=False):
"""
**Notes**:
Expand Down Expand Up @@ -1657,7 +1657,17 @@ def backward(self, retain_graph=False):
loss.backward()
"""
pass
from .backward import append_backward

if retain_graph is True:
raise AssertionError(
"`retain_graph` == True is not supported in @to_static function."
"please set retain_graph = False."
)
param_grad_list = append_backward(self)
for param, param_grad in param_grad_list:
# set grad to simulate dygraph loss.backward() in static mode.
setattr(param, "grad", param_grad)

@fake_interface_only
def gradient(self):
Expand Down Expand Up @@ -7396,6 +7406,19 @@ def _get_var(name, program=None):
return program.global_block().var(name)


@signature_safe_contextmanager
def dygraph_guard_if_declarative():
from .dygraph.base import in_declarative_mode
from .dygraph import Tracer

if in_declarative_mode():
# Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
with _dygraph_guard(tracer=Tracer()):
yield
else:
yield


@signature_safe_contextmanager
def _dygraph_guard(tracer):
tmp_tracer = global_var._dygraph_tracer_
Expand Down
46 changes: 25 additions & 21 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from paddle.nn.layer import layers
from paddle.optimizer.lr import LRScheduler

from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
Expand Down Expand Up @@ -205,6 +206,8 @@ def __call__(self, inputs):
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()

self._sync_lr_value_with_scheduler()

_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
Expand All @@ -219,6 +222,21 @@ def __call__(self, inputs):
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)

def _sync_lr_value_with_scheduler(self):
"""Update lr_var value with calculated by lr_scheduler."""
main_program = self._origin_main_program
if hasattr(main_program, 'lr_scheduler') and hasattr(
main_program, 'lr_var'
):
lr_scheduler = main_program.lr_scheduler
lr_var = main_program.lr_var

assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
lr_scheduler = self._origin_main_program.lr_scheduler
lr_value = lr_scheduler()
data = np.array(lr_value).astype(convert_dtype(lr_var.dtype))
lr_var.set_value(data)

def set_hooker(self, hooker):
self._hooker = hooker

Expand All @@ -240,7 +258,8 @@ def _get_scope(self, program_id=None, use_scope_cache=False):

@LazyInitialized
def _double_grads(self):
return self._get_double_grads(self._origin_main_program)
# TODO: check the affects.
return None

# whole
@switch_to_static_graph
Expand Down Expand Up @@ -658,23 +677,6 @@ def _prune_unused_params(self, program):

self._params = required_params

def _get_double_grads(self, program):
double_grads = []
for block in program.blocks:
for name in block.vars:
if "@GRAD" in name:
var_desc = block.vars[name].desc
var_base = None
var_base = core.eager.Tensor(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
double_grads.append(var_base)
return self._valid_vars(double_grads)

def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
Expand Down Expand Up @@ -1053,9 +1055,11 @@ def _valid_vars(self, vars):
return vars if vars else None


def partial_program_from(concrete_program):
def partial_program_from(concrete_program, from_method=False):
inputs = concrete_program.inputs
if inputs and isinstance(inputs[0], layers.Layer):

# NOTE(SigureMo): Remove the first arg `self` from method args.
if inputs and from_method:
inputs = inputs[1:]

return PartialProgramLayer(
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,9 @@ def _build_once(self, cache_key):
)
)

partial_program = partial_program_from(concrete_program)
partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None
)
if core._is_fwd_prim_enabled() and not _in_amp_guard():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import _C_ops, _legacy_C_ops, framework, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.fluid.framework import (
_non_static_mode,
default_startup_program,
Expand Down Expand Up @@ -1428,7 +1429,8 @@ def flatten_parameters(self):
# dropout state may also can be hided and avoid saving
# should dropout state be persistable for static-graph
self._dropout_state = self.create_variable(
dtype=core.VarDesc.VarType.UINT8
dtype=core.VarDesc.VarType.UINT8,
name=f"dropout_state{NON_PERSISTABLE_VAR_NAME_SUFFIX}",
)
if in_dynamic_mode():
with paddle.no_grad():
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _append_optimize_op(self, block, param_and_grad):
return adam_op

@imperative_base.no_grad
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
Expand All @@ -412,6 +412,10 @@ def step(self):
adam.step()
adam.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return

if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)])

@imperative_base.no_grad
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
Expand All @@ -553,6 +553,10 @@ def step(self):
opt.step()
opt.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return

if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
Expand Down
Loading

0 comments on commit 7728efb

Please sign in to comment.