Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type promotion static T+T logit. #60638

Merged
merged 13 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
is_compiled_with_rocm,
is_compiled_with_xpu,
name_scope,
process_type_promotion,
program_guard,
require_version,
set_flags,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_flags,
in_pir_mode,
paddle_type_to_proto_type,
process_type_promotion,
set_flags,
)
from .incubate.checkpoint import auto_checkpoint as acp
Expand Down Expand Up @@ -1770,6 +1771,8 @@ def run(
return_numpy=return_numpy,
)
else:
# do type promotion if necessary
program = process_type_promotion(program)
res = self._run_impl(
program=program,
feed=feed,
Expand Down
107 changes: 107 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
_global_flags_ = core.globals()

SUPPORT_PROMOTION_OPS_AND_INPUTNAME = {
"elementwise_add": ['X', 'Y'],
"elementwise_add_grad": ['X', 'Y'],
"elementwise_sub": ['X', 'Y'],
"elementwise_sub_grad": ['X', 'Y'],
"elementwise_mul": ['X', 'Y'],
"elementwise_mul_grad": ['X', 'Y'],
"where": ['X', 'Y'],
zxcd marked this conversation as resolved.
Show resolved Hide resolved
"where_grad": ['X', 'Y'],
}


def _global_flags():
return _global_flags_
Expand Down Expand Up @@ -8144,3 +8155,99 @@ def _get_paddle_place_list(places):
ret.append(p)

return ret


def dtype_to_str(in_dtype):
if in_dtype == core.VarDesc.VarType.FP16:
return "fp16"
elif in_dtype == core.VarDesc.VarType.BF16:
return "bf16"
elif in_dtype == core.VarDesc.VarType.FP32:
return "fp32"
elif in_dtype == core.VarDesc.VarType.FP64:
return "fp64"
else:
return None


def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):
op_device = op.attr('op_device')
cast_name = var_name.name + '.cast_' + dtype_to_str(out_dtype)
out_var = block.create_var(
name=cast_name,
dtype=out_dtype,
persistable=False,
stop_gradient=var_name.stop_gradient,
)
op_role = (
int(core.op_proto_and_checker_maker.OpRole.Forward)
if not op.has_attr('op_role')
else op.attr('op_role')
)
block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": var_name},
outputs={"Out": out_var},
attrs={
"in_dtype": var_name.dtype,
"out_dtype": out_var.dtype,
"op_device": op_device,
"op_role": op_role,
},
)
op.desc._rename_input(var_name.name, out_var.name)


def process_type_promotion(program):
org_program = program
if program is None:
program = default_main_program()
# not support pir for now
if not isinstance(program, Program):
return org_program
global_block = program.global_block()
all_params = global_block.all_parameters()
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
var_name = None
all_dtypes = []
all_input_name_need_cast = []

need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(
op.type, None
)
# type promotion only support some dyadic api
if need_transed_var_names is None:
idx += 1
continue

# get all dtype and input_name
for input_idx in range(len(op.input_arg_names)):
if op.input_names[input_idx] in need_transed_var_names:
input_arg_name = op.input_arg_names[input_idx]
all_dtypes.append(
op.block._var_recursive(input_arg_name).dtype
)
all_input_name_need_cast.append(input_arg_name)

# only support promote between float
if core.need_type_promotion(*all_dtypes):
common_dtype = core.get_promote_dtype(op.type, *all_dtypes)
for input_name_need_cast in all_input_name_need_cast:
var_name = op.block._var_recursive(input_name_need_cast)
if var_name.dtype != common_dtype:
# add cast op for different dtype
add_cast_for_type_promotion(
op,
block,
idx,
var_name,
common_dtype,
)
idx += 1
idx += 1
return program
10 changes: 2 additions & 8 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,19 +534,13 @@ def __impl__(self, other_var):
if lhs_dtype != rhs_dtype:
if method_name in SUPPORT_PROMOTION_OPS:
if core.need_type_promotion(lhs_dtype, rhs_dtype):
common_dtype = core.get_promote_dtype(
op_type, lhs_dtype, rhs_dtype
)
# only report warning here, real promotion deal in Executor
warnings.warn(
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted to {common_dtype}"
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted"
)
warnings.filterwarnings(
"ignore", message="The input dtypes of OP"
)
if rhs_dtype != common_dtype:
other_var = astype(other_var, common_dtype)
if lhs_dtype != common_dtype:
self = astype(self, common_dtype)
else:
# NOTE(zoooo0820): Currently, we still keep the old illogical \
# logic for compatibility reasons
Expand Down
17 changes: 16 additions & 1 deletion python/paddle/static/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
unique_name,
)
from paddle.base.executor import Executor, global_scope
from paddle.base.framework import Parameter, dygraph_not_support, static_only
from paddle.base.framework import (
Parameter,
dygraph_not_support,
process_type_promotion,
static_only,
)
from paddle.base.log_helper import get_logger
from paddle.framework.io_utils import (
_clone_var_in_block_,
Expand Down Expand Up @@ -587,6 +592,10 @@ def save_inference_model(
_check_vars('fetch_vars', fetch_vars)

program = _get_valid_program(kwargs.get('program', None))

# do type promotion
program = process_type_promotion(program)

clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(
program,
Expand Down Expand Up @@ -903,6 +912,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
# deserialize bytes to program
program = deserialize_program(program_bytes)

# do type promotion
program = process_type_promotion(program)

vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_vars(
Expand Down Expand Up @@ -958,6 +970,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
# deserialize bytes to program
program = deserialize_program(program_bytes)

# do type promotion
program = process_type_promotion(program)

vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_dirname = os.path.dirname(params_path)
Expand Down
85 changes: 50 additions & 35 deletions test/legacy_test/test_tensor_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,31 @@ def test_dtype_is_expected(self):
)


class TestAPIAddInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = paddle.add(self.l_value, self.r_value)
out_reverse = paddle.add(self.r_value, self.l_value)

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(TestAPIAddInStatic, 'float16', 'float32', 'float32')
create_test_case(TestAPIAddInStatic, 'float16', 'float64', 'float64')

create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64')


if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float16', 'float32')
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float32', 'float32')
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float64', 'float64')


class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
Expand Down Expand Up @@ -156,74 +181,64 @@ def run_api(self):
)


class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
class TestAPISubInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value * self.r_value
out_reverse = self.r_value * self.l_value
out = paddle.subtract(self.l_value, self.r_value)
out_reverse = paddle.subtract(self.r_value, self.l_value)

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
)
create_test_case(TestAPISubInStatic, 'float16', 'float32', 'float32')
create_test_case(TestAPISubInStatic, 'float16', 'float64', 'float64')

create_test_case(
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
)
create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64')

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
)

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(TestAPISubInStatic, 'bfloat16', 'float16', 'float32')
create_test_case(TestAPISubInStatic, 'bfloat16', 'float32', 'float32')
create_test_case(TestAPISubInStatic, 'bfloat16', 'float64', 'float64')

class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic):
def set_dtype(self):
self.ldtype = 'float32'
self.rdtype = 'float64'
self.expected_out_dtype = 'bool'

class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value > self.r_value
out_reverse = self.r_value > self.l_value
out = self.l_value * self.r_value
out_reverse = self.r_value * self.l_value

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool')
create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool')
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
)

create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool')
create_test_case(
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
)

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool'
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool'
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool'
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
)


Expand Down
Loading