diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index f85e0bbe44fd3..3f2d60ff12806 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -143,6 +143,19 @@ def __call__(self, key, prog_creator): return self.programs[key], self.op_size[key] +class PartialProgramLayerHook: + def before_append_backward(self, partial_program_layer, forward_program): + ... + + def after_append_backward( + self, partial_program_layer, whole_program, backward_start_idx + ): + ... + + def after_infer(self, partial_program_layer, infer_program): + ... + + class PartialProgramLayer: """ PartialProgramLayer wraps all the ops from layers decorated by `@to_static` @@ -182,6 +195,7 @@ def __init__( # Set default mode to train self.training = True self._infer_info = ProgramInfo() + self._backward_start_index_map = {} custom_white_list, custom_black_list = None, None tracer = framework._dygraph_tracer() @@ -195,6 +209,7 @@ def __init__( # program_id -> list(scope) self._scope_cache = {} + self._hooker = None def __call__(self, inputs): """ @@ -218,6 +233,9 @@ def __call__(self, inputs): restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) + def set_hooker(self, hooker): + self._hooker = hooker + def _get_scope(self, program_id=None, use_scope_cache=False): if use_scope_cache: if program_id not in self._scope_cache: @@ -242,7 +260,12 @@ def _double_grads(self): @switch_to_static_graph def _create_program(self, is_infer_mode=False): if is_infer_mode: - return self._origin_main_program.clone(for_test=is_infer_mode) + infer_program = self._origin_main_program.clone( + for_test=is_infer_mode + ) + if self._hooker: + infer_program = self._hooker.after_infer(self, infer_program) + return infer_program else: train_program = self._append_backward_desc( self._origin_main_program @@ -609,6 +632,8 @@ def _insert_aggregation_ops_for_var(target_program, var): def _append_backward_desc(self, main_program): # make sure all status of is_test are False in train mode. program = _change_is_test_status(main_program.clone(), is_test=False) + if self._hooker: + program = self._hooker.before_append_backward(self, program) targets = [] for out in self._outputs.tolist(): if isinstance(out, framework.Variable): @@ -618,10 +643,16 @@ def _append_backward_desc(self, main_program): # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. core.check_and_set_prim_all_enabled() backward.gradients(targets=targets, inputs=[]) - - start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist()) - - self.prepare_gradient_aggregation(start_idx, main_program, program) + start_idx = ( + len(main_program.block(0).ops) + len(self._outputs.tolist()) + 1 + ) + if self._hooker: + program, start_idx = self._hooker.after_append_backward( + self, program, start_idx + ) + # self._backward_start_index_map[self._hash_with_id(program, self)] + # TODO: prim make this complicate + self.prepare_gradient_aggregation(start_idx, main_program, program) return program @@ -701,6 +732,11 @@ def _prepare_attributes(self): 'program_id', self.program_id, ] + + print(self.forward_program) + print(self.backward_program) + print(self.program_id) + if self.training: # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index f61f34af40655..a51019f4cf85b 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -19,7 +19,6 @@ import warnings import weakref -from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard from paddle.fluid import _non_static_mode, core, framework from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers @@ -39,7 +38,7 @@ create_and_update_origin_info_map, update_op_callstack_with_origin_info, ) -from .partial_program import partial_program_from +from .partial_program import PartialProgramLayerHook, partial_program_from from .utils import ( ALREADY_D2S, ast_to_func, @@ -1182,26 +1181,45 @@ def _build_once(self, cache_key): ) ) - custom_vjps = set() - if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled(): - custom_vjps = { - op.type - for op in concrete_program.main_program.block(0).ops - if core.has_comp_grad_op_maker(op.type) - } - - if core._is_fwd_prim_enabled(): - if not _in_amp_guard() and not _in_pure_fp16_guard(): - _to_prim( - concrete_program.main_program.blocks, exclude=custom_vjps + class PrimHooker(PartialProgramLayerHook): + def __init__(self): + custom_vjps = set() + if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled(): + custom_vjps = { + op.type + for op in concrete_program.main_program.block(0).ops + if core.has_comp_grad_op_maker(op.type) + } + self.custom_vjps = custom_vjps + self.custom_vjps = {"softmax"} + + def before_append_backward( + self, partial_program_layer, forward_program + ): + if core._is_fwd_prim_enabled(): + to_prim(forward_program.block(0), self.custom_vjps) + return forward_program + + def after_append_backward( + self, partial_program_layer, whole_program, backward_start_idx + ): + backward_length = ( + len(whole_program.block(0).ops) - backward_start_idx ) + if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: + to_prim(whole_program.block(0)) + new_start_index = ( + len(whole_program.block(0).ops) - backward_length + ) + return whole_program, new_start_index - partial_program = partial_program_from(concrete_program) - - if core._is_fwd_prim_enabled() and len(custom_vjps) != 0: - if not _in_amp_guard() and not _in_pure_fp16_guard(): - _to_prim(partial_program.forward_program.blocks) + def after_infer(self, partial_program_layer, infer_program): + if core._is_fwd_prim_enabled(): + to_prim(infer_program.block(0)) + return infer_program + partial_program = partial_program_from(concrete_program) + partial_program.set_hooker(PrimHooker()) return concrete_program, partial_program @@ -1675,8 +1693,8 @@ def func(x): @switch_to_static_graph -def _to_prim(blocks, exclude=frozenset()): +def to_prim(blocks, exclude=frozenset()): # TODO(Aurelius84): Fix this cycle import problem from paddle.incubate.autograd import primapi - primapi.to_prim(blocks, exclude=exclude) + primapi.to_prim(blocks, exclude)