From 5e95fd1f1888bd2bdada5ad22e8575823486084b Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Thu, 24 Mar 2022 09:28:08 +0000 Subject: [PATCH 1/6] [Refactor] refactored eager_gen.py PR #2 --- .../final_state_generator/codegen_utils.py | 10 +- .../final_state_generator/eager_gen.py | 109 ++++++++++++------ 2 files changed, 78 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 6e1bee37a4e59..f5e41a6ae340a 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -46,6 +46,10 @@ ############################# ### File Reader Helpers ### ############################# +def AssertMessage(lhs_str, rhs_str): + return f"lhs: {lhs_str}, rhs: {rhs_str}" + + def ReadFwdFile(filepath): f = open(filepath, 'r') contents = yaml.load(f, Loader=yaml.FullLoader) @@ -58,10 +62,10 @@ def ReadBwdFile(filepath): contents = yaml.load(f, Loader=yaml.FullLoader) ret = {} for content in contents: + assert 'backward_api' in content.keys(), AssertMessage('backward_api', + content.keys()) if 'backward_api' in content.keys(): api_name = content['backward_api'] - else: - assert False ret[api_name] = content f.close() @@ -221,7 +225,7 @@ def ParseYamlReturns(string): ), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping." ret_type = yaml_types_mapping[ret_type] - assert "Tensor" in ret_type + assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type) ret_name = RemoveSpecialSymbolsInName(ret_name) returns_list.append([ret_name, ret_type, i]) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fd750c0d07369..cd59211f02f3b 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -16,6 +16,7 @@ import re import argparse import os +import logging from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info from codegen_utils import yaml_types_mapping from codegen_utils import ReadFwdFile, ReadBwdFile @@ -30,6 +31,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import ops_to_fill_zero_for_empty_grads +from codegen_utils import AssertMessage ########### @@ -398,14 +400,21 @@ def DygraphYamlValidationCheck(self): forward_api_contents = self.forward_api_contents grad_api_contents = self.grad_api_contents - assert 'api' in forward_api_contents.keys() - assert 'args' in forward_api_contents.keys() - assert 'output' in forward_api_contents.keys() - assert 'backward' in forward_api_contents.keys() - - assert 'args' in grad_api_contents.keys() - assert 'output' in grad_api_contents.keys() - assert 'forward' in grad_api_contents.keys() + assert 'api' in forward_api_contents.keys( + ), "Unable to find \"api\" in api.yaml" + assert 'args' in forward_api_contents.keys( + ), "Unable to find \"args\" in api.yaml" + assert 'output' in forward_api_contents.keys( + ), "Unable to find \"output\" in api.yaml" + assert 'backward' in forward_api_contents.keys( + ), "Unable to find \"backward\" in api.yaml" + + assert 'args' in grad_api_contents.keys( + ), "Unable to find \"args\" in backward.yaml" + assert 'output' in grad_api_contents.keys( + ), "Unable to find \"output\" in backward.yaml" + assert 'forward' in grad_api_contents.keys( + ), "Unable to find \"forward\" in backward.yaml" def ForwardsValidationCheck(self): forward_inputs_list = self.forward_inputs_list @@ -424,8 +433,10 @@ def ForwardsValidationCheck(self): orig_input_type = orig_forward_inputs_list[i][1] orig_input_pos = orig_forward_inputs_list[i][2] - assert forward_input_type == orig_input_type - assert forward_input_pos == orig_input_pos + assert forward_input_type == orig_input_type, AssertMessage( + forward_input_type, orig_input_type) + assert forward_input_pos == orig_input_pos, AssertMessage( + forward_input_pos, orig_input_pos) for i in range(len(forward_attrs_list)): orig_attr_name = orig_forward_attrs_list[i][0] @@ -436,9 +447,12 @@ def ForwardsValidationCheck(self): forward_attr_type = forward_attrs_list[i][1] forward_attr_default = forward_attrs_list[i][2] forward_attr_pos = forward_attrs_list[i][3] - assert orig_attr_type == forward_attr_type - assert orig_attr_default == forward_attr_default - assert orig_attr_pos == forward_attr_pos + assert orig_attr_type == forward_attr_type, AssertMessage( + orig_attr_type, forward_attr_type) + assert orig_attr_default == forward_attr_default, AssertMessage( + orig_attr_default, forward_attr_default) + assert orig_attr_pos == forward_attr_pos, AssertMessage( + orig_attr_pos, forward_attr_pos) for i in range(len(forward_returns_list)): orig_return_type = orig_forward_returns_list[i][1] @@ -446,8 +460,10 @@ def ForwardsValidationCheck(self): forward_return_type = forward_returns_list[i][1] forward_return_pos = forward_returns_list[i][2] - assert orig_return_type == forward_return_type - assert orig_return_pos == forward_return_pos + assert orig_return_type == forward_return_type, AssertMessage( + orig_return_type, forward_return_type) + assert orig_return_pos == forward_return_pos, AssertMessage( + orig_return_pos, forward_return_pos) # Check Order: Inputs, Attributes max_input_position = -1 @@ -456,7 +472,8 @@ def ForwardsValidationCheck(self): max_attr_position = -1 for _, _, _, pos in forward_attrs_list: - assert pos > max_input_position + assert pos > max_input_position, AssertMessage(pos, + max_input_position) max_attr_position = max(max_attr_position, pos) def BackwardValidationCheck(self): @@ -471,12 +488,14 @@ def BackwardValidationCheck(self): max_grad_tensor_position = -1 for _, (_, _, pos) in backward_grad_inputs_map.items(): - assert pos > max_fwd_input_position + assert pos > max_fwd_input_position, AssertMessage( + pos, max_grad_tensor_position) max_grad_tensor_position = max(max_grad_tensor_position, pos) max_attr_position = -1 for _, _, _, pos in backward_attrs_list: - assert pos > max_grad_tensor_position + assert pos > max_grad_tensor_position, AssertMessage( + pos, max_grad_tensor_position) max_attr_position = max(max_attr_position, pos) def IntermediateValidationCheck(self): @@ -491,7 +510,8 @@ def IntermediateValidationCheck(self): len(forward_returns_list)) for ret_name, _, pos in forward_returns_list: if ret_name in intermediate_outputs: - assert pos in intermediate_positions + assert pos in intermediate_positions, AssertMessage( + pos, intermediate_positions) def CollectBackwardInfo(self): forward_api_contents = self.forward_api_contents @@ -505,9 +525,12 @@ def CollectBackwardInfo(self): self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward( backward_args_str, backward_returns_str) - print("Parsed Backward Inputs List: ", self.backward_inputs_list) - print("Prased Backward Attrs List: ", self.backward_attrs_list) - print("Parsed Backward Returns List: ", self.backward_returns_list) + + logging.info( + f"Parsed Backward Inputs List: {self.backward_inputs_list}") + logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}") + logging.info( + f"Parsed Backward Returns List: {self.backward_returns_list}") def CollectForwardInfoFromBackwardContents(self): @@ -530,7 +553,9 @@ def SlotNameMatching(self): backward_fwd_name = FindForwardName(backward_input_name) if backward_fwd_name: # Grad Input - assert backward_fwd_name in forward_outputs_position_map.keys() + assert backward_fwd_name in forward_outputs_position_map.keys( + ), AssertMessage(backward_fwd_name, + forward_outputs_position_map.keys()) matched_forward_output_type = forward_outputs_position_map[ backward_fwd_name][0] matched_forward_output_pos = forward_outputs_position_map[ @@ -556,7 +581,7 @@ def SlotNameMatching(self): backward_input_type, False, backward_input_pos ] else: - assert False, backward_input_name + assert False, f"Cannot find {backward_input_name} in forward position map" for backward_output in backward_returns_list: backward_output_name = backward_output[0] @@ -564,9 +589,10 @@ def SlotNameMatching(self): backward_output_pos = backward_output[2] backward_fwd_name = FindForwardName(backward_output_name) - assert backward_fwd_name is not None + assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None" assert backward_fwd_name in forward_inputs_position_map.keys( - ), f"Unable to find {backward_fwd_name} in forward inputs" + ), AssertMessage(backward_fwd_name, + forward_inputs_position_map.keys()) matched_forward_input_type = forward_inputs_position_map[ backward_fwd_name][0] @@ -577,12 +603,15 @@ def SlotNameMatching(self): backward_output_type, matched_forward_input_pos, backward_output_pos ] - print("Generated Backward Fwd Input Map: ", - self.backward_forward_inputs_map) - print("Generated Backward Grad Input Map: ", - self.backward_grad_inputs_map) - print("Generated Backward Grad Output Map: ", - self.backward_grad_outputs_map) + logging.info( + f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}" + ) + logging.info( + f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}" + ) + logging.info( + f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" + ) def GenerateNodeDeclaration(self): forward_op_name = self.forward_api_name @@ -642,7 +671,7 @@ def GenerateNodeDeclaration(self): set_tensor_wrapper_methods_str, set_attribute_methods_str, tensor_wrapper_members_str, attribute_members_str) - print("Generated Node Declaration: ", self.node_declaration_str) + logging.info(f"Generated Node Declaration: {self.node_declaration_str}") def GenerateNodeDefinition(self): namespace = self.namespace @@ -710,7 +739,7 @@ def GenerateNodeDefinition(self): grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, backward_api_name, grad_api_args_str, returns_str) - print("Generated Node Definition: ", self.node_definition_str) + logging.info(f"Generated Node Definition: {self.node_definition_str}") def GenerateForwardDefinition(self, is_inplaced): namespace = self.namespace @@ -813,8 +842,10 @@ def GenerateForwardDefinition(self, is_inplaced): dygraph_event_str, node_creation_str, returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" - print("Generated Forward Definition: ", self.forward_definition_str) - print("Generated Forward Declaration: ", self.forward_declaration_str) + logging.info( + f"Generated Forward Definition: {self.forward_definition_str}") + logging.info( + f"Generated Forward Declaration: {self.forward_declaration_str}") def GenerateNodeCreationCodes(self, forward_call_str): forward_api_name = self.forward_api_name @@ -921,7 +952,8 @@ def GenerateNodeCreationCodes(self, forward_call_str): else: if num_fwd_outputs > 1: # Aligned with forward output position - assert name in forward_outputs_position_map.keys() + assert name in forward_outputs_position_map.keys( + ), AssertMessage(name, forward_outputs_position_map.keys()) fwd_output_pos = forward_outputs_position_map[name][1] tw_name = f"std::get<{fwd_output_pos}>(api_result)" else: @@ -1114,7 +1146,8 @@ def GetBackwardAPIContents(self, forward_api_contents): if 'backward' not in forward_api_contents.keys(): return None backward_api_name = forward_api_contents['backward'] - assert backward_api_name in grad_api_dict.keys() + assert backward_api_name in grad_api_dict.keys(), AssertMessage( + backward_api_name, grad_api_dict.keys()) backward_api_contents = grad_api_dict[backward_api_name] return backward_api_contents From 214ad3626b7b63b9a11871fa25d79c1f51c20a66 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Fri, 25 Mar 2022 05:54:46 +0000 Subject: [PATCH 2/6] [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes --- .../final_state_generator/eager_gen.py | 624 +++++++++--------- 1 file changed, 326 insertions(+), 298 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index cd59211f02f3b..e6e829ffd44f2 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -172,7 +172,7 @@ class {} : public egr::GradNodeBase {{ """ -NODE_CREATION_TEMPLATE = \ +FORWARD_BODY_TEMPLATE = \ """ // Get AutoGradMeta {} @@ -344,7 +344,7 @@ def GenerateCoreOpInfoDefinition(): ##################### ## Generator Class ## ##################### -class DygraphSingleFunctionGenerator(FunctionGeneratorBase): +class DygraphFunctionGeneratorBase(FunctionGeneratorBase): def __init__(self, forward_api_contents, grad_api_contents, namespace): self.forward_api_contents = forward_api_contents # Members from Parent: @@ -390,12 +390,6 @@ def __init__(self, forward_api_contents, grad_api_contents, namespace): self.backward_grad_outputs_map = { } #{ "name" : [type, fwd_position, orig_position] ...} - # Generated Results - self.forward_definition_str = "" - self.forward_declaration_str = "" - self.node_declaration_str = "" - self.node_definition_str = "" - def DygraphYamlValidationCheck(self): forward_api_contents = self.forward_api_contents grad_api_contents = self.grad_api_contents @@ -613,241 +607,7 @@ def SlotNameMatching(self): f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" ) - def GenerateNodeDeclaration(self): - forward_op_name = self.forward_api_name - backward_forward_inputs_map = self.backward_forward_inputs_map - backward_attrs_list = self.backward_attrs_list - no_need_buffers = self.no_need_buffers - - # SetTensorWrapper Methods & TensorWrapper Members - set_tensor_wrapper_methods_str = "" - tensor_wrapper_members_str = "" - clear_tensor_wrapper_str = "" - for tname, (ttype, is_fwd_input, - _) in backward_forward_inputs_map.items(): - no_need_buffer = "true" if tname in no_need_buffers else "false" - tensor_wrapper_name = GetSavedName(tname) - if IsPlainTensorType(ttype): - set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( - tname, tname, tensor_wrapper_name, tname, no_need_buffer) - - tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( - tensor_wrapper_name) - - clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format( - tensor_wrapper_name) - - else: - assert IsVectorTensorType(ttype) - set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( - tname, tname, tname, tensor_wrapper_name, no_need_buffer) - - tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( - tensor_wrapper_name) - - clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format( - tensor_wrapper_name) - - # SetAttributes & Attribute Members - set_attribute_methods_str = "" - attribute_members_str = "" - for aname, atype, default_val, _ in backward_attrs_list: - saved_attr_name = GetSavedName(aname) - set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format( - aname, GetConstReference(atype), aname, saved_attr_name, aname) - - if default_val: - attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format( - RemoveConstAndReference(atype), saved_attr_name, - default_val) - else: - attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( - RemoveConstAndReference(atype), saved_attr_name) - - grad_node_name = GetGradNodeName(forward_op_name) - self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( - grad_node_name, grad_node_name, grad_node_name, grad_node_name, - grad_node_name, clear_tensor_wrapper_str, - set_tensor_wrapper_methods_str, set_attribute_methods_str, - tensor_wrapper_members_str, attribute_members_str) - - logging.info(f"Generated Node Declaration: {self.node_declaration_str}") - - def GenerateNodeDefinition(self): - namespace = self.namespace - forward_api_name = self.forward_api_name - backward_api_name = self.backward_api_name - backward_forward_inputs_map = self.backward_forward_inputs_map - backward_grad_inputs_map = self.backward_grad_inputs_map - backward_grad_outputs_map = self.backward_grad_outputs_map - backward_attrs_list = self.backward_attrs_list - - # Construct grad_api function args - # Order: TensorWrappers, GradTensors, Attributes - grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( - backward_grad_inputs_map.keys()) + len(backward_attrs_list) - grad_api_args = ["" for i in range(grad_api_args_len)] - for name, (_, is_fwd_input, - grad_api_position), in backward_forward_inputs_map.items(): - tensor_wrapper_name = GetSavedName(name) - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" - - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_inputs_map.items(): - if IsPlainTensorType(ttype): - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}][0]" - else: - assert IsVectorTensorType(ttype) - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}]" - - for name, _, _, grad_api_position in backward_attrs_list: - saved_attribute_name = GetSavedName(name) - grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" - grad_api_args_str = ", ".join(grad_api_args) - - # Construct grad_api returns - num_bwd_outputs = len(backward_grad_outputs_map.keys()) - returns_str = f"std::vector> returns({num_bwd_outputs});\n" - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_outputs_map.items(): - # Infer Grad API Return Type - if num_bwd_outputs == 1: - # Single tensor output, return as is - if IsPlainTensorType(ttype): - returns_str += "returns[0] = { grad_api_returns };\n" - else: - assert IsVectorTensorType(ttype) - returns_str += "returns[0] = grad_api_returns;\n" - else: - # Rearrange output order accordingly - returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" - returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" - returns_str += f"return returns;\n" - - grad_node_name = GetGradNodeName(forward_api_name) - - fill_zero_str = "" - if forward_api_name in ops_to_fill_zero_for_empty_grads: - fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" - - grad_api_namespace = f"paddle::experimental::{namespace}" - - self.node_definition_str = FUNCTION_TEMPLATE.format( - grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, - backward_api_name, grad_api_args_str, returns_str) - - logging.info(f"Generated Node Definition: {self.node_definition_str}") - - def GenerateForwardDefinition(self, is_inplaced): - namespace = self.namespace - forward_api_name = GetInplacedFunctionName( - self.forward_api_name) if is_inplaced else self.forward_api_name - backward_api_name = self.backward_api_name - forward_inputs_position_map = self.forward_inputs_position_map - forward_outputs_position_map = self.forward_outputs_position_map - forward_attrs_list = self.forward_attrs_list - backward_forward_inputs_map = self.backward_forward_inputs_map - backward_grad_inputs_map = self.backward_grad_inputs_map - backward_grad_outputs_map = self.backward_grad_outputs_map - backward_attrs_list = self.backward_attrs_list - optional_inputs = self.optional_inputs - intermediate_outputs = self.intermediate_outputs - inplace_map = self.inplace_map - - # Get Function Args - num_inputs = len(forward_attrs_list) + len( - forward_inputs_position_map.keys()) - inputs_args_definition_list = ["" for i in range(num_inputs)] - inputs_args_declaration_list = ["" for i in range(num_inputs)] - inputs_call_list = ["" for i in range(num_inputs)] - for name, (ttype, pos) in forward_inputs_position_map.items(): - inputs_call_list[pos] = f"{name}" - is_optional = (name in optional_inputs) - if IsPlainTensorType(ttype): - if is_optional: - arg_str = f"const paddle::optional& {name}" - else: - if inplace_map and name in inplace_map.keys(): - arg_str = f"paddle::experimental::Tensor& {name}" - else: - arg_str = f"const paddle::experimental::Tensor& {name}" - else: - assert IsVectorTensorType(ttype) - arg_str = f"const std::vector& {name}" - - inputs_args_definition_list[pos] = arg_str - inputs_args_declaration_list[pos] = arg_str - - for name, atype, default_val, pos in forward_attrs_list: - inputs_call_list[pos] = name - if default_val is not None: - inputs_args_declaration_list[ - pos] = f"{atype} {name} = {default_val}" - else: - inputs_args_declaration_list[pos] = f"{atype} {name}" - inputs_args_definition_list[pos] = f"{atype} {name}" - - inputs_args_declaration_str = ", ".join(inputs_args_declaration_list) - inputs_args_definition_str = ", ".join(inputs_args_definition_list) - inputs_call_args_str = ", ".join(inputs_call_list) - - # Forward Full Logic - function_name = forward_api_name - if len(intermediate_outputs) > 0: - function_name = GetIntermediateAPIFunctionName(function_name) - - forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" - - # Get return type list & outputs - num_outputs = len(forward_outputs_position_map.keys()) - len( - intermediate_outputs) - returns_type_list = ["" for i in range(num_outputs)] - returns_list = ["" for i in range(num_outputs)] - for name, (rtype, pos) in forward_outputs_position_map.items(): - if name in intermediate_outputs: - continue - if num_outputs == 1: - returns_list[0] = f"api_result" - else: - # Tuple api_result - returns_list[pos] = f"std::get<{pos}>(api_result)" - - if IsPlainTensorType(rtype): - returns_type_list[pos] = "paddle::experimental::Tensor" - else: - assert IsVectorTensorType(rtype) - returns_type_list[ - pos] = "std::vector" - - if num_outputs == 1: - returns_str = returns_list[0] - returns_type_str = returns_type_list[0] - else: - returns_type_str = ", ".join(returns_type_list) - returns_type_str = f"std::tuple<{returns_type_str}>" - returns_str = ", ".join(returns_list) - returns_str = f"std::make_tuple({returns_str})" - - self.GenerateNodeCreationCodes(forward_call_str) - - node_creation_str = self.node_creation_str - dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" - forward_function_name = GetDygraphForwardFunctionName(forward_api_name) - - self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( - returns_type_str, forward_function_name, inputs_args_definition_str, - dygraph_event_str, node_creation_str, returns_str) - self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" - - logging.info( - f"Generated Forward Definition: {self.forward_definition_str}") - logging.info( - f"Generated Forward Declaration: {self.forward_declaration_str}") - - def GenerateNodeCreationCodes(self, forward_call_str): + def GenerateForwardFunctionBody(self, forward_call_str, is_inplaced): forward_api_name = self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map @@ -857,7 +617,7 @@ def GenerateNodeCreationCodes(self, forward_call_str): backward_grad_outputs_map = self.backward_grad_outputs_map backward_attrs_list = self.backward_attrs_list optional_inputs = self.optional_inputs - inplace_map = self.inplace_map + inplace_map = self.inplace_map if is_inplaced else {} # Get Input AutoGradMeta inputs_autograd_meta_list = [] @@ -1008,7 +768,7 @@ def GenerateNodeCreationCodes(self, forward_call_str): node_event_name = forward_api_name + " node_creation" node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" - self.node_creation_str = NODE_CREATION_TEMPLATE.format( + self.node_creation_str = FORWARD_BODY_TEMPLATE.format( inputs_autograd_meta_str, compute_require_grad_args_str, check_inplace_str, forward_call_str, bump_inplace_version_str, node_creation_event_str, outputs_autograd_meta_str, @@ -1017,18 +777,179 @@ def GenerateNodeCreationCodes(self, forward_call_str): set_edges_str, set_out_rank_str, set_history_str, set_grad_in_meta_str, set_retain_grad_str) - def GenerateInplacedForwardDygraphFunctions(self): - # Inplaced Version Dygraph Function Generation - forward_api_name = self.forward_api_name - forward_api_contents = self.forward_api_contents + def run(self): + # Basic Validation Check + self.DygraphYamlValidationCheck() - if forward_api_name != "sum" and "inplace" in forward_api_contents.keys( - ): - # Node Definition Generation - self.GenerateForwardDefinition(is_inplaced=True) - self.UpdateCoreOpsInformation(is_inplaced=True) + ########################## + ## Parsing Raw Contents ## + ########################## + # Parse inplace_map + self.ParseInplaceInfo() - def UpdateCoreOpsInformation(self, is_inplaced): + # Parse no_need_buffer + self.ParseNoNeedBuffer() + + # Parse optional_inputs + self.ParseDispensable() + + # Parse intermediate_outputs + self.ParseIntermediate() + self.IntermediateValidationCheck() + + # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list + self.CollectBackwardInfo() + + # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list + self.CollectForwardInfoFromBackwardContents() + + # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list + self.CollectOriginalForwardInfo() + + # Forwards Validation Check + self.ForwardsValidationCheck() + + ############################# + ## Process Parsed Contents ## + ############################# + # Initialize forward_inputs_position_map, forward_outputs_position_map + self.DetermineForwardPositionMap(self.forward_inputs_list, + self.forward_returns_list) + + # Initialize forward_inputs_position_map, forward_outputs_position_map + self.SlotNameMatching() + + # Backward Validation Check + self.BackwardValidationCheck() + + +class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): + def __init__(self, forward_api_contents, grad_api_contents, namespace): + DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, + grad_api_contents, namespace) + + # Generated Results + self.forward_definition_str = "" + self.forward_declaration_str = "" + + def GenerateForwardDefinition(self, is_inplaced): + namespace = self.namespace + forward_api_name = GetInplacedFunctionName( + self.forward_api_name) if is_inplaced else self.forward_api_name + backward_api_name = self.backward_api_name + forward_inputs_position_map = self.forward_inputs_position_map + forward_outputs_position_map = self.forward_outputs_position_map + forward_attrs_list = self.forward_attrs_list + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map + backward_attrs_list = self.backward_attrs_list + optional_inputs = self.optional_inputs + intermediate_outputs = self.intermediate_outputs + inplace_map = self.inplace_map if is_inplaced else {} + + # Get Function Args + num_inputs = len(forward_attrs_list) + len( + forward_inputs_position_map.keys()) + inputs_args_definition_list = ["" for i in range(num_inputs)] + inputs_args_declaration_list = ["" for i in range(num_inputs)] + inputs_call_list = ["" for i in range(num_inputs)] + for name, (ttype, pos) in forward_inputs_position_map.items(): + inputs_call_list[pos] = f"{name}" + is_optional = (name in optional_inputs) + if IsPlainTensorType(ttype): + if is_optional: + arg_str = f"const paddle::optional& {name}" + else: + if inplace_map and name in inplace_map.keys(): + arg_str = f"paddle::experimental::Tensor& {name}" + else: + arg_str = f"const paddle::experimental::Tensor& {name}" + else: + assert IsVectorTensorType(ttype) + arg_str = f"const std::vector& {name}" + + inputs_args_definition_list[pos] = arg_str + inputs_args_declaration_list[pos] = arg_str + + for name, atype, default_val, pos in forward_attrs_list: + inputs_call_list[pos] = name + if default_val is not None: + inputs_args_declaration_list[ + pos] = f"{atype} {name} = {default_val}" + else: + inputs_args_declaration_list[pos] = f"{atype} {name}" + inputs_args_definition_list[pos] = f"{atype} {name}" + + inputs_args_declaration_str = ", ".join(inputs_args_declaration_list) + inputs_args_definition_str = ", ".join(inputs_args_definition_list) + inputs_call_args_str = ", ".join(inputs_call_list) + + # Forward Full Logic + function_name = forward_api_name + if len(intermediate_outputs) > 0: + function_name = GetIntermediateAPIFunctionName(function_name) + + forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" + + # Get return type list & outputs + num_outputs = len(forward_outputs_position_map.keys()) - len( + intermediate_outputs) + returns_type_list = ["" for i in range(num_outputs)] + returns_list = ["" for i in range(num_outputs)] + for name, (rtype, pos) in forward_outputs_position_map.items(): + if name in intermediate_outputs: + continue + if num_outputs == 1: + returns_list[0] = f"api_result" + else: + # Tuple api_result + returns_list[pos] = f"std::get<{pos}>(api_result)" + + if IsPlainTensorType(rtype): + returns_type_list[pos] = "paddle::experimental::Tensor" + else: + assert IsVectorTensorType(rtype) + returns_type_list[ + pos] = "std::vector" + + if num_outputs == 1: + returns_str = returns_list[0] + returns_type_str = returns_type_list[0] + else: + returns_type_str = ", ".join(returns_type_list) + returns_type_str = f"std::tuple<{returns_type_str}>" + returns_str = ", ".join(returns_list) + returns_str = f"std::make_tuple({returns_str})" + + self.GenerateForwardFunctionBody(forward_call_str, is_inplaced) + + node_creation_str = self.node_creation_str + dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" + forward_function_name = GetDygraphForwardFunctionName(forward_api_name) + + self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( + returns_type_str, forward_function_name, inputs_args_definition_str, + dygraph_event_str, node_creation_str, returns_str) + self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" + + logging.info( + f"Generated Forward Definition: {self.forward_definition_str}") + logging.info( + f"Generated Forward Declaration: {self.forward_declaration_str}") + + def GenerateInplacedForwardDygraphFunctions(self): + # Inplaced Version Dygraph Function Generation + forward_api_name = self.forward_api_name + forward_api_contents = self.forward_api_contents + + if forward_api_name != "sum" and "inplace" in forward_api_contents.keys( + ): + # Node Definition Generation + self.GenerateForwardDefinition(is_inplaced=True) + self.UpdateCoreOpsInformation(is_inplaced=True) + + def UpdateCoreOpsInformation(self, is_inplaced): forward_api_name = GetInplacedFunctionName( self.forward_api_name) if is_inplaced else self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map @@ -1062,60 +983,163 @@ def UpdateCoreOpsInformation(self, is_inplaced): core_ops_returns_info[final_state_fwd_api_name][pos] = name def run(self): - # Basic Validation Check - self.DygraphYamlValidationCheck() + super().run() - ########################## - ## Parsing Raw Contents ## - ########################## - # Parse inplace_map - self.ParseInplaceInfo() + ##################### + ## Code Generation ## + ##################### + self.GenerateForwardDefinition(is_inplaced=False) - # Parse no_need_buffer - self.ParseNoNeedBuffer() + self.UpdateCoreOpsInformation(is_inplaced=False) - # Parse optional_inputs - self.ParseDispensable() + self.GenerateInplacedForwardDygraphFunctions() - # Parse intermediate_outputs - self.ParseIntermediate() - self.IntermediateValidationCheck() - # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list - self.CollectBackwardInfo() +class DygraphNodeGenerator(DygraphFunctionGeneratorBase): + def __init__(self, forward_api_contents, grad_api_contents, namespace): + DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, + grad_api_contents, namespace) - # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list - self.CollectForwardInfoFromBackwardContents() + # Generated Results + self.node_declaration_str = "" + self.node_definition_str = "" - # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list - self.CollectOriginalForwardInfo() + def GenerateNodeDeclaration(self): + forward_op_name = self.forward_api_name + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_attrs_list = self.backward_attrs_list + no_need_buffers = self.no_need_buffers - # Forwards Validation Check - self.ForwardsValidationCheck() + # SetTensorWrapper Methods & TensorWrapper Members + set_tensor_wrapper_methods_str = "" + tensor_wrapper_members_str = "" + clear_tensor_wrapper_str = "" + for tname, (ttype, is_fwd_input, + _) in backward_forward_inputs_map.items(): + no_need_buffer = "true" if tname in no_need_buffers else "false" + tensor_wrapper_name = GetSavedName(tname) + if IsPlainTensorType(ttype): + set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( + tname, tname, tensor_wrapper_name, tname, no_need_buffer) - ############################# - ## Process Parsed Contents ## - ############################# - # Initialize forward_inputs_position_map, forward_outputs_position_map - self.DetermineForwardPositionMap(self.forward_inputs_list, - self.forward_returns_list) + tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( + tensor_wrapper_name) - # Initialize forward_inputs_position_map, forward_outputs_position_map - self.SlotNameMatching() + clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format( + tensor_wrapper_name) - # Backward Validation Check - self.BackwardValidationCheck() + else: + assert IsVectorTensorType(ttype) + set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( + tname, tname, tname, tensor_wrapper_name, no_need_buffer) + + tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( + tensor_wrapper_name) + + clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format( + tensor_wrapper_name) + + # SetAttributes & Attribute Members + set_attribute_methods_str = "" + attribute_members_str = "" + for aname, atype, default_val, _ in backward_attrs_list: + saved_attr_name = GetSavedName(aname) + set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format( + aname, GetConstReference(atype), aname, saved_attr_name, aname) + + if default_val: + attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format( + RemoveConstAndReference(atype), saved_attr_name, + default_val) + else: + attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( + RemoveConstAndReference(atype), saved_attr_name) + + grad_node_name = GetGradNodeName(forward_op_name) + self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( + grad_node_name, grad_node_name, grad_node_name, grad_node_name, + grad_node_name, clear_tensor_wrapper_str, + set_tensor_wrapper_methods_str, set_attribute_methods_str, + tensor_wrapper_members_str, attribute_members_str) + + logging.info(f"Generated Node Declaration: {self.node_declaration_str}") + + def GenerateNodeDefinition(self): + namespace = self.namespace + forward_api_name = self.forward_api_name + backward_api_name = self.backward_api_name + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map + backward_attrs_list = self.backward_attrs_list + + # Construct grad_api function args + # Order: TensorWrappers, GradTensors, Attributes + grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( + backward_grad_inputs_map.keys()) + len(backward_attrs_list) + grad_api_args = ["" for i in range(grad_api_args_len)] + for name, (_, is_fwd_input, + grad_api_position), in backward_forward_inputs_map.items(): + tensor_wrapper_name = GetSavedName(name) + grad_api_args[ + grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + + for _, (ttype, fwd_position, + grad_api_position) in backward_grad_inputs_map.items(): + if IsPlainTensorType(ttype): + grad_api_args[ + grad_api_position] = f"hooked_grads[{fwd_position}][0]" + else: + assert IsVectorTensorType(ttype) + grad_api_args[ + grad_api_position] = f"hooked_grads[{fwd_position}]" + + for name, _, _, grad_api_position in backward_attrs_list: + saved_attribute_name = GetSavedName(name) + grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" + grad_api_args_str = ", ".join(grad_api_args) + + # Construct grad_api returns + num_bwd_outputs = len(backward_grad_outputs_map.keys()) + returns_str = f"std::vector> returns({num_bwd_outputs});\n" + for _, (ttype, fwd_position, + grad_api_position) in backward_grad_outputs_map.items(): + # Infer Grad API Return Type + if num_bwd_outputs == 1: + # Single tensor output, return as is + if IsPlainTensorType(ttype): + returns_str += "returns[0] = { grad_api_returns };\n" + else: + assert IsVectorTensorType(ttype) + returns_str += "returns[0] = grad_api_returns;\n" + else: + # Rearrange output order accordingly + returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" + returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" + returns_str += f"return returns;\n" + + grad_node_name = GetGradNodeName(forward_api_name) + + fill_zero_str = "" + if forward_api_name in ops_to_fill_zero_for_empty_grads: + fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + + grad_api_namespace = f"paddle::experimental::{namespace}" + + self.node_definition_str = FUNCTION_TEMPLATE.format( + grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, + backward_api_name, grad_api_args_str, returns_str) + + logging.info(f"Generated Node Definition: {self.node_definition_str}") + + def run(self): + super().run() ##################### ## Code Generation ## ##################### self.GenerateNodeDeclaration() self.GenerateNodeDefinition() - self.GenerateForwardDefinition(is_inplaced=False) - - self.UpdateCoreOpsInformation(is_inplaced=False) - - self.GenerateInplacedForwardDygraphFunctions() class DygraphYamlGenerator(YamlGeneratorBase): @@ -1162,14 +1186,18 @@ def GenerateCode(self): forward_api_contents) if backward_api_contents is None: continue - d_generator = DygraphSingleFunctionGenerator( + function_generator = DygraphForwardFunctionGenerator( + forward_api_contents, backward_api_contents, namespace) + function_generator.run() + + node_generator = DygraphNodeGenerator( forward_api_contents, backward_api_contents, namespace) - d_generator.run() + node_generator.run() - self.forward_definition_str += d_generator.forward_definition_str + "\n" - self.forward_declaration_str += d_generator.forward_declaration_str + "\n" - self.node_declaration_str += d_generator.node_declaration_str + "\n" - self.node_definition_str += d_generator.node_definition_str + "\n" + self.forward_definition_str += function_generator.forward_definition_str + "\n" + self.forward_declaration_str += function_generator.forward_declaration_str + "\n" + self.node_declaration_str += node_generator.node_declaration_str + "\n" + self.node_definition_str += node_generator.node_definition_str + "\n" if len(namespace) > 0: if namespace.endswith("::"): From 14ebc19eed8844edae58950c4a93c5f36559f6c2 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Fri, 25 Mar 2022 08:01:21 +0000 Subject: [PATCH 3/6] Fixed minor issue --- .../final_state_generator/python_c_gen.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index c7be9480f557d..0b45fee2de088 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -310,7 +310,7 @@ def GeneratePythonCFunction(self): dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_str = ",".join(dygraph_function_call_list) - # Generate Python-C Function Definitions + # Generate Python-C Function Definitions if is_forward_only: fwd_function_name = FUNCTION_NAME_TEMPLATE.format( "paddle::experimental::", namespace, forward_api_name) @@ -332,9 +332,18 @@ def GeneratePythonCFunction(self): self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( forward_api_name, namespace, forward_api_name, forward_api_name) - if len(inplace_map) > 0: + if inplace_map: inplaced_forward_api_name = GetInplacedFunctionName( self.forward_api_name) + if is_forward_only: + inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format( + "paddle::experimental::", namespace, + inplaced_forward_api_name) + else: + inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format( + "::", namespace, + GetForwardFunctionName(inplaced_forward_api_name)) + assert len( inplace_map ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" @@ -347,7 +356,7 @@ def GeneratePythonCFunction(self): self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, get_eager_tensor_str, - parse_attributes_str, fwd_function_name, + parse_attributes_str, inplaced_fwd_function_name, dygraph_function_call_str, return_str) # Generate Python-C Function Registration From 9e7ecf689e8e678ccd8055f8e6358faf4d1c5e70 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 28 Mar 2022 06:31:35 +0000 Subject: [PATCH 4/6] Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition --- .../final_state_generator/eager_gen.py | 156 ++++++++++-------- 1 file changed, 85 insertions(+), 71 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 3bac4046909e4..119cea68fd8c6 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -162,9 +162,21 @@ class {} : public egr::GradNodeBase {{ FORWARD_FUNCTION_TEMPLATE = \ """ {} {}({}) {{ - {} - {} - {} + // Dygraph Record Event +{} + // AMP Logic +{} + + // Get Input AutoGradMeta +{} + // Check Inplace +{} + // Forward API Call +{} + // Get Output AutoGradMeta +{} + // Node Creation +{} // Returns return {}; @@ -174,18 +186,10 @@ class {} : public egr::GradNodeBase {{ FORWARD_BODY_TEMPLATE = \ """ - // Get AutoGradMeta -{} - bool trace_backward = egr::Controller::Instance().HasGrad(); - bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); -{} - // Forward API Call - {} -{} - {{ -{} -{} + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); if(require_any_grad) {{ +{} egr::EagerUtils::PassStopGradient({}); // Node Construction @@ -203,7 +207,6 @@ class {} : public egr::GradNodeBase {{ {} {} }} - }} """ NAMESPACE_WRAPPER_TEMPLATE = \ @@ -625,7 +628,7 @@ def SlotNameMatching(self): f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" ) - def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): + def GenerateNodeCreationCodes(self): forward_api_name = self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map @@ -635,67 +638,20 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): backward_grad_outputs_map = self.backward_grad_outputs_map backward_attrs_list = self.backward_attrs_list optional_inputs = self.optional_inputs - inplace_map = self.inplace_map if is_inplaced else {} - # Get Input AutoGradMeta - inputs_autograd_meta_list = [] compute_require_grad_args_list = ["trace_backward"] - for name, (ttype, pos) in forward_inputs_position_map.items(): + for name, (_, _) in forward_inputs_position_map.items(): input_autograd_meta_name = GetAutoGradMetaName(name) - if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" - else: - assert IsVectorTensorType(ttype) - input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" - - inputs_autograd_meta_list.append(input_autograd_meta) compute_require_grad_args_list.append(input_autograd_meta_name) - inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) compute_require_grad_args_str = ",".join(compute_require_grad_args_list) - # Get Output AutoGradMeta - outputs_autograd_meta_list = [] + # Pass Stop Gradient Args pass_stop_gradient_args_list = ["false"] - num_fwd_outputs = len(forward_outputs_position_map.keys()) - for name, (rtype, pos) in forward_outputs_position_map.items(): + for name, (_, _) in forward_outputs_position_map.items(): output_autograd_meta_name = GetAutoGradMetaName(name) - output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - if num_fwd_outputs == 1: - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - else: - # Tuple api_result - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - - outputs_autograd_meta_list.append(output_autograd_meta) pass_stop_gradient_args_list.append(output_autograd_meta_name) - - # ComputeRequireGrad & PassStopGradient - outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) - # Check Inplace - check_inplace_str = "" - bump_inplace_version_str = "" - if is_inplaced: - for inplace_name in inplace_map.keys(): - inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name) - bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( - inplace_name, inplace_name) - # Node Construction num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys()) @@ -719,6 +675,7 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): # SetTensorWrappers set_tensor_wrappers_list = [] + num_fwd_outputs = len(forward_outputs_position_map.keys()) for name, (atype, is_fwd_input, pos) in backward_forward_inputs_map.items(): is_optional = (name in optional_inputs) @@ -794,9 +751,7 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" self.node_creation_str = FORWARD_BODY_TEMPLATE.format( - inputs_autograd_meta_str, compute_require_grad_args_str, - check_inplace_str, forward_call_str, bump_inplace_version_str, - node_creation_event_str, outputs_autograd_meta_str, + compute_require_grad_args_str, node_creation_event_str, pass_stop_gradient_args_str, node_construction_str, set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str, set_out_rank_str, set_history_str, @@ -973,7 +928,64 @@ def GenerateForwardDefinition(self, is_inplaced): returns_str = ", ".join(returns_list) returns_str = f"std::make_tuple({returns_str})" - self.GenerateNodeCreationCodes(forward_call_str, is_inplaced) + # Node Creation Pre-Processing + # 1. Get Input AutoGradMeta + inputs_autograd_meta_list = [] + compute_require_grad_args_list = ["trace_backward"] + for name, (ttype, pos) in forward_inputs_position_map.items(): + input_autograd_meta_name = GetAutoGradMetaName(name) + if IsPlainTensorType(ttype): + input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" + input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) + inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) + compute_require_grad_args_str = ",".join(compute_require_grad_args_list) + + # 2. Get Output AutoGradMeta + outputs_autograd_meta_list = [] + num_fwd_outputs = len(forward_outputs_position_map.keys()) + for name, (rtype, pos) in forward_outputs_position_map.items(): + output_autograd_meta_name = GetAutoGradMetaName(name) + output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + if num_fwd_outputs == 1: + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + else: + # Tuple api_result + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + + outputs_autograd_meta_list.append(output_autograd_meta) + + # 3. ComputeRequireGrad & PassStopGradient + outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) + + # 4. Check Inplace + check_inplace_str = "" + bump_inplace_version_str = "" + if is_inplaced: + for inplace_name in inplace_map.keys(): + inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) + bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name) + + self.GenerateNodeCreationCodes() node_creation_str = self.node_creation_str dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" @@ -1001,7 +1013,9 @@ def GenerateForwardDefinition(self, is_inplaced): self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( returns_type_str, forward_function_name, inputs_args_definition_str, - dygraph_event_str, amp_logic_str, node_creation_str, returns_str) + dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, + check_inplace_str, forward_call_str, outputs_autograd_meta_str, + node_creation_str, returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" logging.info( From b9454249ff85eb2cd419e1a5e478790f6e7c73a5 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 28 Mar 2022 07:41:49 +0000 Subject: [PATCH 5/6] Fixed issues --- .../final_state_generator/eager_gen.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 119cea68fd8c6..1abb47889c6ad 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -168,12 +168,14 @@ class {} : public egr::GradNodeBase {{ {} // Get Input AutoGradMeta -{} - // Check Inplace {} // Forward API Call {} // Get Output AutoGradMeta +{} + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); + // Check Inplace {} // Node Creation {} @@ -186,8 +188,6 @@ class {} : public egr::GradNodeBase {{ FORWARD_BODY_TEMPLATE = \ """ - bool trace_backward = egr::Controller::Instance().HasGrad(); - bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); if(require_any_grad) {{ {} egr::EagerUtils::PassStopGradient({}); @@ -297,7 +297,6 @@ class {} : public egr::GradNodeBase {{ CHECK_INPLACE_TEMPLATE = \ """ - // Check Inplace egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n """ @@ -639,12 +638,6 @@ def GenerateNodeCreationCodes(self): backward_attrs_list = self.backward_attrs_list optional_inputs = self.optional_inputs - compute_require_grad_args_list = ["trace_backward"] - for name, (_, _) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - compute_require_grad_args_list.append(input_autograd_meta_name) - compute_require_grad_args_str = ",".join(compute_require_grad_args_list) - # Pass Stop Gradient Args pass_stop_gradient_args_list = ["false"] for name, (_, _) in forward_outputs_position_map.items(): @@ -751,11 +744,10 @@ def GenerateNodeCreationCodes(self): node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" self.node_creation_str = FORWARD_BODY_TEMPLATE.format( - compute_require_grad_args_str, node_creation_event_str, - pass_stop_gradient_args_str, node_construction_str, - set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str, - set_edges_str, set_out_rank_str, set_history_str, - set_grad_in_meta_str, set_retain_grad_str) + node_creation_event_str, pass_stop_gradient_args_str, + node_construction_str, set_attributes_str, set_tensor_wrappers_str, + set_grad_out_meta_str, set_edges_str, set_out_rank_str, + set_history_str, set_grad_in_meta_str, set_retain_grad_str) def run(self): # Basic Validation Check @@ -1014,8 +1006,9 @@ def GenerateForwardDefinition(self, is_inplaced): self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( returns_type_str, forward_function_name, inputs_args_definition_str, dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, - check_inplace_str, forward_call_str, outputs_autograd_meta_str, - node_creation_str, returns_str) + forward_call_str, outputs_autograd_meta_str, + compute_require_grad_args_str, check_inplace_str, node_creation_str, + returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" logging.info( From d6d952070d84fedf6fe8f533381e74706ab57096 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 28 Mar 2022 09:48:27 +0000 Subject: [PATCH 6/6] Fixed minor issue --- .../auto_code_generator/final_state_generator/eager_gen.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 1abb47889c6ad..f5ae3da91dadc 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -175,7 +175,8 @@ class {} : public egr::GradNodeBase {{ {} bool trace_backward = egr::Controller::Instance().HasGrad(); bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); - // Check Inplace + // Check Inplace & Bump Inplace Version +{} {} // Node Creation {} @@ -1007,8 +1008,8 @@ def GenerateForwardDefinition(self, is_inplaced): returns_type_str, forward_function_name, inputs_args_definition_str, dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, forward_call_str, outputs_autograd_meta_str, - compute_require_grad_args_str, check_inplace_str, node_creation_str, - returns_str) + compute_require_grad_args_str, check_inplace_str, + bump_inplace_version_str, node_creation_str, returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" logging.info(