Skip to content

Commit

Permalink
Pr 50885 (#7)
Browse files Browse the repository at this point in the history
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>
  • Loading branch information
4 people committed Mar 14, 2023
1 parent d0c80f4 commit ecc842f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 26 deletions.
46 changes: 41 additions & 5 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ def __call__(self, key, prog_creator):
return self.programs[key], self.op_size[key]


class PartialProgramLayerHook:
def before_append_backward(self, partial_program_layer, forward_program):
...

def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
...

def after_infer(self, partial_program_layer, infer_program):
...


class PartialProgramLayer:
"""
PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
Expand Down Expand Up @@ -182,6 +195,7 @@ def __init__(
# Set default mode to train
self.training = True
self._infer_info = ProgramInfo()
self._backward_start_index_map = {}

custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer()
Expand All @@ -195,6 +209,7 @@ def __init__(

# program_id -> list(scope)
self._scope_cache = {}
self._hooker = None

def __call__(self, inputs):
"""
Expand All @@ -218,6 +233,9 @@ def __call__(self, inputs):
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)

def set_hooker(self, hooker):
self._hooker = hooker

def _get_scope(self, program_id=None, use_scope_cache=False):
if use_scope_cache:
if program_id not in self._scope_cache:
Expand All @@ -242,7 +260,12 @@ def _double_grads(self):
@switch_to_static_graph
def _create_program(self, is_infer_mode=False):
if is_infer_mode:
return self._origin_main_program.clone(for_test=is_infer_mode)
infer_program = self._origin_main_program.clone(
for_test=is_infer_mode
)
if self._hooker:
infer_program = self._hooker.after_infer(self, infer_program)
return infer_program
else:
train_program = self._append_backward_desc(
self._origin_main_program
Expand Down Expand Up @@ -609,6 +632,8 @@ def _insert_aggregation_ops_for_var(target_program, var):
def _append_backward_desc(self, main_program):
# make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False)
if self._hooker:
program = self._hooker.before_append_backward(self, program)
targets = []
for out in self._outputs.tolist():
if isinstance(out, framework.Variable):
Expand All @@ -618,10 +643,16 @@ def _append_backward_desc(self, main_program):
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[])

start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist())

self.prepare_gradient_aggregation(start_idx, main_program, program)
start_idx = (
len(main_program.block(0).ops) + len(self._outputs.tolist()) + 1
)
if self._hooker:
program, start_idx = self._hooker.after_append_backward(
self, program, start_idx
)
# self._backward_start_index_map[self._hash_with_id(program, self)]
# TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program)

return program

Expand Down Expand Up @@ -701,6 +732,11 @@ def _prepare_attributes(self):
'program_id',
self.program_id,
]

print(self.forward_program)
print(self.backward_program)
print(self.program_id)

if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
Expand Down
60 changes: 39 additions & 21 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import warnings
import weakref

from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers
Expand All @@ -39,7 +38,7 @@
create_and_update_origin_info_map,
update_op_callstack_with_origin_info,
)
from .partial_program import partial_program_from
from .partial_program import PartialProgramLayerHook, partial_program_from
from .utils import (
ALREADY_D2S,
ast_to_func,
Expand Down Expand Up @@ -1182,26 +1181,45 @@ def _build_once(self, cache_key):
)
)

custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
custom_vjps = {
op.type
for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}

if core._is_fwd_prim_enabled():
if not _in_amp_guard() and not _in_pure_fp16_guard():
_to_prim(
concrete_program.main_program.blocks, exclude=custom_vjps
class PrimHooker(PartialProgramLayerHook):
def __init__(self):
custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
custom_vjps = {
op.type
for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
self.custom_vjps = custom_vjps
self.custom_vjps = {"softmax"}

def before_append_backward(
self, partial_program_layer, forward_program
):
if core._is_fwd_prim_enabled():
to_prim(forward_program.block(0), self.custom_vjps)
return forward_program

def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
to_prim(whole_program.block(0))
new_start_index = (
len(whole_program.block(0).ops) - backward_length
)
return whole_program, new_start_index

partial_program = partial_program_from(concrete_program)

if core._is_fwd_prim_enabled() and len(custom_vjps) != 0:
if not _in_amp_guard() and not _in_pure_fp16_guard():
_to_prim(partial_program.forward_program.blocks)
def after_infer(self, partial_program_layer, infer_program):
if core._is_fwd_prim_enabled():
to_prim(infer_program.block(0))
return infer_program

partial_program = partial_program_from(concrete_program)
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program


Expand Down Expand Up @@ -1675,8 +1693,8 @@ def func(x):


@switch_to_static_graph
def _to_prim(blocks, exclude=frozenset()):
def to_prim(blocks, exclude=frozenset()):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi

primapi.to_prim(blocks, exclude=exclude)
primapi.to_prim(blocks, exclude)

0 comments on commit ecc842f

Please sign in to comment.