Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ set(nodes_h_path
# StringTensor only needs forward api
set(fwd_api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/strings_ops.yaml")
# The yaml file which include the python api info for ops
set(python_api_info_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/python_api_info.yaml")

message("Final State Eager CodeGen")
add_custom_target(
Expand Down Expand Up @@ -87,6 +90,7 @@ add_custom_target(
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path},${backward_yaml_path}"
"--python_api_info_yaml_path=${python_api_info_yaml_path}"
"--source_path=${tmp_python_c_source_path}"
"--header_path=${tmp_python_c_header_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_source_path}
Expand All @@ -109,6 +113,7 @@ add_custom_target(
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/monkey_patch_gen.py"
"--api_yaml_path=${ops_yaml_path}"
"--python_api_info_yaml_path=${python_api_info_yaml_path}"
"--output_path=${tmp_monkey_patch_tensor_methods_path}"
COMMAND
${CMAKE_COMMAND} -E copy_if_different
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,32 @@ def ParseYamlReturns(string):
return returns_list


def ParsePythonAPIInfoFromYAML(path) -> dict:
"""
Parse Python API information from a YAML file.

Args:
path (str): The path to the YAML file.

Returns:
dict: A dictionary containing Python API information, where the keys are operation names and the values are related api information.

Raises:
RuntimeError: This exception is raised if an error occurs while parsing the YAML file.
"""
res_dict = {}
with open(path, "r", encoding="utf-8") as f:
try:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise RuntimeError(f"read_python_api_info load error: {e}")
# Trans list to dict, the key is op in yaml item
for item in data:
if "op" in item.keys():
res_dict.update({item["op"]: item})
return res_dict


def ParseYamlForwardFromBackward(string):
# Example: matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out)

Expand Down
111 changes: 89 additions & 22 deletions paddle/fluid/eager/auto_code_generator/generator/monkey_patch_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from codegen_utils import (
FunctionGeneratorBase,
GeneratorBase,
ParsePythonAPIInfoFromYAML,
)

IMPORT_TEMPLATE = """
Expand All @@ -29,26 +30,57 @@
def {func_name}():
"""

NAME_METHOD_MAPPING_TEMPLATE = """ ('{api_name}',_{api_name})"""
NAME_METHOD_MAPPING_TEMPLATE = """ ('{op_name}',_{op_name})"""

METHODS_MAP_TEMPLATE = """
methods_map = [
{}
]

"""
FUNCTIONS_MAP_TEMPLATE = """
funcs_map = [
{}
]

"""
NN_FUNCTIONS_MAP_TEMPLATE = """
nn_funcs_map = [
{}
]

"""

METHOD_TEMPLATE = """
def _{name}(*args, **kwargs):
return _C_ops.{name}(*args, **kwargs)
"""
SET_METHOD_TEMPLATE = """
# set methods for Tensor in dygraph
# set methods for paddle.Tensor in dygraph
local_tensor = core.eager.Tensor
for method_name, method in methods_map:
setattr(local_tensor, method_name, method)

"""
SET_FUNCTION_TEMPLATE = """
# set functions for paddle
for method_name, method in funcs_map:
setattr(paddle, method_name, method)

"""
SET_NN_FUNCTION_TEMPLATE = """
# set functions for paddle.nn.functional
for method_name, method in nn_funcs_map:
setattr(paddle.nn.functional, method_name, method)
"""
# The pair of name and func which should be added to paddle
paddle_func_map = []
# The pair of name and func which should be added to paddle.Tensor
tensor_method_map = []
# The pair of name and func which should be added to paddle.nn.functional
nn_func_map = []
# The python api info which not in ops.yaml
python_api_info_from_yaml = {}


class MethodGenerator(FunctionGeneratorBase):
Expand All @@ -58,22 +90,40 @@ def __init__(self, forward_api_contents, namespace):
# Generated Results
self.Method_str = ""

def GenerateMethod(self, name):
self.Method_str = METHOD_TEMPLATE.format(name=name)

def run(self):
# Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list
self.CollectOriginalForwardInfo()

if len(self.python_api_info) > 0:
self.need_parse_python_api_args = True
self.ParsePythonAPIInfo()
for name in self.python_api_names:
if "Tensor." in name:
api_name = name.split(".")[-1]
self.GenerateMethod(api_name)
self.api_name = api_name
break
self.Method_str = GenerateMethod(self.forward_api_name)
ClassifyAPIByPrefix(self.python_api_info, self.forward_api_name)


def ExtractPrefix(full_name):
res = ""
for m in full_name.split(".")[:-1]:
res += m + '.'
return res


def GenerateMethod(name):
return METHOD_TEMPLATE.format(name=name)


def ClassifyAPIByPrefix(python_api_info, op_name):
python_api_names = python_api_info["name"]
name_func_mapping = NAME_METHOD_MAPPING_TEMPLATE.format(op_name=op_name)
for name in python_api_names:
prefix = ExtractPrefix(name)
if prefix == "paddle.":
paddle_func_map.append(name_func_mapping)
elif prefix == "paddle.Tensor.":
tensor_method_map.append(name_func_mapping)
elif prefix == "paddle.nn.functional.":
nn_func_map.append(name_func_mapping)
else:
raise Exception("Unsupported Prefix " + prefix, "API : " + name)


class MonkeyPatchTensorMethodsGenerator(GeneratorBase):
Expand All @@ -92,23 +142,34 @@ def GenerateMonkeyPatchTensorMethods(self):

forward_api_list = self.forward_api_list
methods_map = [] # [("method_name",method),]
method_str = ""
# some python api info in ops.yaml
for forward_api_content in forward_api_list:
f_generator = MethodGenerator(forward_api_content, None)
status = f_generator.run()
method_str = f_generator.Method_str
if method_str != "":
methods_map.append(
NAME_METHOD_MAPPING_TEMPLATE.format(
api_name=f_generator.api_name
)
)
self.MonkeyPatchTensorMethods_str += method_str
result = ',\n '.join(methods_map)
method_str += f_generator.Method_str
# some python api info not in ops.yaml but in python_api_info.yaml
for ops_name, python_api_info in python_api_info_from_yaml.items():
method_str += GenerateMethod(ops_name)
ClassifyAPIByPrefix(python_api_info, ops_name)

self.MonkeyPatchTensorMethods_str += method_str
result = ',\n '.join(tensor_method_map)
self.MonkeyPatchTensorMethods_str += METHODS_MAP_TEMPLATE.format(result)
result = ',\n '.join(paddle_func_map)
self.MonkeyPatchTensorMethods_str += FUNCTIONS_MAP_TEMPLATE.format(
result
)
result = ',\n '.join(nn_func_map)
self.MonkeyPatchTensorMethods_str += NN_FUNCTIONS_MAP_TEMPLATE.format(
result
)
self.MonkeyPatchTensorMethods_str += FUNCTION_NAME_TEMPLATE.format(
func_name="monkey_patch_generated_methods_for_tensor"
)
self.MonkeyPatchTensorMethods_str += SET_METHOD_TEMPLATE
self.MonkeyPatchTensorMethods_str += SET_FUNCTION_TEMPLATE
self.MonkeyPatchTensorMethods_str += SET_NN_FUNCTION_TEMPLATE

def run(self):
# Read Yaml file
Expand All @@ -125,7 +186,7 @@ def ParseArguments():
)
parser.add_argument('--api_yaml_path', type=str)
parser.add_argument('--output_path', type=str)

parser.add_argument('--python_api_info_yaml_path', type=str)
args = parser.parse_args()
return args

Expand All @@ -139,6 +200,12 @@ def GenerateMonkeyPathFile(filepath, python_c_str):
args = ParseArguments()
api_yaml_path = args.api_yaml_path
output_path = args.output_path
python_api_info_yaml_path = args.python_api_info_yaml_path

python_api_info_from_yaml = ParsePythonAPIInfoFromYAML(
python_api_info_yaml_path
)

gen = MonkeyPatchTensorMethodsGenerator(api_yaml_path)
gen.run()
GenerateMonkeyPathFile(output_path, gen.MonkeyPatchTensorMethods_str)
27 changes: 20 additions & 7 deletions paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GetForwardFunctionName,
GetInplacedFunctionName,
IsVectorTensorType,
ParsePythonAPIInfoFromYAML,
)

args_default_mapping = {
Expand All @@ -39,6 +40,8 @@
"multiply_grad",
"pull_sparse_v2_grad",
}
# The python api info which not in ops.yaml
python_api_info_from_yaml = {}


def SkipAPIGeneration(forward_api_name):
Expand Down Expand Up @@ -799,6 +802,16 @@ def pre_process_add_ampersand(s):
# Generate Python-C Function Registration
self.python_c_function_reg_str += python_c_inplace_func_reg_str

def InitAndParsePythonAPIInfo(self):
global python_api_info_from_yaml
if self.forward_api_name in python_api_info_from_yaml.keys():
self.python_api_info = python_api_info_from_yaml[
self.forward_api_name
]
if len(self.python_api_info) > 0:
self.need_parse_python_api_args = True
self.ParsePythonAPIInfo()

def run(self, no_input_out_tensor=False):
# Initialized is_forward_only
self.CollectIsForwardOnly()
Expand All @@ -811,11 +824,7 @@ def run(self, no_input_out_tensor=False):

# Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list
self.CollectOriginalForwardInfo()

if len(self.python_api_info) > 0:
self.need_parse_python_api_args = True
self.ParsePythonAPIInfo()

self.InitAndParsePythonAPIInfo()
if SkipAPIGeneration(self.forward_api_name):
return False

Expand Down Expand Up @@ -905,6 +914,7 @@ def ParseArguments():
description='Eager Code Generator Args Parser'
)
parser.add_argument('--api_yaml_path', type=str)
parser.add_argument('--python_api_info_yaml_path', type=str)
parser.add_argument('--source_path', type=str)
parser.add_argument('--header_path', type=str)

Expand Down Expand Up @@ -941,10 +951,14 @@ def GeneratePythonCFile(filepath, python_c_str):
if __name__ == "__main__":
args = ParseArguments()
api_yaml_paths = args.api_yaml_path.split(",")

generated_python_c_functions = ""
generated_python_c_registration = ""
generated_python_c_functions_header = ""
python_api_info_yaml_path = args.python_api_info_yaml_path

python_api_info_from_yaml = ParsePythonAPIInfoFromYAML(
python_api_info_yaml_path
)
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]

Expand All @@ -970,7 +984,6 @@ def GeneratePythonCFile(filepath, python_c_str):
python_c_str = GeneratePythonCWrappers(
generated_python_c_functions, generated_python_c_registration
)

source_path = args.source_path
header_path = args.header_path
for path in [source_path, header_path]:
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,17 @@ set(python_c_source_file_tmp ${python_c_source_file}.tmp)
set(trimmed_op_yaml_files
${op_fwd_yaml},${op_bwd_yaml},${fused_op_fwd_yaml},${fused_op_bwd_yaml},${pir_op_fwd_yaml},${pir_op_bwd_yaml},${pir_update_op_fwd_yaml},${pir_op_fwd_sparse_yaml},${pir_op_bfd_sparse_yaml}
)
set(python_api_info_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/python_api_info.yaml")

execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${python_c_gen_file} --op_yaml_files
${trimmed_op_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file}
--namespaces "paddle,pybind" --python_c_def_h_file
${python_c_header_file_tmp} --python_c_def_cc_file
${python_c_source_file_tmp})
${python_c_source_file_tmp} --python_api_info_yaml_path
${python_api_info_yaml_path})

set(generated_files_python_c "${python_c_header_file}"
"${python_c_source_file}")
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml


def ParsePythonAPIInfoFromYAML(path: str) -> dict:
"""
Parse Python API information from a YAML file.

Args:
path (str): The path to the YAML file.

Returns:
dict: A dictionary containing Python API information, where the keys are operation names and the values are related api information.

Raises:
RuntimeError: This exception is raised if an error occurs while parsing the YAML file.
"""
res_dict = {}
with open(path, "r", encoding="utf-8") as f:
try:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise RuntimeError(f"read_python_api_info load error: {e}")
# Trans list to dict, the key is op in yaml item
for item in data:
if "op" in item.keys():
res_dict.update({item["op"]: item})
return res_dict


def to_pascal_case(s):
Expand Down
Loading