Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DoubleGrad PR #2] Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition #41016

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,24 @@ class {} : public egr::GradNodeBase {{
FORWARD_FUNCTION_TEMPLATE = \
"""
{} {}({}) {{
{}
{}
{}
// Dygraph Record Event
{}
// AMP Logic
{}

// Get Input AutoGradMeta
{}
// Forward API Call
{}
// Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace & Bump Inplace Version
{}
{}
// Node Creation
{}

// Returns
return {};
Expand All @@ -174,18 +189,8 @@ class {} : public egr::GradNodeBase {{

FORWARD_BODY_TEMPLATE = \
"""
// Get AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
{}
// Forward API Call
{}
{}
{{
{}
{}
if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({});

// Node Construction
Expand All @@ -203,7 +208,6 @@ class {} : public egr::GradNodeBase {{
{}
{}
}}
}}
"""

NAMESPACE_WRAPPER_TEMPLATE = \
Expand Down Expand Up @@ -294,7 +298,6 @@ class {} : public egr::GradNodeBase {{

CHECK_INPLACE_TEMPLATE = \
"""
// Check Inplace
egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
"""

Expand Down Expand Up @@ -625,7 +628,7 @@ def SlotNameMatching(self):
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)

def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
def GenerateNodeCreationCodes(self):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
Expand All @@ -635,67 +638,14 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
inplace_map = self.inplace_map if is_inplaced else {}

# Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"

inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)

# Get Output AutoGradMeta
outputs_autograd_meta_list = []
# Pass Stop Gradient Args
pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
for name, (_, _) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"

outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)

# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)

# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)

# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
Expand All @@ -719,6 +669,7 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):

# SetTensorWrappers
set_tensor_wrappers_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items():
is_optional = (name in optional_inputs)
Expand Down Expand Up @@ -794,13 +745,10 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"

self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str,
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
node_creation_event_str, pass_stop_gradient_args_str,
node_construction_str, set_attributes_str, set_tensor_wrappers_str,
set_grad_out_meta_str, set_edges_str, set_out_rank_str,
set_history_str, set_grad_in_meta_str, set_retain_grad_str)

def run(self):
# Basic Validation Check
Expand Down Expand Up @@ -973,7 +921,64 @@ def GenerateForwardDefinition(self, is_inplaced):
returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})"

self.GenerateNodeCreationCodes(forward_call_str, is_inplaced)
# Node Creation Pre-Processing
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"

inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)

# 2. Get Output AutoGradMeta
outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"

outputs_autograd_meta_list.append(output_autograd_meta)

# 3. ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)

# 4. Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)

self.GenerateNodeCreationCodes()

node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
Expand Down Expand Up @@ -1001,7 +1006,10 @@ def GenerateForwardDefinition(self, is_inplaced):

self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, node_creation_str, returns_str)
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_call_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"

logging.info(
Expand Down