Skip to content

Commit

Permalink
add type promotion for complex and real number. (#63842)
Browse files Browse the repository at this point in the history
* add type promotion for complex and real number.

* fix

* reduce api support

* add more api support

* fix

* fix

* remove matmul

* add T+S logic.

* fix bug

* fix unittest

* fix

* fix

* fix unittest

* fix gumbel

* rm print

* fix more unittests.

* fix test_llama_group_log_softmax.py

* fix bug, and add 0-d + 0-d logic.

* rm print

* fix behavior of bool and int

* add unittest for all type promotion.

* rm unintest which is unsupport dtype

* fix

* fix

* add error unittest

* fix increase unittest

* bug fix

* fixed by comment

* remove useless code.

* fix

* fix

* fix TypePromotionForZeroDimTensor

* add inplace API support, add special case can skip type promotion (add x=float32,y=float16/bfloat16).

* add broatcast support for MultiPrecisionAddKernelImpl.
  • Loading branch information
zxcd authored May 9, 2024
1 parent 5248add commit 3f10cae
Show file tree
Hide file tree
Showing 36 changed files with 5,013 additions and 801 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
Expand Down Expand Up @@ -247,6 +247,22 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT

VLOG(5)
<< " No AMP for multiply__ad_func because it is a inplace or cast api. ";

// Type promotion Logic
if (phi::NeedTypePromotion("multiply_", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
"automatically, this may cause data type been changed.";
auto op_name = phi::TransToFluidOpName("multiply_");
auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype());

x = egr::PromoteCastInplace("x", x, promotion_type);
auto new_y = egr::PromoteCast("y", y, promotion_type);

return multiply__ad_func(x, new_y);
}

// Layout autotune

if (egr::Controller::Instance().UseLayoutAutoTune()) {
Expand Down Expand Up @@ -424,7 +440,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
Expand Down
83 changes: 81 additions & 2 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,50 @@
type_promote_white_list = {
"add": ["x", "y"],
"subtract": ["x", "y"],
"divide": ["x", "y"],
"floor_divide": ["x", "y"],
"elementwise_pow": ["x", "y"],
"where": ["x", "y"],
"equal": ["x", "y"],
"not_equal": ["x", "y"],
"less_than": ["x", "y"],
"less_equal": ["x", "y"],
"greater_than": ["x", "y"],
"greater_equal": ["x", "y"],
"logical_and": ["x", "y"],
"logical_or": ["x", "y"],
"logical_xor": ["x", "y"],
"fmax": ["x", "y"],
"fmin": ["x", "y"],
"maximum": ["x", "y"],
"minimum": ["x", "y"],
"remainder": ["x", "y"],
"huber_loss": ["input", "label"],
"nextafter": ["x", "y"],
"atan2": ["x", "y"],
}

type_promote_inplace_white_list = {
"add_": ["x", "y"],
"subtract_": ["x", "y"],
"divide_": ["x", "y"],
"floor_divide_": ["x", "y"],
"where_": ["x", "y"],
"equal_": ["x", "y"],
"not_equal_": ["x", "y"],
"less_than_": ["x", "y"],
"less_equal_": ["x", "y"],
"greater_than_": ["x", "y"],
"greater_equal_": ["x", "y"],
"logical_and_": ["x", "y"],
"logical_or_": ["x", "y"],
"logical_xor_": ["x", "y"],
"remainder_": ["x", "y"],
}

# dict of special api that forward api's output will affect backward api's output
# backward api's output usually affected by backward api's input

special_prune_dict = {
"matmul_grad": {"x": "grad_y", "y": "grad_x"},
}
Expand Down Expand Up @@ -537,13 +576,13 @@ class {} : public egr::GradNodeBase {{
}}
"""

TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({x}.dtype(), {y}.dtype())) {{
TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({op_func_name}, {x}.dtype(), {y}.dtype())) {{
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion automatically, this may cause data type been changed.";
{op_name}
auto promotion_type = phi::GetPromoteDtype(op_name, {x}.dtype(), {y}.dtype());
auto new_{x} = egr::PromoteCast("{x}", {x}, promotion_type);
{x_cast}
auto new_{y} = egr::PromoteCast("{y}", {y}, promotion_type);
{return_value}
Expand Down Expand Up @@ -1511,6 +1550,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
elif forward_api_name in type_promote_inplace_white_list:
if name in type_promote_inplace_white_list[forward_api_name]:
if (
is_inplaced
and forward_inplace_map
and name in forward_inplace_map
):
type_promote_inputs_call_list[pos] = f"{name}"
else:
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
if is_optional:
if (
Expand Down Expand Up @@ -1601,6 +1652,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
amp_inputs_call_list[pos] = name
type_promote_inputs_call_list[pos] = name
if default_val is not None:
inputs_args_declaration_list[
pos
Expand Down Expand Up @@ -1846,16 +1898,43 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
# Forward type promotion logic
if forward_api_name in type_promote_white_list:
# only support two inputs
op_func_name = f"\"{forward_api_name}\""
x = type_promote_white_list[forward_api_name][0]
y = type_promote_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

x_cast = f"auto new_{x} = egr::PromoteCast(\"{x}\", {x}, promotion_type);"

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
y=y,
x_cast=x_cast,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
elif forward_api_name in type_promote_inplace_white_list:
# only support two inputs
op_func_name = f"\"{forward_api_name}\""
x = type_promote_inplace_white_list[forward_api_name][0]
y = type_promote_inplace_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

x_cast = (
f"{x} = egr::PromoteCastInplace(\"{x}\", {x}, promotion_type);"
)

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
y=y,
x_cast=x_cast,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/eager/type_promotion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,15 @@ inline paddle::Tensor PromoteCast(const std::string& input_name,
}
}

inline paddle::Tensor PromoteCastInplace(const std::string& input_name,
paddle::Tensor& input, // NOLINT
const phi::DataType& dst_dtype,
bool trace_backward = true) {
if (input.dtype() != dst_dtype) {
return paddle::experimental::cast_(input, dst_dtype);
} else {
return input;
}
}

} // namespace egr
Loading

0 comments on commit 3f10cae

Please sign in to comment.