From ee532627e54ffc4239b2d630fa1af4b6cbfa97c3 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 12 May 2023 16:11:35 +0800 Subject: [PATCH] fix dtype missmatch error (#53712) --- python/paddle/static/amp/fp16_utils.py | 27 +++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index c2f8a12d33b46..333d652f9f59d 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -629,12 +629,15 @@ def cast_model_to_fp16( def need_process(op): need_process = True - if op.type in ["cast", "create_py_reader", "read"]: + if op.type in ["create_py_reader", "read"]: need_process = False else: for attr_name in ['out_dtype', 'dtype']: - if op.has_attr(attr_name) and is_float_dtype( - op.attr(attr_name) + # output type of some operators such as fill_constant will be determined by the attribute value. + # + if not op.has_attr('in_dtype') and ( + op.has_attr(attr_name) + and is_float_dtype(op.attr(attr_name)) ): need_process = False @@ -667,6 +670,24 @@ def need_process(op): "---- Add into keep_fp16_ops because the op in white_list ----" ) else: + # if cast in orgin program, we only modifiy attr and output's dtype to avoid dtype mismatch errors. + if op.type == 'cast': + in_var = block._find_var_recursive(op.input('X')[0]) + out_var = block._find_var_recursive(op.output('Out')[0]) + op._set_attr('in_dtype', in_var.dtype) + out_var.desc.set_dtype(paddle.dtype(op.attr('out_dtype'))) + _logger.debug( + "---- op type: {}, in var [name: {} dtype: {}], out var [name: {} dtype: {}], attr [in_dtype {} out_dtype {}] ----".format( + op.type, + op.input('X')[0], + in_var.dtype, + op.output('Out')[0], + out_var.dtype, + op.attr('in_dtype'), + op.attr('out_dtype'), + ) + ) + continue # divide others ops into fp16/fp32 sets according to promoting principle. dst_dtype = dest_type if not use_promote: