diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8e3737b72d6fd..ac2dde6ba6998 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3751,7 +3751,6 @@ def __init__(self, program, idx): self.vars = collections.OrderedDict() # var_name --> var self.ops = list() # operator list self.program = program - self.removed_vars = collections.OrderedDict() def __str__(self): return self._to_readable_code() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py index 1a0fe1a6938cb..d25fe730308d4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py @@ -77,7 +77,12 @@ def train(self, use_prim): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) @@ -128,7 +133,12 @@ def train(self, use_prim): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] all_ops = [ op.type for op in net.forward.program_cache.last()[-1][-1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py index 2fce19b3943f1..ad68e1195a968 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py @@ -77,6 +77,7 @@ def _train(self, use_prim, approximate, data): net = apply_to_static(net, use_prim) res = [] + self.x = data for _ in range(10): out = net(data) loss = paddle.mean(out) @@ -92,7 +93,12 @@ def _train(self, use_prim, approximate, data): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that gelu is splitted into small ops self.assertTrue('gelu' not in fwd_ops) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index 6460515c0a8dd..28aac57b2f526 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -89,7 +89,14 @@ def train(self, use_prim): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x, self.w, self.b)[ + 1 + ] + .train_program.block(0) + .ops + ] # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) @@ -150,7 +157,14 @@ def train(self, use_prim): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x, self.w, self.b)[ + 1 + ] + .train_program.block(0) + .ops + ] # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py index e77388742af36..ff18964f7a360 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py @@ -83,6 +83,7 @@ def _train(self, use_prim, data, axis, keep_dim): net = apply_to_static(net, use_prim) res = [] + self.x = data for _ in range(10): out = net(data) loss = paddle.mean(out, axis, keep_dim) @@ -99,7 +100,12 @@ def _train(self, use_prim, data, axis, keep_dim): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) @@ -150,6 +156,7 @@ def _train(self, use_prim, data, axis, keep_dim): net = apply_to_static(net, use_prim) res = [] + self.x = data for _ in range(10): out = net(data) loss = paddle.mean(out, axis, keep_dim) @@ -166,7 +173,12 @@ def _train(self, use_prim, data, axis, keep_dim): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 595f97980f9db..4afa8c1f90f7c 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -315,9 +315,7 @@ def _create_pure_fp16_program(self, is_infer_mode=False): def _create_forward_backward_train_program(self): whole_program = self._train_program # _, forward_end_op_index = self._infer_info('fp32', self._create_program) - forward_end_op_index = self._forward_end_index_map[ - _hash_with_id(whole_program, self) - ] + forward_end_op_index = self.get_forward_end_op_idx(whole_program) assert forward_end_op_index >= 0 return self._get_forward_backward_program_form( @@ -438,11 +436,14 @@ def _infer_pure_fp16_program_id(self): def _param_grad_names(self): return _param_grad_names(self._train_program.desc, self._params) + def get_forward_end_op_idx(self, program): + return self._forward_end_index_map[_hash_with_id(program, self)] + @LazyInitialized def _out_grad_names(self): return _out_grad_names( self._train_program.desc, - self._create_program(is_infer_mode=True).desc.block(0).op_size(), + self.get_forward_end_op_idx(self._train_program), len(self._outputs.var_ids), ) @@ -642,6 +643,7 @@ def _append_backward_desc(self, main_program): if isinstance(out, framework.Variable): targets.append(program.global_block().var(out.name)) + start_idx = len(program.block(0).ops) + len(self._outputs.tolist()) if targets: # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. core.check_and_set_prim_all_enabled() @@ -652,12 +654,11 @@ def _append_backward_desc(self, main_program): program, start_idx = self._hooker.after_append_backward( self, program, start_idx ) - self._forward_end_index_map[ - _hash_with_id(program, self) - ] = start_idx - len(self._outputs.tolist()) - # TODO: prim make this complicate self.prepare_gradient_aggregation(start_idx, main_program, program) + self._forward_end_index_map[ + _hash_with_id(program, self) + ] = start_idx - len(self._outputs.tolist()) return program def _prune_unused_params(self, program): @@ -1155,5 +1156,8 @@ def add_build_strategy_for( if hasattr(compiled_program._program, 'lr_sheduler'): builded_program.lr_sheduler = compiled_program._program.lr_sheduler else: + # can't just create a new program, we need copy the vardesc. builded_program = paddle.static.Program() + for var in program.block(0).vars.values(): + builded_program.block(0)._clone_variable(var, False) return builded_program diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 69a6e004606af..e201915310e41 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1226,7 +1226,6 @@ def after_infer(self, partial_program_layer, infer_program): partial_program.set_hooker(PrimHooker()) return concrete_program, partial_program - def __getitem__(self, item): if not isinstance(item, CacheKey): raise ValueError( diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 34d628c1d35c4..b37ee05f9f0aa 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), ): op = program_desc.block(0).op(i) - if op.type() == 'fill_any_like': + if op.type() in ['fill_any_like', "fill_constant"]: var_name = op.output('Out')[0] names.append(var_name) return names