diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index ac533db098c61..a7782b6d8d130 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -232,6 +232,11 @@ def build_state(self): return is_train def _mark_black_white_ops(self, op, ops, block): + # deal auto_cast info + if not op.amp_options.enable: + self._op_fp16_dict[op.desc.original_id()] = False + return + # ernie inference trick if op.type == "assign" and "array_" in op.input_arg_names[0]: self._op_fp16_dict[op.desc.original_id()] = False diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 2985b4da290f4..92259dee3ae05 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -209,6 +209,9 @@ def _build_state(self): for block in self.program.blocks: self.resolute_tensor_dtype(block) + for block in self.program.blocks: + self.resolute_cast_op(block) + # insert cast ops for block in self.program.blocks: self.cast_block(block) @@ -296,6 +299,19 @@ def set_var_to_fp16(self, var_name, block): if var.dtype == core.VarDesc.VarType.FP32: var.desc.set_dtype(__target_dtype__) + def resolute_cast_op(self, block): + """ + Deal the "cast_op" from "FP32" to "FP16" or "BF16" in the model. + """ + for op in block.ops: + if op.type == "cast": + in_name = op.input('X')[0] + out_name = op.output('Out')[0] + in_var = block._find_var_recursive(in_name) + out_var = block._find_var_recursive(out_name) + op._set_attr("in_dtype", in_var.dtype) + op._set_attr("out_dtype", out_var.dtype) + def resolute_tensor_dtype(self, block): for op in block.ops: # 'amp_options' flag has highest priority