From ef4cc881a514752ce4b97a7545c204ca31062fc4 Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Wed, 8 Mar 2023 16:36:46 +0800 Subject: [PATCH] add bf16 o2 --- .../auto_parallel/operators/dist_embedding.py | 11 +- .../distributed/passes/auto_parallel_fp16.py | 363 +++++++++++------- 2 files changed, 223 insertions(+), 151 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index cb4060b2593ee..08b00a5c7f63b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -455,7 +455,7 @@ def forward(ctx, *args, **kwargs): check_variable_and_dtype( Out_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'c_allreduce_sum', ) @@ -645,7 +645,7 @@ def backward(ctx, *args, **kwargs): check_variable_and_dtype( Out_grad, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) @@ -687,12 +687,15 @@ def backward(ctx, *args, **kwargs): }, ) check_variable_and_dtype( - intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 941f805dbd65a..ccf14043389ee 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -29,13 +29,9 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.framework import core from paddle.static import default_main_program, default_startup_program -from paddle.static.amp.fp16_utils import ( - AutoMixedPrecisionLists, - _dtype_to_str, - _keep_layer_norm_scale_bias_to_fp32, - _need_keep_fp32, - _valid_types, -) + +# NOTE bf16 and fp16 may have diff logic for _keep_layer_norm_scale_bias_to_fp32 +from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 from paddle.utils import unique_name from ..auto_parallel.process_mesh import ProcessMesh @@ -50,6 +46,8 @@ 'while', 'cast', ] +__target_dtype__ = None +__amp_utils__ = None def set_op_dtype_to_fp16(op): @@ -57,17 +55,24 @@ def set_op_dtype_to_fp16(op): op.has_attr('in_dtype') and op.attr('in_dtype') == core.VarDesc.VarType.FP32 ): - op._set_attr('in_dtype', core.VarDesc.VarType.FP16) + op._set_attr('in_dtype', __target_dtype__) if ( op.has_attr('out_dtype') and op.attr('out_dtype') == core.VarDesc.VarType.FP32 ): - op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + op._set_attr('out_dtype', __target_dtype__) if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + op._set_attr('dtype', __target_dtype__) + + if __target_dtype__ == core.VarDesc.VarType.BF16: + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + if op.has_attr('mkldnn_data_type'): + op._set_attr('mkldnn_data_type', 'bfloat16') # adapot for backward op +# TODO check if bf16 and fp16 still share the same logic def _keep_fp32_input(op, in_name): op_type = op.type if op_type == 'batch_norm': @@ -96,6 +101,7 @@ def _keep_fp32_input(op, in_name): return False +# TODO check if bf16 and fp16 still share the same logic def _keep_fp32_output(op, out_name): op_type = op.type if op_type in ['batch_norm', 'fused_bn_add_activation']: @@ -208,7 +214,7 @@ def _mark_op(self, op): self._op_fp16_dict[op.desc.original_id()] = True return - if _need_keep_fp32( + if __amp_utils__._need_keep_fp32( op, self.amp_list.unsupported_list, self.use_fp16_guard ): self._op_fp16_dict[op.desc.original_id()] = False @@ -240,11 +246,15 @@ def set_var_to_fp16(self, var_name, block): # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY - if var is None or var.type not in _valid_types or "array_" in var_name: + if ( + var is None + or var.type not in __amp_utils__._valid_types + or "array_" in var_name + ): return if var.dtype == core.VarDesc.VarType.FP32: - var.desc.set_dtype(core.VarDesc.VarType.FP16) + var.desc.set_dtype(__target_dtype__) def resolute_tensor_dtype(self, block): @@ -274,9 +284,12 @@ def resolute_tensor_dtype(self, block): elif self._is_fp16_op(op.desc.original_id()) is False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: + if ( + out_var is None + or out_var.type not in __amp_utils__._valid_types + ): continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == __target_dtype__: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) elif is_backward_op(op): if self._is_fp16_op(op.desc.original_id()) is True: @@ -290,9 +303,12 @@ def resolute_tensor_dtype(self, block): elif self._is_fp16_op(op.desc.original_id()) is False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: + if ( + out_var is None + or out_var.type not in __amp_utils__._valid_types + ): continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == __target_dtype__: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) def cast_block(self, block): @@ -311,7 +327,7 @@ def cast_block(self, block): op, idx, block, - core.VarDesc.VarType.FP16, + __target_dtype__, core.VarDesc.VarType.FP32, self.dist_context, ) @@ -321,7 +337,7 @@ def cast_block(self, block): idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + __target_dtype__, self.dist_context, ) elif is_backward_op(op): @@ -331,7 +347,7 @@ def cast_block(self, block): op, idx, block, - core.VarDesc.VarType.FP16, + __target_dtype__, core.VarDesc.VarType.FP32, self.dist_context, ) @@ -341,7 +357,7 @@ def cast_block(self, block): idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + __target_dtype__, self.dist_context, ) elif op.type == "sum": @@ -359,6 +375,7 @@ def cast_block(self, block): out_var.desc.set_dtype(in_var.dtype) idx += num_cast_ops + 1 + print("self.forward_input_cast_ops: ", self.forward_input_cast_ops) block._sync_with_cpp() def _insert_forward_cast_ops( @@ -379,19 +396,30 @@ def _insert_forward_cast_ops( in_var = block._find_var_recursive(in_var_name) if ( in_var is None - or in_var.type not in _valid_types + or in_var.type not in __amp_utils__._valid_types or in_var.dtype == dst_dtype ): continue if in_var.dtype == src_dtype: cast_name = ( - in_var.name + '.cast_' + _dtype_to_str(dst_dtype) + in_var.name + + '.cast_' + + __amp_utils__._dtype_to_str(dst_dtype) ) cast_var = block.vars.get(cast_name) self.forward_input_cast_ops[op.desc.original_id()] += [ (cast_name, in_var.name, dst_dtype, src_dtype, in_name) ] + print( + "insert forward cast: ", + cast_name, + in_var.name, + dst_dtype, + src_dtype, + in_name, + ) + print(str(op)) in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var.name @@ -467,7 +495,8 @@ def _insert_backward_cast_ops( assert out_var.dtype == dst_dtype, "{}, {}".format( str(out_var), dst_dtype ) - + if int(forward_op_id) == 610: + print("forward_op_id = 610", str(op)) for ( cast_name, src_name, @@ -475,9 +504,14 @@ def _insert_backward_cast_ops( src_dtype, slot_name, ) in self.forward_input_cast_ops[forward_op_id]: + if int(forward_op_id) == 610: + print("forward_op_id = 610 insert: ") + print(cast_name, src_name, dst_dtype, src_dtype, slot_name) # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy - if slot_name not in op.input_names: + if slot_name not in op.input_names and op.type in [ + 'softmax_with_cross_entropy_grad' + ]: continue # rename input @@ -488,12 +522,14 @@ def _insert_backward_cast_ops( assert src_var_dist_attr is not None op._rename_input(src_name, cast_name) grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) - + if int(forward_op_id) == 610: + print("here0") # create cast grad grad_slot_name = slot_name + "@GRAD" if grad_slot_name not in op.output_names: continue - + if int(forward_op_id) == 610: + print("here1") # some forward input maybe stop_gradient=True, e.g. input_mask if len(op.output(grad_slot_name)) == 0: continue @@ -517,6 +553,8 @@ def _insert_backward_cast_ops( persistable=grad.persistable, stop_gradient=grad.stop_gradient, ) + if int(forward_op_id) == 610: + print("here2: ", str(cast_grad)) dist_context.set_tensor_dist_attr_for_program( cast_grad, grad_dist_attr ) @@ -535,6 +573,8 @@ def _insert_backward_cast_ops( OP_ROLE_KEY: OpRole.Backward, }, ) + if int(forward_op_id) == 610: + print("here2: ", str(cast_op)) grad.desc.set_dtype(src_dtype) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( @@ -604,7 +644,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): def _split_grads(params_grads): grads = [g for _, g in params_grads] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] - fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] + fp16_grads = [g for g in grads if g.dtype == __target_dtype__] assert len(fp32_grads) + len(fp16_grads) == len( grads ), "Data types of all grads must be either fp16 or fp32." @@ -707,17 +747,17 @@ def is_initialization_op(op): for op in startup_program.global_block().ops: if is_initialization_op(op): output_name = op.output_arg_names[0] - if ( - param_to_dtype.get(output_name, None) - == core.VarDesc.VarType.FP16 - ): + if param_to_dtype.get(output_name, None) == __target_dtype__: assert op.has_attr( 'dtype' ), "initialization op is supported to has dtype attribute but got {}.".format( str(op) ) + out_var = startup_program.global_block().var(output_name) + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(__target_dtype__) if op.attr('dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + op._set_attr('dtype', __target_dtype__) @register_pass("auto_parallel_fp16") @@ -730,12 +770,37 @@ def __init__(self): # in distributed scenario, all ranks should have the same modification. def _apply_single_impl(self, main_program, startup_program, context): self.dist_context = self.get_attr("dist_context") + self.target_dtype = self.get_attr("dtype") params_grads = self.get_attr("params_grads") + self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None) if self.use_optimizer_fp16 is None: self.use_optimizer_fp16 = self.get_attr("level", None) == "o3" - amp_list = AutoMixedPrecisionLists( + # swith enviroment for fp16 / bf16. + if self.target_dtype == "float16": + import paddle.static.amp.fp16_utils as amp_utils + + AMPList = amp_utils.AutoMixedPrecisionLists + __target_dtype = core.VarDesc.VarType.FP16 + + elif self.target_dtype == "bfloat16": + import paddle.static.amp.bf16.amp_utils as amp_utils + + AMPList = amp_utils.AutoMixedPrecisionListsBF16 + __target_dtype = core.VarDesc.VarType.BF16 + + else: + raise NotImplementedError( + "target dtype [{}] is for amp o2 not supported yet.".format( + self.target_dtype + ) + ) + global __target_dtype__ + __target_dtype__ = __target_dtype + global __amp_utils__ + __amp_utils__ = amp_utils + amp_list = AMPList( set(self.get_attr("custom_white_list")), set(self.get_attr("custom_black_list")), None, @@ -750,7 +815,9 @@ def _apply_single_impl(self, main_program, startup_program, context): main_program, amp_list, self.dist_context, - self.get_attr("use_fp16_guard"), + self.get_attr( + "use_fp16_guard" + ), # TODO unify to use_amp_guard to be compatible with amp o1 input_data_var_names, ) is_train = fp16_state._build_state() @@ -758,128 +825,130 @@ def _apply_single_impl(self, main_program, startup_program, context): cast_startup_program() if is_train: - with paddle.static.program_guard(main_program, startup_program): - # TODO (JZ-LIANG)support cast forward program only when inference - self._init_amp_var() - self._scale_loss() - - grads, fp32_grads, fp16_grads = _split_grads(params_grads) - - if ( - self.get_attr("use_dynamic_loss_scaling") - or self.get_attr("init_loss_scaling") != 1.0 - ): - found_infs = [] - if fp32_grads: + if self.target_dtype == "fp16": + with paddle.static.program_guard(main_program, startup_program): + # TODO (JZ-LIANG)support cast forward program only when inference + self._init_amp_var() + self._scale_loss() + + grads, fp32_grads, fp16_grads = _split_grads(params_grads) + + if ( + self.get_attr("use_dynamic_loss_scaling") + or self.get_attr("init_loss_scaling") != 1.0 + ): + found_infs = [] + if fp32_grads: + with main_program._optimized_guard([]): + _, found_inf_fp32 = _check_and_update_gradient( + fp32_grads, + self._loss_scaling, + "@fp32", + self.dist_context, + ) + found_infs.append(found_inf_fp32) + if fp16_grads: + with main_program._optimized_guard([]): + _, found_inf_fp16 = _check_and_update_gradient( + fp16_grads, + self._loss_scaling, + "@fp16", + self.dist_context, + ) + found_infs.append(found_inf_fp16) with main_program._optimized_guard([]): - _, found_inf_fp32 = _check_and_update_gradient( - fp32_grads, - self._loss_scaling, - "@fp32", + block = main_program.global_block() + + # all_infs = paddle.fluid.layers.concat(found_infs) + all_infs = block.create_var( + name=paddle.utils.unique_name.generate_with_ignorable_key( + ".".join(['concat', 'tmp']) + ), + dtype=found_infs[0].dtype, + shape=None, + lod_level=found_infs[0].lod_level, + type=found_infs[0].type, + persistable=False, + stop_gradient=False, + ) + concat_op = block.append_op( + type='concat', + inputs={'X': found_infs}, + outputs={'Out': [all_infs]}, + attrs={'axis': 0}, + ) + set_var_dist_attr( self.dist_context, + all_infs, + [-1], + world_process_group.ranks, ) - found_infs.append(found_inf_fp32) - if fp16_grads: - with main_program._optimized_guard([]): - _, found_inf_fp16 = _check_and_update_gradient( - fp16_grads, - self._loss_scaling, - "@fp16", + _set_op_dist_attr_with_ranks( + concat_op, + world_process_group.ranks, + block, self.dist_context, ) - found_infs.append(found_inf_fp16) - with main_program._optimized_guard([]): - block = main_program.global_block() - - # all_infs = paddle.fluid.layers.concat(found_infs) - all_infs = block.create_var( - name=paddle.utils.unique_name.generate_with_ignorable_key( - ".".join(['concat', 'tmp']) - ), - dtype=found_infs[0].dtype, - shape=None, - lod_level=found_infs[0].lod_level, - type=found_infs[0].type, - persistable=False, - stop_gradient=False, - ) - concat_op = block.append_op( - type='concat', - inputs={'X': found_infs}, - outputs={'Out': [all_infs]}, - attrs={'axis': 0}, - ) - set_var_dist_attr( - self.dist_context, - all_infs, - [-1], - world_process_group.ranks, - ) - _set_op_dist_attr_with_ranks( - concat_op, - world_process_group.ranks, - block, - self.dist_context, - ) - # found_inf = paddle.fluid.layers.reduce_any(all_infs) - found_inf = block.create_var( - name=paddle.utils.unique_name.generate_with_ignorable_key( - ".".join(['reduce_any', 'tmp']) - ), - dtype=all_infs.dtype, - shape=None, - lod_level=all_infs.lod_level, - type=all_infs.type, - persistable=False, - stop_gradient=False, - ) - reduce_any_op = block.append_op( - type='reduce_any', - inputs={'X': all_infs}, - outputs={'Out': found_inf}, - attrs={ - 'dim': [0], - 'keep_dim': False, - 'reduce_all': True, - }, - ) - set_var_dist_attr( - self.dist_context, - found_inf, - [-1], - world_process_group.ranks, - ) - _set_op_dist_attr_with_ranks( - reduce_any_op, - world_process_group.ranks, - block, - self.dist_context, - ) + # found_inf = paddle.fluid.layers.reduce_any(all_infs) + found_inf = block.create_var( + name=paddle.utils.unique_name.generate_with_ignorable_key( + ".".join(['reduce_any', 'tmp']) + ), + dtype=all_infs.dtype, + shape=None, + lod_level=all_infs.lod_level, + type=all_infs.type, + persistable=False, + stop_gradient=False, + ) + reduce_any_op = block.append_op( + type='reduce_any', + inputs={'X': all_infs}, + outputs={'Out': found_inf}, + attrs={ + 'dim': [0], + 'keep_dim': False, + 'reduce_all': True, + }, + ) + set_var_dist_attr( + self.dist_context, + found_inf, + [-1], + world_process_group.ranks, + ) + _set_op_dist_attr_with_ranks( + reduce_any_op, + world_process_group.ranks, + block, + self.dist_context, + ) - if self.get_attr("use_dynamic_loss_scaling"): - with main_program._optimized_guard([]): - if fp32_grads: - self._update_loss_scaling(fp32_grads, found_inf) - if fp16_grads: - self._update_loss_scaling(fp16_grads, found_inf) + if self.get_attr("use_dynamic_loss_scaling"): + with main_program._optimized_guard([]): + if fp32_grads: + self._update_loss_scaling(fp32_grads, found_inf) + if fp16_grads: + self._update_loss_scaling(fp16_grads, found_inf) # modify optimizer base_opt = self.get_attr("base_opt") base_opt._multi_precision = True if self.use_optimizer_fp16: base_opt._multi_precision = False - if isinstance( - base_opt, - (paddle.static.Adam, paddle.optimizer.AdamW), - ): - with main_program._optimized_guard([]): - # found_inf = paddle.tensor.creation._memcpy( - # found_inf, paddle.CPUPlace()) - insert_idx = _get_memcopy_idx(block, found_inf) - found_inf = _insert_memcopy( - block, insert_idx, found_inf, self.dist_context - ) - base_opt._set_auxiliary_var('found_inf', found_inf.name) - elif hasattr(base_opt, "_set_auxiliary_var"): - base_opt._set_auxiliary_var('found_inf', found_inf.name) + + if self.target_dtype == "fp16": + if isinstance( + base_opt, (paddle.static.Adam, paddle.optimizer.AdamW) + ): + with main_program._optimized_guard([]): + # found_inf = paddle.tensor.creation._memcpy( + # found_inf, paddle.CPUPlace()) + insert_idx = _get_memcopy_idx(block, found_inf) + found_inf = _insert_memcopy( + block, insert_idx, found_inf, self.dist_context + ) + base_opt._set_auxiliary_var('found_inf', found_inf.name) + elif hasattr(base_opt, "_set_auxiliary_var"): + base_opt._set_auxiliary_var('found_inf', found_inf.name)