Skip to content

Commit

Permalink
Fix bug of amp code-gen (PaddlePaddle#44570)
Browse files Browse the repository at this point in the history
* fix bug of amp code_gen

* fix bug
  • Loading branch information
zyfncg authored and Aurelius84 committed Jul 29, 2022
1 parent ca44e83 commit 4154dde
Showing 1 changed file with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,25 @@ def FindParsingFunctionFromAttributeType(atype):
}}
"""

INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
using result_type = decltype({}({}));
std::unique_ptr<result_type> out_ptr;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out_ptr = std::make_unique<result_type>({}({}));
}} else {{
out_ptr = std::make_unique<result_type>({}({}));
}}
result_type& out = *out_ptr;
"""

FUNCTION_SET_DEVICE_TEMPLATE = \
"""{} if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand Down Expand Up @@ -531,7 +550,7 @@ def GeneratePythonCFunction(self):
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str)

inplace_amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
Expand Down

0 comments on commit 4154dde

Please sign in to comment.