@@ -660,6 +660,7 @@ class {} : public egr::GradNodeBase {{
660660#include "paddle/fluid/framework/op_registry.h"
661661#include "paddle/utils/test_macros.h"
662662#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
663+ #include "paddle/utils/optional.h"
663664using CPUPlace = phi::CPUPlace;
664665{}
665666{}
@@ -1496,7 +1497,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
14961497
14971498 self .grad_node_out_list = grad_node_out_list
14981499
1499- def run (self ):
1500+ def run (self , append_input_out = False ):
15001501 # Basic Validation Check
15011502 self .DygraphYamlValidationCheck ()
15021503
@@ -1684,7 +1685,9 @@ def GenerateForwardLayoutAutotune(
16841685
16851686 return layout_logic_str
16861687
1687- def GenerateForwardDefinitionAndDeclaration (self , is_inplaced , grad_flag ):
1688+ def GenerateForwardDefinitionAndDeclaration (
1689+ self , is_inplaced , grad_flag , append_input_out
1690+ ):
16881691 namespace = self .namespace
16891692 if self .forward_api_name [- 1 ] == '_' and not is_inplaced :
16901693 return
@@ -1881,6 +1884,24 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
18811884
18821885 inputs_args_declaration_str = ", " .join (inputs_args_declaration_list )
18831886 inputs_args_definition_str = ", " .join (inputs_args_definition_list )
1887+ if (
1888+ append_input_out
1889+ and not grad_flag
1890+ and not is_inplaced
1891+ and len (self .forward_outputs_position_map ) == 1
1892+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
1893+ == "Tensor"
1894+ and forward_api_name != "empty_like"
1895+ ):
1896+ inputs_args_declaration_str = (
1897+ inputs_args_declaration_str
1898+ + ", paddle::optional<paddle::Tensor*> input_out = paddle::none"
1899+ )
1900+ inputs_args_definition_str = (
1901+ inputs_args_definition_str
1902+ + ", paddle::optional<paddle::Tensor*> input_out"
1903+ )
1904+ inputs_call_list .append ("input_out" )
18841905 inputs_call_args_str = ", " .join (inputs_call_list )
18851906 self .inputs_call_list = inputs_call_list
18861907
@@ -2135,6 +2156,16 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
21352156 + " " .join (amp_autocast_optional_list )
21362157 )
21372158 amp_inputs_call_args_str = ", " .join (amp_inputs_call_list )
2159+ if (
2160+ append_input_out
2161+ and not grad_flag
2162+ and not is_inplaced
2163+ and len (self .forward_outputs_position_map ) == 1
2164+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
2165+ == "Tensor"
2166+ and forward_api_name != "empty_like"
2167+ ):
2168+ amp_inputs_call_args_str = amp_inputs_call_args_str + ", input_out"
21382169 amp_call_str = (
21392170 f"return { forward_ad_function_name } ({ amp_inputs_call_args_str } );"
21402171 )
@@ -2158,6 +2189,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
21582189 type_promote_inputs_call_args_str = ", " .join (
21592190 type_promote_inputs_call_list
21602191 )
2192+ if (
2193+ append_input_out
2194+ and not grad_flag
2195+ and not is_inplaced
2196+ and len (self .forward_outputs_position_map ) == 1
2197+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
2198+ == "Tensor"
2199+ and forward_api_name != "empty_like"
2200+ ):
2201+ type_promote_inputs_call_args_str = (
2202+ type_promote_inputs_call_args_str + ", input_out"
2203+ )
21612204 type_promote_call_list = f"return { forward_ad_function_name } ({ type_promote_inputs_call_args_str } );"
21622205
21632206 x_cast = (
@@ -2180,6 +2223,19 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
21802223 type_promote_inputs_call_args_str = ", " .join (
21812224 type_promote_inputs_call_list
21822225 )
2226+ if (
2227+ append_input_out
2228+ and not grad_flag
2229+ and not is_inplaced
2230+ and len (self .forward_outputs_position_map ) == 1
2231+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
2232+ == "Tensor"
2233+ and forward_api_name != "empty_like"
2234+ ):
2235+ type_promote_inputs_call_args_str = (
2236+ type_promote_inputs_call_args_str + ", input_out"
2237+ )
2238+
21832239 type_promote_call_list = f"return { forward_ad_function_name } ({ type_promote_inputs_call_args_str } );"
21842240
21852241 x_cast = (
@@ -2323,15 +2379,19 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
23232379
23242380 self .forward_declaration_str += f"TEST_API { returns_type_str } { forward_ad_function_name } ({ inputs_args_declaration_str } );\n "
23252381
2326- def GenerateInplacedForwardDygraphFunctions (self , grad_flag ):
2382+ def GenerateInplacedForwardDygraphFunctions (
2383+ self , grad_flag , append_input_out
2384+ ):
23272385 # Inplaced Version Dygraph Function Generation
23282386 forward_api_name = self .forward_api_name
23292387 forward_api_contents = self .forward_api_contents
23302388
23312389 if forward_api_name != "sum" and "inplace" in forward_api_contents :
23322390 # Function Definition and Declaration Generation
23332391 self .GenerateForwardDefinitionAndDeclaration (
2334- is_inplaced = True , grad_flag = grad_flag
2392+ is_inplaced = True ,
2393+ grad_flag = grad_flag ,
2394+ append_input_out = append_input_out ,
23352395 )
23362396 self .UpdateCoreOpsInformation (is_inplaced = True )
23372397
@@ -2367,21 +2427,25 @@ def UpdateCoreOpsInformation(self, is_inplaced):
23672427 for name , (ttype , pos ) in forward_outputs_position_map .items ():
23682428 core_ops_returns_info [fwd_api_name ][pos ] = name
23692429
2370- def run (self , grad_flag = False ):
2371- super ().run ()
2430+ def run (self , grad_flag = False , append_input_out = False ):
2431+ super ().run (append_input_out = append_input_out )
23722432
23732433 ###################
23742434 # Code Generation #
23752435 ###################
23762436
23772437 # Definition And Declaration
23782438 self .GenerateForwardDefinitionAndDeclaration (
2379- is_inplaced = False , grad_flag = grad_flag
2439+ is_inplaced = False ,
2440+ grad_flag = grad_flag ,
2441+ append_input_out = append_input_out ,
23802442 )
23812443
23822444 self .UpdateCoreOpsInformation (is_inplaced = False )
23832445
2384- self .GenerateInplacedForwardDygraphFunctions (grad_flag )
2446+ self .GenerateInplacedForwardDygraphFunctions (
2447+ grad_flag , append_input_out = append_input_out
2448+ )
23852449
23862450
23872451class DygraphNodeGenerator (DygraphFunctionGeneratorBase ):
@@ -3214,8 +3278,8 @@ def _gen_api_call_code_block(
32143278 returns_str ,
32153279 )
32163280
3217- def run (self ):
3218- super ().run ()
3281+ def run (self , append_input_out = False ):
3282+ super ().run (append_input_out = append_input_out )
32193283
32203284 self .ResetOptionalInputs ()
32213285
@@ -3299,7 +3363,7 @@ def GetBackwardAPIContents(self, forward_api_contents):
32993363
33003364 return backward_api_contents
33013365
3302- def GenerateCode (self , grad_flag = False ):
3366+ def GenerateCode (self , grad_flag = False , append_input_out = True ):
33033367 if grad_flag :
33043368 op_string = 'backward_op'
33053369 else :
@@ -3347,7 +3411,9 @@ def GenerateCode(self, grad_flag=False):
33473411 forward_apis_dict ,
33483412 namespace ,
33493413 )
3350- function_generator .run (grad_flag )
3414+ function_generator .run (
3415+ grad_flag , append_input_out = append_input_out
3416+ )
33513417
33523418 self .forward_definition_str += (
33533419 function_generator .forward_definition_str + "\n "
@@ -3372,7 +3438,7 @@ def GenerateCode(self, grad_flag=False):
33723438 namespace ,
33733439 next_grad_api_contents ,
33743440 )
3375- node_generator .run ()
3441+ node_generator .run (append_input_out = append_input_out )
33763442 self .node_declaration_str += (
33773443 node_generator .node_declaration_str + "\n "
33783444 )
@@ -3407,12 +3473,12 @@ def GenerateCode(self, grad_flag=False):
34073473 namespace , self .node_definition_str
34083474 )
34093475
3410- def run (self , grad_flag = False ):
3476+ def run (self , grad_flag = False , append_input_out = False ):
34113477 self .ParseYamlContents ()
34123478
34133479 self .InferNameSpace ()
34143480
3415- self .GenerateCode (grad_flag )
3481+ self .GenerateCode (grad_flag , append_input_out = append_input_out )
34163482
34173483
34183484################
@@ -3521,7 +3587,10 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
35213587 generator = DygraphForwardAndNodesGenerator (
35223588 api_yaml_path , backward_yaml_path
35233589 )
3524- generator .run ()
3590+ append_input_out = (
3591+ "string" not in api_yaml_path and "sparse" not in api_yaml_path
3592+ )
3593+ generator .run (append_input_out = append_input_out )
35253594
35263595 node_declaration_str += generator .node_declaration_str + "\n "
35273596 node_definition_str += generator .node_definition_str + "\n "
@@ -3556,7 +3625,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
35563625 backward_yaml_path , backward_yaml_path
35573626 )
35583627
3559- generator_grad .run (True )
3628+ generator_grad .run (True , append_input_out = False )
35603629
35613630 backward_declaration_str += (
35623631 generator_grad .forward_declaration_str + "\n "
0 commit comments