Skip to content

Commit 2bec694

Browse files
authored
[API Compatible ]Provide mechanical support for the Python API to sink to the C++ layer (#74601)
* test * fix * import amax and amin from _C_ops * fix __all__ export error for build ci * add # type: ignore to ignore type check * ignore max and amax diff in docs * rm print and add the test case time out * add time out seconds and revert some error * format * recover config * reconfig cmakefile * revert config * using ctest lists instead of cmake * add time out
1 parent e4e9446 commit 2bec694

File tree

25 files changed

+1405
-394
lines changed

25 files changed

+1405
-394
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build
117117
paddle/phi/kernels/fusion/cutlass/gemm_epilogue/cutlass
118118
python/paddle/_typing/libs/**/*.pyi
119119
third_party.tar.gz
120-
120+
python/paddle/base/dygraph/generated_tensor_methods_patch.py
121121
#fp8
122122
paddle/fluid/fp8/deep_gemm/include/cute/*
123123
paddle/fluid/fp8/deep_gemm/include/cutlass/*

paddle/fluid/eager/auto_code_generator/generator/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,23 @@ add_custom_target(
9494
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_header_path}
9595
${python_c_header_path}
9696
VERBATIM)
97+
98+
set(ops_yaml_path "${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/ops.yaml")
99+
set(monkey_patch_tensor_methods_path
100+
"${PADDLE_SOURCE_DIR}/python/paddle/base/dygraph/generated_tensor_methods_patch.py"
101+
)
102+
set(tmp_monkey_patch_tensor_methods_path
103+
"${PADDLE_SOURCE_DIR}/python/paddle/base/dygraph/generated_tensor_methods_patch.py.tmp"
104+
)
105+
message("Eager monkey path tensor methods CodeGen")
106+
add_custom_target(
107+
eager_monkey_patch_codegen
108+
COMMAND
109+
"${PYTHON_EXECUTABLE}"
110+
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/monkey_patch_gen.py"
111+
"--api_yaml_path=${ops_yaml_path}"
112+
"--output_path=${tmp_monkey_patch_tensor_methods_path}"
113+
COMMAND
114+
${CMAKE_COMMAND} -E copy_if_different
115+
${tmp_monkey_patch_tensor_methods_path} ${monkey_patch_tensor_methods_path}
116+
VERBATIM)

paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def __init__(self, forward_api_contents, namespace):
479479
)
480480

481481
self.forward_api_name = ""
482+
self.python_api_info = {}
482483

483484
self.orig_forward_inputs_list = (
484485
[]
@@ -506,6 +507,15 @@ def __init__(self, forward_api_contents, namespace):
506507
) # {name: func_name, args: [input_name, ...]}
507508
self.intermediate_outputs = [] # [name, ...]
508509
self.forward_inplace_map = {} # {name : name, ...}
510+
self.args_alias_map = {} # {arg_name: alias_vector, ...}
511+
self.dygraph_pre_process = (
512+
"" # The pre_process function calling code for dygraph
513+
)
514+
self.static_pre_process = (
515+
"" # The pre_process function calling code for static graph
516+
)
517+
self.args_parser_func_name = "" # The custom args parser function
518+
self.python_api_names = ""
509519

510520
def ParseForwardInplaceInfo(self):
511521
forward_api_contents = self.forward_api_contents
@@ -515,6 +525,40 @@ def ParseForwardInplaceInfo(self):
515525
inplace_map_str = forward_api_contents['inplace']
516526
self.forward_inplace_map = ParseYamlInplaceInfo(inplace_map_str)
517527

528+
# Function for parameters parse
529+
def ParsePythonAPIInfo(self):
530+
python_api_info = self.python_api_info
531+
args_alias = {}
532+
if 'name' in python_api_info.keys():
533+
self.python_api_names = python_api_info['name']
534+
if 'args_alias' in python_api_info.keys():
535+
for arg, alias in python_api_info['args_alias'].items():
536+
alias_set = set(alias)
537+
# Add the original argument name to the alias set
538+
alias_set.add(arg)
539+
# Convert to C++ vector format
540+
alias_vector = (
541+
"{" + ",".join(f'"{name}"' for name in alias_set) + "}"
542+
)
543+
args_alias.update({arg: alias_vector})
544+
self.args_alias_map = args_alias
545+
if 'pre_process' in python_api_info.keys():
546+
pre_process = python_api_info['pre_process']
547+
if 'func' in pre_process.keys():
548+
self.dygraph_pre_process = pre_process['func']
549+
self.static_pre_process = pre_process['func']
550+
# TODO check len(pre_process) > 1
551+
552+
if 'dygraph_func' in pre_process.keys():
553+
self.dygraph_pre_process = pre_process['dygraph_func']
554+
if 'static_func' in pre_process.keys():
555+
self.static_pre_process = pre_process['static_func']
556+
if (
557+
'args_parser' in python_api_info.keys()
558+
and 'func' in python_api_info['args_parser']
559+
):
560+
self.args_parser_func_name = python_api_info['args_parser']['func']
561+
518562
def ParseNoNeedBuffer(self):
519563
grad_api_contents = self.grad_api_contents
520564

@@ -575,6 +619,8 @@ def CollectOriginalForwardInfo(self):
575619
), 'Unable to find "output" in forward_api_contents keys'
576620

577621
forward_returns_str = forward_api_contents['output']
622+
if 'python_api' in forward_api_contents.keys():
623+
self.python_api_info = forward_api_contents['python_api']
578624

579625
# Collect Original Forward Inputs/Outputs and then perform validation checks
580626
(
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
from codegen_utils import (
18+
FunctionGeneratorBase,
19+
GeneratorBase,
20+
)
21+
22+
IMPORT_TEMPLATE = """
23+
import paddle
24+
from paddle import _C_ops
25+
from .. import core
26+
"""
27+
28+
FUNCTION_NAME_TEMPLATE = """
29+
def {func_name}():
30+
"""
31+
32+
NAME_METHOD_MAPPING_TEMPLATE = """ ('{api_name}',_{api_name})"""
33+
34+
METHODS_MAP_TEMPLATE = """
35+
methods_map = [
36+
{}
37+
]
38+
"""
39+
40+
METHOD_TEMPLATE = """
41+
def _{name}(self,*args, **kwargs):
42+
return _C_ops.{name}(self,*args, **kwargs)
43+
"""
44+
SET_METHOD_TEMPLATE = """
45+
# set methods for Tensor in dygraph
46+
local_tensor = core.eager.Tensor
47+
for method_name, method in methods_map:
48+
setattr(local_tensor, method_name, method)
49+
50+
"""
51+
52+
53+
class MethodGenerator(FunctionGeneratorBase):
54+
def __init__(self, forward_api_contents, namespace):
55+
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
56+
self.need_parse_python_api_args = False
57+
# Generated Results
58+
self.Method_str = ""
59+
60+
def GenerateMethod(self, name):
61+
self.Method_str = METHOD_TEMPLATE.format(name=name)
62+
63+
def run(self):
64+
# Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list
65+
self.CollectOriginalForwardInfo()
66+
67+
if len(self.python_api_info) > 0:
68+
self.need_parse_python_api_args = True
69+
self.ParsePythonAPIInfo()
70+
for name in self.python_api_names:
71+
if "Tensor." in name:
72+
api_name = name.split(".")[-1]
73+
self.GenerateMethod(api_name)
74+
self.api_name = api_name
75+
break
76+
77+
78+
class MonkeyPatchTensorMethodsGenerator(GeneratorBase):
79+
def __init__(self, path):
80+
# Parent members:
81+
# self.namespace
82+
# self.api_yaml_path
83+
# self.forward_api_list
84+
GeneratorBase.__init__(self, path)
85+
86+
# Generated Result
87+
self.MonkeyPatchTensorMethods_str = ""
88+
89+
def GenerateMonkeyPatchTensorMethods(self):
90+
self.MonkeyPatchTensorMethods_str += IMPORT_TEMPLATE
91+
92+
forward_api_list = self.forward_api_list
93+
methods_map = [] # [("method_name",method),]
94+
for forward_api_content in forward_api_list:
95+
f_generator = MethodGenerator(forward_api_content, None)
96+
status = f_generator.run()
97+
method_str = f_generator.Method_str
98+
if method_str != "":
99+
methods_map.append(
100+
NAME_METHOD_MAPPING_TEMPLATE.format(
101+
api_name=f_generator.api_name
102+
)
103+
)
104+
self.MonkeyPatchTensorMethods_str += method_str
105+
result = ',\n '.join(methods_map)
106+
self.MonkeyPatchTensorMethods_str += METHODS_MAP_TEMPLATE.format(result)
107+
self.MonkeyPatchTensorMethods_str += FUNCTION_NAME_TEMPLATE.format(
108+
func_name="monkey_patch_generated_methods_for_tensor"
109+
)
110+
self.MonkeyPatchTensorMethods_str += SET_METHOD_TEMPLATE
111+
112+
def run(self):
113+
# Read Yaml file
114+
self.ParseForwardYamlContents()
115+
self.GenerateMonkeyPatchTensorMethods()
116+
117+
118+
##########################
119+
# Code Generation Helper #
120+
##########################
121+
def ParseArguments():
122+
parser = argparse.ArgumentParser(
123+
description='Eager Code Generator Args Parser for Monkey patch methods '
124+
)
125+
parser.add_argument('--api_yaml_path', type=str)
126+
parser.add_argument('--output_path', type=str)
127+
128+
args = parser.parse_args()
129+
return args
130+
131+
132+
def GenerateMonkeyPathFile(filepath, python_c_str):
133+
with open(filepath, 'w') as f:
134+
f.write(python_c_str)
135+
136+
137+
if __name__ == "__main__":
138+
args = ParseArguments()
139+
api_yaml_path = args.api_yaml_path
140+
output_path = args.output_path
141+
gen = MonkeyPatchTensorMethodsGenerator(api_yaml_path)
142+
gen.run()
143+
GenerateMonkeyPathFile(output_path, gen.MonkeyPatchTensorMethods_str)

0 commit comments

Comments
 (0)