Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] support train step in to_static #51693

Merged
merged 55 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
3bb1a5f
merge
2742195759 Mar 9, 2023
a7b4981
fix bugs while backward multi-times.
2742195759 Mar 9, 2023
337ca6f
train-step first commit: loss.backward support.
2742195759 Mar 10, 2023
5718f65
1. support Adam and Accumulator.
2742195759 Mar 13, 2023
b88961c
code format by ci
2742195759 Mar 13, 2023
1b27907
merge
2742195759 Mar 14, 2023
54d34f1
Merge remote-tracking branch '2742195759/dy2static-auto-release' into…
SigureMo Mar 15, 2023
79d61d1
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 15, 2023
0aecaac
use `non_static_only` instead of `_non_static_only_`
SigureMo Mar 15, 2023
f7fe08d
fix assertion
SigureMo Mar 15, 2023
eb278aa
restore sample code
SigureMo Mar 15, 2023
354155f
fix op._set_attr
SigureMo Mar 15, 2023
7ca1384
support all optimizer, add uts
SigureMo Mar 15, 2023
6775b4b
fix codestyle
SigureMo Mar 15, 2023
46ce083
add resnet18 tests
SigureMo Mar 15, 2023
618df70
up the train step time limit
SigureMo Mar 16, 2023
531a5ec
up the train step time limit
SigureMo Mar 16, 2023
9f01a0c
up the train step time limit
SigureMo Mar 16, 2023
f0a2da7
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 20, 2023
f30242f
fix prim issues due to merge conflicts
SigureMo Mar 20, 2023
45afdc3
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 20, 2023
b82ee42
`_non_static_only_` -> `non_static_only`
SigureMo Mar 20, 2023
821d4f1
add all lr scheduler tests
SigureMo Mar 20, 2023
d02dced
add notes about `why remove first arg from method`
SigureMo Mar 21, 2023
7507cde
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 21, 2023
72d5c78
for codestyle
SigureMo Mar 21, 2023
90b6714
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 21, 2023
d473b41
inc train step time limit
SigureMo Mar 21, 2023
fabb8c6
remove some lr scheduler test
SigureMo Mar 22, 2023
84d5dc6
Revert "remove some lr scheduler test"
SigureMo Mar 22, 2023
251144d
split test into 3 difference tests
SigureMo Mar 22, 2023
8785129
remove a useless TODO
SigureMo Mar 23, 2023
635d2f5
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 23, 2023
bead669
add lr scheduler support
SigureMo Mar 24, 2023
973df89
Merge branch develop into train-step-siguremo
SigureMo Mar 24, 2023
4017eff
move imports
SigureMo Mar 24, 2023
b8ccbd9
use name to identify the dropout_state in rnn
SigureMo Mar 27, 2023
3d87c63
inc train step time limit
SigureMo Mar 27, 2023
82c5cf0
fix 0d lr scheduler value
SigureMo Mar 27, 2023
3db82c2
add some missing committed changes
SigureMo Mar 27, 2023
3001565
`new_var` -> `tensor`
SigureMo Mar 27, 2023
6ca3d23
Merge branch develop into train-step-siguremo
SigureMo Mar 28, 2023
ed5b0dd
`sheduler` -> `scheduler`
SigureMo Mar 28, 2023
7183931
apply some suggestions
SigureMo Mar 29, 2023
25cac14
test on gpu only
SigureMo Mar 29, 2023
faa5a16
add TIMEOUT for gpu tests
SigureMo Mar 30, 2023
3fb4888
Merge branch develop into train-step-siguremo
SigureMo Mar 30, 2023
21d4f83
set uts timeout to 240
SigureMo Mar 31, 2023
306c19e
Merge branch 'develop' into train-step-siguremo
SigureMo Mar 31, 2023
5a4e933
move uts to new place
SigureMo Mar 31, 2023
fbfe26a
inc train step time limit
SigureMo Mar 31, 2023
38f2255
inc train step time limit
SigureMo Apr 1, 2023
5ba356f
adjust cmake for ut
SigureMo Apr 3, 2023
8e98ae2
adjust cmake for ut
SigureMo Apr 3, 2023
0c48762
adjust cmake for ut
SigureMo Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/paddle/fluid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ def patch_getter(self, item):

self._caches[item_id] = (
concrete_program,
partial_program_from(concrete_program),
partial_program_from(
concrete_program, item.class_instance is not None
),
)
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
Expand Down
11 changes: 5 additions & 6 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'to_variable',
]

# Flag that indicates whether running code under `@to_static`
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"


def in_declarative_mode():
Expand Down Expand Up @@ -143,17 +143,16 @@ def _convert_into_variable(tensor):
# and necessary for inferring. It will be pruned if it's not necessary for inferring.

# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(tensor.shape) > 0
# non-persistable. See case of `dropout_state` in lstm api.
is_persistable = True
if tensor.name.endswith(NON_PERSISTABLE_VAR_NAME_SUFFIX):
is_persistable = False

new_var = tensor._to_static_var(
to_parameter=False, persistable=is_persistable
)
# add param into parameter recorder to collect all the params used in this program.
if new_var.persistable is True:
# TODO(@xiongkun): 0d-tensor may be affected at present,
# but there is no particularly good method to identify whether 0d-tensor
# is used as buffer or "drop_out_state" in LSTM buffer variable.
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
Expand Down
27 changes: 25 additions & 2 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,7 @@ def numpy(self):
"""
pass

@fake_interface_only
@non_static_only
def backward(self, retain_graph=False):
"""
**Notes**:
Expand Down Expand Up @@ -1721,7 +1721,17 @@ def backward(self, retain_graph=False):
loss.backward()

"""
pass
from .backward import append_backward

if retain_graph is True:
raise AssertionError(
"`retain_graph` == True is not supported in @to_static function."
"please set retain_graph = False."
)
param_grad_list = append_backward(self)
for param, param_grad in param_grad_list:
# set grad to simulate dygraph loss.backward() in static mode.
setattr(param, "grad", param_grad)

@fake_interface_only
def gradient(self):
Expand Down Expand Up @@ -7473,6 +7483,19 @@ def _get_var(name, program=None):
return program.global_block().var(name)


@signature_safe_contextmanager
def dygraph_guard_if_declarative():
from .dygraph.base import in_declarative_mode
from .dygraph import Tracer

if in_declarative_mode():
# Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
with _dygraph_guard(tracer=Tracer()):
yield
else:
yield


@signature_safe_contextmanager
def _dygraph_guard(tracer):
tmp_tracer = global_var._dygraph_tracer_
Expand Down
46 changes: 25 additions & 21 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from paddle.nn.layer import layers
from paddle.optimizer.lr import LRScheduler

from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
Expand Down Expand Up @@ -205,6 +206,8 @@ def __call__(self, inputs):
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()

self._sync_lr_value_with_scheduler()

_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
Expand All @@ -219,6 +222,21 @@ def __call__(self, inputs):
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)

def _sync_lr_value_with_scheduler(self):
"""Update lr_var value with calculated by lr_scheduler."""
main_program = self._origin_main_program
if hasattr(main_program, 'lr_scheduler') and hasattr(
main_program, 'lr_var'
):
lr_scheduler = main_program.lr_scheduler
lr_var = main_program.lr_var

assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
lr_scheduler = self._origin_main_program.lr_scheduler
lr_value = lr_scheduler()
data = np.array(lr_value).astype(convert_dtype(lr_var.dtype))
lr_var.set_value(data)

def set_hooker(self, hooker):
self._hooker = hooker

Expand All @@ -240,7 +258,8 @@ def _get_scope(self, program_id=None, use_scope_cache=False):

@LazyInitialized
def _double_grads(self):
return self._get_double_grads(self._origin_main_program)
# TODO: check the affects.
return None

# whole
@switch_to_static_graph
Expand Down Expand Up @@ -658,23 +677,6 @@ def _prune_unused_params(self, program):

self._params = required_params

def _get_double_grads(self, program):
double_grads = []
for block in program.blocks:
for name in block.vars:
if "@GRAD" in name:
var_desc = block.vars[name].desc
var_base = None
var_base = core.eager.Tensor(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
double_grads.append(var_base)
return self._valid_vars(double_grads)

def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
Expand Down Expand Up @@ -1053,9 +1055,11 @@ def _valid_vars(self, vars):
return vars if vars else None


def partial_program_from(concrete_program):
def partial_program_from(concrete_program, from_method=False):
inputs = concrete_program.inputs
if inputs and isinstance(inputs[0], layers.Layer):

# NOTE(SigureMo): Remove the first arg `self` from method args.
if inputs and from_method:
inputs = inputs[1:]

return PartialProgramLayer(
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,9 @@ def _build_once(self, cache_key):
)
)

partial_program = partial_program_from(concrete_program)
partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None
)
if core._is_fwd_prim_enabled() and not _in_amp_guard():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import _C_ops, _legacy_C_ops, framework, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.fluid.framework import (
_non_static_mode,
default_startup_program,
Expand Down Expand Up @@ -1428,7 +1429,8 @@ def flatten_parameters(self):
# dropout state may also can be hided and avoid saving
# should dropout state be persistable for static-graph
self._dropout_state = self.create_variable(
dtype=core.VarDesc.VarType.UINT8
dtype=core.VarDesc.VarType.UINT8,
name=f"dropout_state{NON_PERSISTABLE_VAR_NAME_SUFFIX}",
)
if in_dynamic_mode():
with paddle.no_grad():
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _append_optimize_op(self, block, param_and_grad):
return adam_op

@imperative_base.no_grad
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
Expand All @@ -412,6 +412,10 @@ def step(self):
adam.step()
adam.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return

if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)])

@imperative_base.no_grad
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
Expand All @@ -553,6 +553,10 @@ def step(self):
opt.step()
opt.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return

if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
Expand Down
Loading