Skip to content

Commit

Permalink
fix dtype missmatch error (#53712)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 authored May 12, 2023
1 parent 1019b26 commit 772b490
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,15 @@ def cast_model_to_fp16(

def need_process(op):
need_process = True
if op.type in ["cast", "create_py_reader", "read"]:
if op.type in ["create_py_reader", "read"]:
need_process = False
else:
for attr_name in ['out_dtype', 'dtype']:
if op.has_attr(attr_name) and is_float_dtype(
op.attr(attr_name)
# output type of some operators such as fill_constant will be determined by the attribute value.
#
if not op.has_attr('in_dtype') and (
op.has_attr(attr_name)
and is_float_dtype(op.attr(attr_name))
):
need_process = False

Expand Down Expand Up @@ -667,6 +670,24 @@ def need_process(op):
"---- Add into keep_fp16_ops because the op in white_list ----"
)
else:
# if cast in orgin program, we only modifiy attr and output's dtype to avoid dtype mismatch errors.
if op.type == 'cast':
in_var = block._find_var_recursive(op.input('X')[0])
out_var = block._find_var_recursive(op.output('Out')[0])
op._set_attr('in_dtype', in_var.dtype)
out_var.desc.set_dtype(paddle.dtype(op.attr('out_dtype')))
_logger.debug(
"---- op type: {}, in var [name: {} dtype: {}], out var [name: {} dtype: {}], attr [in_dtype {} out_dtype {}] ----".format(
op.type,
op.input('X')[0],
in_var.dtype,
op.output('Out')[0],
out_var.dtype,
op.attr('in_dtype'),
op.attr('out_dtype'),
)
)
continue
# divide others ops into fp16/fp32 sets according to promoting principle.
dst_dtype = dest_type
if not use_promote:
Expand Down

0 comments on commit 772b490

Please sign in to comment.