From 245bc5ef766e23e1a9d571f13740b2aff1a1d97d Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Tue, 9 Jan 2024 04:00:34 +0000 Subject: [PATCH 01/13] add type promotion static T+T logit. --- python/paddle/base/executor.py | 85 ++++++++++++++++++++++ python/paddle/base/layers/math_op_patch.py | 10 +-- 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 4a7b24d6618c8..50b83bc6fe6e5 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -53,6 +53,13 @@ InferNativeConfig = core.NativeConfig InferAnalysisConfig = core.AnalysisConfig +SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { + "elementwise_add": ['X', 'Y'], + "elementwise_sub": ['X', 'Y'], + "elementwise_mul": ['X', 'Y'], + "where": ['X', 'Y'], +} + def global_scope(): """ @@ -1605,6 +1612,80 @@ def flush(self): del trainer_instance self.trainer_caches.clear() + def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): + op_device = op.attr('op_device') + cast_name = var_name.name + '.cast_' + out_var = block.create_var( + name=cast_name, + dtype=out_dtype, + persistable=False, + stop_gradient=var_name.stop_gradient, + ) + op_role = ( + int(core.op_proto_and_checker_maker.OpRole.Forward) + if not op.has_attr('op_role') + else op.attr('op_role') + ) + block._insert_op_without_sync( + idx, + type="cast", + inputs={"X": var_name}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": var_name.dtype, + "out_dtype": out_var.dtype, + "op_device": op_device, + "op_role": op_role, + }, + ) + op.desc._rename_input(var_name.name, out_var.name) + + def _process_type_promotion(self, program): + global_block = program.global_block() + all_params = global_block.all_parameters() + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + var_name = None + all_dtypes = [] + all_input_name_need_cast = [] + + need_transed_var_names = ( + SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(op.type, None) + ) + # type promotion only support some dyadic api + if need_transed_var_names is None: + idx += 1 + continue + + # get all dtype and input_name + for input_idx in range(len(op.input_arg_names)): + if op.input_names[input_idx] in need_transed_var_names: + input_arg_name = op.input_arg_names[input_idx] + all_dtypes.append( + op.block._var_recursive(input_arg_name).dtype + ) + all_input_name_need_cast.append(input_arg_name) + + # only support promote between float + if core.need_type_promotion(*all_dtypes): + common_dtype = core.get_promote_dtype(op.type, *all_dtypes) + for input_name_need_cast in all_input_name_need_cast: + var_name = op.block._var_recursive(input_name_need_cast) + if var_name.dtype != common_dtype: + # add cast op for different dtype + self._add_cast_for_type_promotion( + op, + block, + idx, + var_name, + common_dtype, + ) + idx += 1 + idx += 1 + def run( self, program=None, @@ -1759,6 +1840,10 @@ def run( 'true', ] self._log_force_set_program_cache(use_program_cache) + + # do type promotion if necessary + self._process_type_promotion(self, program) + if in_pir_mode(): res = self._run_pir_impl( program=program, diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index dbf23b5fff2ff..0a2a63ae8ead2 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -534,19 +534,13 @@ def __impl__(self, other_var): if lhs_dtype != rhs_dtype: if method_name in SUPPORT_PROMOTION_OPS: if core.need_type_promotion(lhs_dtype, rhs_dtype): - common_dtype = core.get_promote_dtype( - op_type, lhs_dtype, rhs_dtype - ) + # only report warning here, real promotion deal in Executor before run warnings.warn( - f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted to {common_dtype}" + f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted" ) warnings.filterwarnings( "ignore", message="The input dtypes of OP" ) - if rhs_dtype != common_dtype: - other_var = astype(other_var, common_dtype) - if lhs_dtype != common_dtype: - self = astype(self, common_dtype) else: # NOTE(zoooo0820): Currently, we still keep the old illogical \ # logic for compatibility reasons From 3ca540d028abf507733b7c5cf8640766054c8ffa Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Tue, 9 Jan 2024 07:17:07 +0000 Subject: [PATCH 02/13] fix bug --- python/paddle/base/executor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 50b83bc6fe6e5..e9de69214356e 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -1641,6 +1641,9 @@ def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): op.desc._rename_input(var_name.name, out_var.name) def _process_type_promotion(self, program): + # not support pir for now + if not isinstance(program, Program): + return global_block = program.global_block() all_params = global_block.all_parameters() for block in program.blocks: @@ -1841,9 +1844,6 @@ def run( ] self._log_force_set_program_cache(use_program_cache) - # do type promotion if necessary - self._process_type_promotion(self, program) - if in_pir_mode(): res = self._run_pir_impl( program=program, @@ -1855,6 +1855,8 @@ def run( return_numpy=return_numpy, ) else: + # do type promotion if necessary + self._process_type_promotion(program) res = self._run_impl( program=program, feed=feed, From cb8cce628b0cd395a38d82991702dadad73fc0b1 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Tue, 9 Jan 2024 07:30:52 +0000 Subject: [PATCH 03/13] fix code comment --- python/paddle/base/executor.py | 1 - python/paddle/base/layers/math_op_patch.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index e9de69214356e..90ab12aecbc49 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -1843,7 +1843,6 @@ def run( 'true', ] self._log_force_set_program_cache(use_program_cache) - if in_pir_mode(): res = self._run_pir_impl( program=program, diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 0a2a63ae8ead2..dc1e9dd731ce3 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -534,7 +534,7 @@ def __impl__(self, other_var): if lhs_dtype != rhs_dtype: if method_name in SUPPORT_PROMOTION_OPS: if core.need_type_promotion(lhs_dtype, rhs_dtype): - # only report warning here, real promotion deal in Executor before run + # only report warning here, real promotion deal in Executor warnings.warn( f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted" ) From b707ee0c13ca74b68db09d99c4bcf162006684df Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Tue, 9 Jan 2024 08:54:04 +0000 Subject: [PATCH 04/13] add where op test for type promotion. --- test/legacy_test/test_where_op.py | 112 ++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 09fe941feed85..5529e641d8259 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -318,6 +318,61 @@ def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): expect = np.where(cond_data, x_data, y_data) np.testing.assert_array_equal(out[0], expect) + def __test_where_with_type_promotion( + self, x_dtype, y_dtype, expeced_dtype=None + ): + paddle.enable_static() + main_program = paddle.static.Program() + shape = [3, 10] + with paddle.static.program_guard(main_program): + cond = paddle.static.data(name='cond', shape=[3, 10], dtype='bool') + x = paddle.static.data(name='x', shape=shape, dtype=x_dtype) + y = paddle.static.data(name='y', shape=shape, dtype=y_dtype) + cond_data_tmp = np.random.random(size=shape).astype('float32') + cond_data = cond_data_tmp < 0.3 + + if x_dtype != 'bfloat16': + x_data = np.random.random(size=shape).astype(x_dtype) + else: + x_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + if y_dtype != 'bfloat16': + y_data = np.random.random(size=shape).astype(y_dtype) + else: + y_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + result = paddle.where(condition=cond, x=x, y=y) + for use_cuda in [False, True]: + if use_cuda and (not base.core.is_compiled_with_cuda()): + return + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + exe = base.Executor(place) + out = exe.run( + paddle.static.default_main_program(), + feed={'cond': cond_data, 'x': x_data, 'y': y_data}, + fetch_list=[result], + ) + if x_dtype == 'bfloat16' or y_dtype == 'bfloat16': + x_data_convert = ( + convert_uint16_to_float(x_data) + if x_dtype == 'bfloat16' + else x_data + ) + y_data_convert = ( + convert_uint16_to_float(y_data) + if y_dtype == 'bfloat16' + else y_data + ) + expect = np.where(cond_data, x_data_convert, y_data_convert) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype.__str__(), expeced_dtype) + else: + expect = np.where(cond_data, x_data, y_data) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype, expect.dtype) + @test_with_pir_api def test_static_api_broadcast_1(self): cond_shape = [2, 4] @@ -374,6 +429,63 @@ def test_static_api_broadcast_8(self): b_shape = [2, 2, 1] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + def test_static_api_type_promotion_fp16_fp32(self): + x_dtype = 'float16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp16_fp64(self): + x_dtype = 'float16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp32_fp64(self): + x_dtype = 'float32' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp16(self): + x_dtype = 'bfloat16' + y_dtype = 'float16' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp32(self): + x_dtype = 'bfloat16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp64(self): + x_dtype = 'bfloat16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float64') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float64') + class TestWhereDygraphAPI(unittest.TestCase): def test_api(self): From d545c0daadf9e05734ac7f7e82d80e2af8f08c34 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Tue, 9 Jan 2024 10:35:10 +0000 Subject: [PATCH 05/13] fix --- .../legacy_test/test_tensor_type_promotion.py | 85 +++++++++++-------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index c47bfe8e5d1d5..19d26048f6997 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -119,6 +119,31 @@ def test_dtype_is_expected(self): ) +class TestAPIAddInStatic(TestOperatorOverloadAddInStatic): + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = paddle.add(self.l_value, self.r_value) + out_reverse = paddle.add(self.r_value, self.l_value) + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case(TestAPIAddInStatic, 'float16', 'float32', 'float32') +create_test_case(TestAPIAddInStatic, 'float16', 'float64', 'float64') + +create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64') + + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case(TestAPIAddInStatic, 'bfloat16', 'float16', 'float32') + create_test_case(TestAPIAddInStatic, 'bfloat16', 'float32', 'float32') + create_test_case(TestAPIAddInStatic, 'bfloat16', 'float64', 'float64') + + class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic): def run_api(self): prog = paddle.static.Program() @@ -156,74 +181,64 @@ def run_api(self): ) -class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): +class TestAPISubInStatic(TestOperatorOverloadAddInStatic): def run_api(self): prog = paddle.static.Program() with paddle.static.program_guard(prog): self.generate_test_value() - out = self.l_value * self.r_value - out_reverse = self.r_value * self.l_value + out = paddle.subtract(self.l_value, self.r_value) + out_reverse = paddle.subtract(self.r_value, self.l_value) res = self.exe.run(prog, fetch_list=[out, out_reverse]) return res -create_test_case( - TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32' -) -create_test_case( - TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' -) +create_test_case(TestAPISubInStatic, 'float16', 'float32', 'float32') +create_test_case(TestAPISubInStatic, 'float16', 'float64', 'float64') -create_test_case( - TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' -) +create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64') -if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32' - ) - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32' - ) - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' - ) +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case(TestAPISubInStatic, 'bfloat16', 'float16', 'float32') + create_test_case(TestAPISubInStatic, 'bfloat16', 'float32', 'float32') + create_test_case(TestAPISubInStatic, 'bfloat16', 'float64', 'float64') -class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic): - def set_dtype(self): - self.ldtype = 'float32' - self.rdtype = 'float64' - self.expected_out_dtype = 'bool' +class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): def run_api(self): prog = paddle.static.Program() with paddle.static.program_guard(prog): self.generate_test_value() - out = self.l_value > self.r_value - out_reverse = self.r_value > self.l_value + out = self.l_value * self.r_value + out_reverse = self.r_value * self.l_value res = self.exe.run(prog, fetch_list=[out, out_reverse]) return res -create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool') -create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool') +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' +) -create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool') +create_test_case( + TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' +) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool' + TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32' ) create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool' + TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32' ) create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool' + TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' ) From 2b9b4dec5fb8cf28bbb9c84ec935ec950e08f530 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 10 Jan 2024 07:01:40 +0000 Subject: [PATCH 06/13] fix bug --- python/paddle/base/executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 90ab12aecbc49..19b6c2e84ad27 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -1641,6 +1641,8 @@ def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): op.desc._rename_input(var_name.name, out_var.name) def _process_type_promotion(self, program): + if program is None: + program = default_main_program() # not support pir for now if not isinstance(program, Program): return From 991143f126ee3c8c0a94b9031776ecd5cae054e2 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 11 Jan 2024 06:52:37 +0000 Subject: [PATCH 07/13] fix --- python/paddle/base/executor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 19b6c2e84ad27..e1b7737bb199e 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -44,6 +44,7 @@ set_flags, ) from .incubate.checkpoint import auto_checkpoint as acp +from .static.amp.fp16_utils import _dtype_to_str from .trainer_factory import FetchHandlerMonitor, TrainerFactory from .wrapped_decorator import signature_safe_contextmanager @@ -55,8 +56,11 @@ SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { "elementwise_add": ['X', 'Y'], + "elementwise_mul_add": ['X', 'Y'], "elementwise_sub": ['X', 'Y'], + "elementwise_sub_grad": ['X', 'Y'], "elementwise_mul": ['X', 'Y'], + "elementwise_mul_grad": ['X', 'Y'], "where": ['X', 'Y'], } @@ -1614,7 +1618,7 @@ def flush(self): def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): op_device = op.attr('op_device') - cast_name = var_name.name + '.cast_' + cast_name = var_name.name + '.cast_' + _dtype_to_str(out_dtype) out_var = block.create_var( name=cast_name, dtype=out_dtype, From f54397319f665436dfe5ba771eb10c3216b1d32d Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 11 Jan 2024 08:03:50 +0000 Subject: [PATCH 08/13] fix path --- python/paddle/base/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index e1b7737bb199e..1175fe9d2f33e 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -29,6 +29,7 @@ translate_to_pir, translate_to_pir_with_param_map, ) +from ..static.amp.fp16_utils import _dtype_to_str from . import compiler, core, framework, unique_name from .data_feeder import convert_dtype from .framework import ( @@ -44,7 +45,6 @@ set_flags, ) from .incubate.checkpoint import auto_checkpoint as acp -from .static.amp.fp16_utils import _dtype_to_str from .trainer_factory import FetchHandlerMonitor, TrainerFactory from .wrapped_decorator import signature_safe_contextmanager From f2555a63962a922034ae178259c03c8c830db10c Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 11 Jan 2024 08:55:16 +0000 Subject: [PATCH 09/13] fix --- python/paddle/base/executor.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 1175fe9d2f33e..37283472b8d50 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -29,7 +29,6 @@ translate_to_pir, translate_to_pir_with_param_map, ) -from ..static.amp.fp16_utils import _dtype_to_str from . import compiler, core, framework, unique_name from .data_feeder import convert_dtype from .framework import ( @@ -1616,6 +1615,16 @@ def flush(self): del trainer_instance self.trainer_caches.clear() + def _dtype_to_str(in_dtype): + if in_dtype == core.VarDesc.VarType.FP16: + return "fp16" + elif in_dtype == core.VarDesc.VarType.BF16: + return "bf16" + elif in_dtype == core.VarDesc.VarType.FP32: + return "fp32" + elif in_dtype == core.VarDesc.VarType.FP64: + return "fp64" + def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): op_device = op.attr('op_device') cast_name = var_name.name + '.cast_' + _dtype_to_str(out_dtype) From eab2a1a7f5bfed15ed65beab0ddd8218902729af Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 11 Jan 2024 09:40:07 +0000 Subject: [PATCH 10/13] fix --- python/paddle/base/executor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 37283472b8d50..4297914b44fb5 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -1615,7 +1615,7 @@ def flush(self): del trainer_instance self.trainer_caches.clear() - def _dtype_to_str(in_dtype): + def _dtype_to_str(self, in_dtype): if in_dtype == core.VarDesc.VarType.FP16: return "fp16" elif in_dtype == core.VarDesc.VarType.BF16: @@ -1624,10 +1624,12 @@ def _dtype_to_str(in_dtype): return "fp32" elif in_dtype == core.VarDesc.VarType.FP64: return "fp64" + else: + return None def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): op_device = op.attr('op_device') - cast_name = var_name.name + '.cast_' + _dtype_to_str(out_dtype) + cast_name = var_name.name + '.cast_' + self._dtype_to_str(out_dtype) out_var = block.create_var( name=cast_name, dtype=out_dtype, From 8652c79855386a4179b5fe0525ec1474fcf4b6cb Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Mon, 15 Jan 2024 03:48:31 +0000 Subject: [PATCH 11/13] fix spelling problem. --- python/paddle/base/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 4297914b44fb5..c4e5593494821 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -55,7 +55,7 @@ SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { "elementwise_add": ['X', 'Y'], - "elementwise_mul_add": ['X', 'Y'], + "elementwise_add_grad": ['X', 'Y'], "elementwise_sub": ['X', 'Y'], "elementwise_sub_grad": ['X', 'Y'], "elementwise_mul": ['X', 'Y'], From 39a22dad20b5f9d218a72149986fba50e49df0fb Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 17 Jan 2024 12:16:12 +0000 Subject: [PATCH 12/13] support paddle inference. --- python/paddle/base/__init__.py | 1 + python/paddle/base/executor.py | 104 +------------------------------ python/paddle/base/framework.py | 106 ++++++++++++++++++++++++++++++++ python/paddle/static/io.py | 17 ++++- 4 files changed, 125 insertions(+), 103 deletions(-) diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 9a6d8914feddb..83fe57b21ce4c 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -107,6 +107,7 @@ is_compiled_with_rocm, is_compiled_with_xpu, name_scope, + process_type_promotion, program_guard, require_version, set_flags, diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index c4e5593494821..f73b2c999b227 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -41,6 +41,7 @@ get_flags, in_pir_mode, paddle_type_to_proto_type, + process_type_promotion, set_flags, ) from .incubate.checkpoint import auto_checkpoint as acp @@ -53,16 +54,6 @@ InferNativeConfig = core.NativeConfig InferAnalysisConfig = core.AnalysisConfig -SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { - "elementwise_add": ['X', 'Y'], - "elementwise_add_grad": ['X', 'Y'], - "elementwise_sub": ['X', 'Y'], - "elementwise_sub_grad": ['X', 'Y'], - "elementwise_mul": ['X', 'Y'], - "elementwise_mul_grad": ['X', 'Y'], - "where": ['X', 'Y'], -} - def global_scope(): """ @@ -1615,97 +1606,6 @@ def flush(self): del trainer_instance self.trainer_caches.clear() - def _dtype_to_str(self, in_dtype): - if in_dtype == core.VarDesc.VarType.FP16: - return "fp16" - elif in_dtype == core.VarDesc.VarType.BF16: - return "bf16" - elif in_dtype == core.VarDesc.VarType.FP32: - return "fp32" - elif in_dtype == core.VarDesc.VarType.FP64: - return "fp64" - else: - return None - - def _add_cast_for_type_promotion(self, op, block, idx, var_name, out_dtype): - op_device = op.attr('op_device') - cast_name = var_name.name + '.cast_' + self._dtype_to_str(out_dtype) - out_var = block.create_var( - name=cast_name, - dtype=out_dtype, - persistable=False, - stop_gradient=var_name.stop_gradient, - ) - op_role = ( - int(core.op_proto_and_checker_maker.OpRole.Forward) - if not op.has_attr('op_role') - else op.attr('op_role') - ) - block._insert_op_without_sync( - idx, - type="cast", - inputs={"X": var_name}, - outputs={"Out": out_var}, - attrs={ - "in_dtype": var_name.dtype, - "out_dtype": out_var.dtype, - "op_device": op_device, - "op_role": op_role, - }, - ) - op.desc._rename_input(var_name.name, out_var.name) - - def _process_type_promotion(self, program): - if program is None: - program = default_main_program() - # not support pir for now - if not isinstance(program, Program): - return - global_block = program.global_block() - all_params = global_block.all_parameters() - for block in program.blocks: - ops = block.ops - idx = 0 - while idx < len(ops): - op = ops[idx] - var_name = None - all_dtypes = [] - all_input_name_need_cast = [] - - need_transed_var_names = ( - SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(op.type, None) - ) - # type promotion only support some dyadic api - if need_transed_var_names is None: - idx += 1 - continue - - # get all dtype and input_name - for input_idx in range(len(op.input_arg_names)): - if op.input_names[input_idx] in need_transed_var_names: - input_arg_name = op.input_arg_names[input_idx] - all_dtypes.append( - op.block._var_recursive(input_arg_name).dtype - ) - all_input_name_need_cast.append(input_arg_name) - - # only support promote between float - if core.need_type_promotion(*all_dtypes): - common_dtype = core.get_promote_dtype(op.type, *all_dtypes) - for input_name_need_cast in all_input_name_need_cast: - var_name = op.block._var_recursive(input_name_need_cast) - if var_name.dtype != common_dtype: - # add cast op for different dtype - self._add_cast_for_type_promotion( - op, - block, - idx, - var_name, - common_dtype, - ) - idx += 1 - idx += 1 - def run( self, program=None, @@ -1872,7 +1772,7 @@ def run( ) else: # do type promotion if necessary - self._process_type_promotion(program) + program = process_type_promotion(program) res = self._run_impl( program=program, feed=feed, diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 1225eba4e4242..8d9e5635e3600 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -56,6 +56,16 @@ CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() _global_flags_ = core.globals() +SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { + "elementwise_add": ['X', 'Y'], + "elementwise_add_grad": ['X', 'Y'], + "elementwise_sub": ['X', 'Y'], + "elementwise_sub_grad": ['X', 'Y'], + "elementwise_mul": ['X', 'Y'], + "elementwise_mul_grad": ['X', 'Y'], + "where": ['X', 'Y'], +} + def _global_flags(): return _global_flags_ @@ -8144,3 +8154,99 @@ def _get_paddle_place_list(places): ret.append(p) return ret + + +def dtype_to_str(in_dtype): + if in_dtype == core.VarDesc.VarType.FP16: + return "fp16" + elif in_dtype == core.VarDesc.VarType.BF16: + return "bf16" + elif in_dtype == core.VarDesc.VarType.FP32: + return "fp32" + elif in_dtype == core.VarDesc.VarType.FP64: + return "fp64" + else: + return None + + +def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype): + op_device = op.attr('op_device') + cast_name = var_name.name + '.cast_' + dtype_to_str(out_dtype) + out_var = block.create_var( + name=cast_name, + dtype=out_dtype, + persistable=False, + stop_gradient=var_name.stop_gradient, + ) + op_role = ( + int(core.op_proto_and_checker_maker.OpRole.Forward) + if not op.has_attr('op_role') + else op.attr('op_role') + ) + block._insert_op_without_sync( + idx, + type="cast", + inputs={"X": var_name}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": var_name.dtype, + "out_dtype": out_var.dtype, + "op_device": op_device, + "op_role": op_role, + }, + ) + op.desc._rename_input(var_name.name, out_var.name) + + +def process_type_promotion(program): + org_program = program + if program is None: + program = default_main_program() + # not support pir for now + if not isinstance(program, Program): + return org_program + global_block = program.global_block() + all_params = global_block.all_parameters() + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + var_name = None + all_dtypes = [] + all_input_name_need_cast = [] + + need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get( + op.type, None + ) + # type promotion only support some dyadic api + if need_transed_var_names is None: + idx += 1 + continue + + # get all dtype and input_name + for input_idx in range(len(op.input_arg_names)): + if op.input_names[input_idx] in need_transed_var_names: + input_arg_name = op.input_arg_names[input_idx] + all_dtypes.append( + op.block._var_recursive(input_arg_name).dtype + ) + all_input_name_need_cast.append(input_arg_name) + + # only support promote between float + if core.need_type_promotion(*all_dtypes): + common_dtype = core.get_promote_dtype(op.type, *all_dtypes) + for input_name_need_cast in all_input_name_need_cast: + var_name = op.block._var_recursive(input_name_need_cast) + if var_name.dtype != common_dtype: + # add cast op for different dtype + add_cast_for_type_promotion( + op, + block, + idx, + var_name, + common_dtype, + ) + idx += 1 + idx += 1 + return program diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 7426dc0b05b0f..52b5ab97c8f38 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -33,7 +33,12 @@ unique_name, ) from paddle.base.executor import Executor, global_scope -from paddle.base.framework import Parameter, dygraph_not_support, static_only +from paddle.base.framework import ( + Parameter, + dygraph_not_support, + process_type_promotion, + static_only, +) from paddle.base.log_helper import get_logger from paddle.framework.io_utils import ( _clone_var_in_block_, @@ -587,6 +592,10 @@ def save_inference_model( _check_vars('fetch_vars', fetch_vars) program = _get_valid_program(kwargs.get('program', None)) + + # do type promotion + program = process_type_promotion(program) + clip_extra = kwargs.get('clip_extra', True) program = normalize_program( program, @@ -903,6 +912,9 @@ def load_inference_model(path_prefix, executor, **kwargs): # deserialize bytes to program program = deserialize_program(program_bytes) + # do type promotion + program = process_type_promotion(program) + vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: load_vars( @@ -958,6 +970,9 @@ def load_inference_model(path_prefix, executor, **kwargs): # deserialize bytes to program program = deserialize_program(program_bytes) + # do type promotion + program = process_type_promotion(program) + vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: load_dirname = os.path.dirname(params_path) From 055b9ac89f84bee72affa878882e5b35b993d9bf Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 18 Jan 2024 07:46:58 +0000 Subject: [PATCH 13/13] add where grad --- python/paddle/base/framework.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 8d9e5635e3600..f170bdd1279dd 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -64,6 +64,7 @@ "elementwise_mul": ['X', 'Y'], "elementwise_mul_grad": ['X', 'Y'], "where": ['X', 'Y'], + "where_grad": ['X', 'Y'], }