Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support inplace in dygraph eager_final state #40695

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def ParseArguments():
#################
### Helpers ###
#################
def RecoverBaseNameOfInplaceFunction(function_name):
return function_name[:-1]


def GetInplacedFunctionName(function_name):
return function_name + "_"


def FindGradName(string):
return string + "_grad"

Expand Down Expand Up @@ -149,6 +157,24 @@ def ReadBwdFile(filepath):
######################
### Yaml Parsers ###
######################
def ParseInplaceInfo(string):
# string: "(x -> out0), (y -> out2)"
inplace_map = {}
for pair in string.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]

if pair.endswith(")"):
pair = pair[:-1]

key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
inplace_map[key] = val

return inplace_map


def RemoveSpecialSymbolsInName(string):
# Remove any name after '@'
ret = string.split("@")[0]
Expand Down Expand Up @@ -683,9 +709,10 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,

def GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs):
backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
Expand Down Expand Up @@ -722,19 +749,19 @@ def GenerateNodeCreationCodes(
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"

outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
Expand All @@ -743,16 +770,34 @@ def GenerateNodeCreationCodes(
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)

# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += f"""
// Check Inplace
egr::EagerUtils::CheckInplace({inplace_name}, {inplace_autograd_meta_name}, require_any_grad);\n
"""

bump_inplace_version_str += f"""
// Bump Inplace Version
{inplace_name}.bump_inplace_version();
VLOG(3) << \"Tensor(\" << {inplace_name}.name() << \") uses Inplace Strategy.\";\n
"""

# Node Construction
num_bwd_inputs = len(backward_grad_input_map.keys())
num_bwd_outputs = len(backward_grad_output_map.keys())
grad_node_name = GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
grad_node_name = GetGradNodeName(
RecoverBaseNameOfInplaceFunction(
fwd_api_name)) if inplace_map else GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"

# SetAttributes
set_attributes_list = []
for name, _, _, _ in backward_attrs_list:
set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)

Expand All @@ -763,9 +808,9 @@ def GenerateNodeCreationCodes(

if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
Expand All @@ -776,9 +821,9 @@ def GenerateNodeCreationCodes(
tw_name = f"api_result"

if is_optional:
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)

Expand All @@ -787,8 +832,8 @@ def GenerateNodeCreationCodes(
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
Expand All @@ -802,14 +847,14 @@ def GenerateNodeCreationCodes(
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"

set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
Expand All @@ -821,55 +866,64 @@ def GenerateNodeCreationCodes(
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)

node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
"""
node_creation_event_str = NODE_CREATION_TEMPLATE.format(node_event_name)

NODE_CREATION_TEMPLATE = """

// Get AutoGradMeta
{}
{}
bool trace_backward = egr::Controller::Instance().HasGrad();

bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});

// Node Construction
{}

// SetAttributes
// Forward API Call
{}
{}

// SetTensorWrappers
{{
{}

// SetGradOutMeta & SetEdges
{}
if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});

// Node Construction
{}

// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
// SetAttributes
{}
// SetTensorWrappers
{}
// SetGradOutMeta & SetEdges
{}
{}

// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
{}
{}
{}
}}
}}

"""
node_creation_str = NODE_CREATION_TEMPLATE.format(
inputs_autograd_meta_str, outputs_autograd_meta_str,
compute_require_grad_args_str, pass_stop_gradient_args_str,
node_construction_str, set_attributes_str, set_tensor_wrappers_str,
set_grad_out_meta_str, set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str, set_attributes_str,
set_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str,
set_out_rank_str, set_history_str, set_grad_in_meta_str,
set_retain_grad_str)

return node_creation_str


def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list,
optional_inputs, intermediate_outputs):
def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs, inplace_map):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
Expand All @@ -893,7 +947,10 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
if inplace_map and name in inplace_map.keys():
arg_str = f"paddle::experimental::Tensor& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
Expand Down Expand Up @@ -956,26 +1013,16 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,

node_creation_str = GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs)

node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """{{\n
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
{}\n
}}"""
node_creation_str = NODE_CREATION_TEMPLATE.format(node_event_name,
node_creation_str)
backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map)

dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{fwd_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"

FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{
{}

// Forward API Call
{}

{}

Expand All @@ -987,7 +1034,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, forward_call_str, node_creation_str, returns_str)
dygraph_event_str, node_creation_str, returns_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});"

return forward_function_str, forward_function_declaration_str
Expand Down Expand Up @@ -1189,6 +1236,10 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']

inplace_map = {}
if 'inplace' in fwd_api.keys():
inplace_map = ParseInplaceInfo(fwd_api['inplace'])

bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys()
bwd_api = grad_api_dict[bwd_api_name]
Expand Down Expand Up @@ -1285,7 +1336,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
forward_outputs_position_map, orig_forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
intermediate_outputs, {})
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
yaml_forward_definition_str += definition_declaration_pair[0]
Expand All @@ -1296,6 +1347,30 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
forward_outputs_position_map,
orig_forward_attrs_list)

# Inplaced Version Dygraph Function Generation
if fwd_api_name != "sum" and "inplace" in fwd_api.keys():
fwd_api_name_inplaced = GetInplacedFunctionName(fwd_api_name)

# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name_inplaced, bwd_api_name,
forward_inputs_position_map, forward_outputs_position_map,
forward_attrs_list, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list, optional_inputs, intermediate_outputs,
inplace_map)
print("Generated Inplaced Forward Definition: ",
forward_definition_str)
print("Generated Inplaced Forward Declaration: ",
forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]

# For python-level API dispatch
CollectCoreOpsInformation(
fwd_api_name_inplaced, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list)

if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str}
Expand Down
Loading