-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recompute upgrade #47985
Recompute upgrade #47985
Changes from all commits
3ed8d7d
dad7966
7e89a42
05a0064
ecc19e2
0c216b8
8d76bbb
fb212aa
81929d6
db416b0
bad716d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,12 +13,16 @@ | |
# limitations under the License. | ||
|
||
import contextlib | ||
import weakref | ||
|
||
import paddle | ||
from paddle import framework | ||
from paddle.autograd import PyLayer | ||
from paddle.autograd.py_layer import LegacyPyLayer | ||
from paddle.fluid import core, framework | ||
from paddle.fluid.framework import in_dygraph_mode | ||
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( | ||
get_rng_state_tracker, | ||
) | ||
from paddle.framework import core, in_dygraph_mode | ||
|
||
from ..utils.log_util import logger | ||
|
||
|
@@ -52,10 +56,6 @@ def check_recompute_necessary(inputs): | |
|
||
@contextlib.contextmanager | ||
def swith_rng_state_tracker(rng_state, tracker): | ||
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( | ||
get_rng_state_tracker, | ||
) | ||
|
||
orig_cuda_rng_state = paddle.get_cuda_rng_state() | ||
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() | ||
|
||
|
@@ -71,10 +71,6 @@ def swith_rng_state_tracker(rng_state, tracker): | |
class LegacyRecomputeFunction(LegacyPyLayer): | ||
@staticmethod | ||
def forward(ctx, run_function, preserve_rng_state, *args): | ||
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( | ||
get_rng_state_tracker, | ||
) | ||
|
||
# store for recomputing | ||
ctx.run_function = run_function | ||
ctx.preserve_rng_state = preserve_rng_state | ||
|
@@ -223,10 +219,6 @@ def backward(ctx, *args): | |
class RecomputeFunction(PyLayer): | ||
@staticmethod | ||
def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): | ||
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( | ||
get_rng_state_tracker, | ||
) | ||
|
||
# store for recomputing | ||
ctx.run_function = run_function | ||
ctx.preserve_rng_state = preserve_rng_state | ||
|
@@ -382,6 +374,116 @@ def backward(ctx, *args): | |
return grads | ||
|
||
|
||
def _recompute_without_reentrant( | ||
function, preserve_rng_state=True, *args, **kwargs | ||
): | ||
""" | ||
recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd. | ||
""" | ||
|
||
if preserve_rng_state: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. User can't use CPU or other XXPU? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only support GPU to preserve rng state now. |
||
cur_device = paddle.get_device() | ||
if 'gpu:' not in cur_device: | ||
raise RuntimeError( | ||
"Recompute with RNG perserve is not support current device: {}.".format( | ||
cur_device | ||
) | ||
) | ||
fw_cuda_rng_state = paddle.get_cuda_rng_state() | ||
fwd_cuda_rng_state_tracker = ( | ||
get_rng_state_tracker().get_states_tracker() | ||
) | ||
tracer = framework._dygraph_tracer() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these functions all internal functions with _? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, and only use this way to get tracer from python, the name of _dygraph_tracer is defined by framewark. |
||
is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True | ||
if tracer._amp_level == core.AmpLevel.O2: | ||
amp_level = 'O2' | ||
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): | ||
amp_level = 'O1' | ||
|
||
if tracer._amp_dtype == 'float16': | ||
amp_dtype = 'float16' | ||
elif tracer._amp_dtype in ('bfloat16', 'float32'): | ||
amp_dtype = 'bfloat16' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float32->bfloat16? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
amp_white_list, amp_black_list = tracer._get_amp_op_list() | ||
|
||
class Intermediate_Holder: | ||
pass | ||
|
||
storage = weakref.WeakKeyDictionary() | ||
holder_list = [] | ||
|
||
def pack(x): | ||
res = Intermediate_Holder() | ||
holder_list.append(weakref.ref(res)) | ||
return res | ||
|
||
def unpack(x): | ||
unpack_counter = 0 | ||
if len(storage) == 0: | ||
|
||
def inner_pack(inner_x): | ||
nonlocal unpack_counter | ||
unpack_counter += 1 | ||
|
||
if holder_list[unpack_counter - 1]() is None: | ||
return | ||
|
||
tmp_tensor = core.eager.Tensor( | ||
inner_x.dtype, | ||
inner_x.shape, | ||
inner_x.name + "cpy", | ||
core.VarDesc.VarType.LOD_TENSOR, | ||
inner_x.persistable, | ||
) | ||
inner_x._share_buffer_to(tmp_tensor) | ||
storage[holder_list[unpack_counter - 1]()] = tmp_tensor | ||
return | ||
|
||
def inner_unpack(inner_x): | ||
raise Exception("An unexcepted backward called on a tensor!") | ||
|
||
if preserve_rng_state: | ||
with swith_rng_state_tracker( | ||
fw_cuda_rng_state, fwd_cuda_rng_state_tracker | ||
): | ||
with paddle.set_grad_enabled(True): | ||
with paddle.amp.auto_cast( | ||
enable=is_fw_autocast, | ||
custom_white_list=amp_white_list, | ||
custom_black_list=amp_black_list, | ||
level=amp_level, | ||
dtype=amp_dtype, | ||
): | ||
with paddle.autograd.saved_tensors_hooks( | ||
inner_pack, inner_unpack | ||
): | ||
unused_outputs = function(*args, **kwargs) | ||
else: | ||
with paddle.set_grad_enabled(True), paddle.amp.auto_cast( | ||
enable=is_fw_autocast, | ||
custom_white_list=amp_white_list, | ||
custom_black_list=amp_black_list, | ||
level=amp_level, | ||
dtype=amp_dtype, | ||
), paddle.autograd.saved_tensors_hooks( | ||
inner_pack, inner_unpack | ||
): | ||
unused_outputs = function(*args, **kwargs) | ||
|
||
if x not in storage: | ||
raise Exception( | ||
"Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute." | ||
) | ||
|
||
return storage[x] | ||
|
||
with paddle.autograd.saved_tensors_hooks(pack, unpack): | ||
outputs = function(*args, **kwargs) | ||
|
||
return outputs | ||
|
||
|
||
def recompute(function, *args, **kwargs): | ||
""" | ||
recompute intermediate activations to save then memory. | ||
|
@@ -391,11 +493,13 @@ def recompute(function, *args, **kwargs): | |
whose intermediate activations will be released to save memory in forward stage and will be recomputed | ||
in backward stage for gradient calculation. | ||
*args(Tensor): inputs to the function. | ||
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to | ||
indicate whether to save the forward rng. If it is True, then the last forward rng value will be | ||
restored when the forward recalculation of backpropagation is performed. The default | ||
preserve_rng_state is True. | ||
|
||
**kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params, | ||
and the other contains 'preserve_rng_state' and 'use_reentrant'. the key-value pair of preserve_rng_state, | ||
which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value | ||
will be restored when the forward recalculation of backpropagation is performed, its default value is True. | ||
the key-value pair of use_reentrant is used to indicate which implementation of recompute you will be used. | ||
'use_reentrant=True' means to use the PyLayer implementation of recompute, 'use_reentrant=False' means to | ||
use the Hook implementation of recompute, its default value is True. | ||
Returns: | ||
Output of function on args. | ||
|
||
|
@@ -487,10 +591,21 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): | |
# Hack to mix *args with **kwargs in a python 2.7-compliant way | ||
preserve = kwargs.pop('preserve_rng_state', True) | ||
|
||
# whether to use reentrant method to implement recompute | ||
use_reentrant = kwargs.pop('use_reentrant', True) | ||
|
||
if kwargs and use_reentrant: | ||
raise ValueError( | ||
"Error, if you want to send kwargs(dict parameter) to function, please set use_reentrant=False." | ||
) | ||
|
||
if framework._dygraph_tracer()._has_grad: | ||
check_recompute_necessary(args) | ||
|
||
return RecomputeFunction.apply(function, preserve, *args, **kwargs) | ||
if use_reentrant: | ||
return RecomputeFunction.apply(function, preserve, *args) | ||
else: | ||
return _recompute_without_reentrant(function, preserve, *args, **kwargs) | ||
|
||
|
||
def recompute_sequential(ctx, functions, *args, **kwargs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK