Skip to content

Commit

Permalink
[Infrt] add skip method for inferShape codegen (#41014)
Browse files Browse the repository at this point in the history
  • Loading branch information
DannyIsFunny authored Mar 29, 2022
1 parent cc52501 commit 1840349
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
28 changes: 28 additions & 0 deletions tools/infrt/generate_phi_kernel_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import yaml
import sys
import os
from get_compat_kernel_signature import get_compat_kernels_info
Expand Down Expand Up @@ -52,6 +53,28 @@
kernel_types_info_file = "./kernels.json"
kernel_signature_info_file = "./kernel_signature.json"

skipped_phi_api_list_file = "./skipped_phi_api.json"


def get_skipped_kernel_list():
skiped_kernel_list = []
with open(skipped_phi_api_list_file, 'r') as f:
skiped_api_list = json.load(f)
infer_meta_data = get_api_yaml_info("../../")
for api in infer_meta_data:
if "kernel" not in api or "infer_meta" not in api:
continue
if api["api"] in skiped_api_list["phi_apis"]:
skiped_kernel_list.append(api["kernel"]["func"])
skiped_kernel_list += skiped_api_list["phi_kernels"]
return skiped_kernel_list


def get_api_yaml_info(file_path):
f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r")
cont = f.read()
return yaml.load(cont, Loader=yaml.FullLoader)


def generate_kernel_name(op_name, place_str):
[target_, layout_, precision_] = place_str[1:-1].split(',')
Expand Down Expand Up @@ -140,6 +163,10 @@ def generate_supported_kernel_list(load_dict):
if flag and op_name in kernel_attrs_names:
supported_kernels_list_.append(op_name)
supported_kernels_list_ = list(set(supported_kernels_list_))
skipped_kernel_list = get_skipped_kernel_list()
for skipped_kernel in skipped_kernel_list:
if skipped_kernel in skipped_kernel_list:
supported_kernels_list_.remove(skipped_kernel)
return supported_kernels_list_


Expand Down Expand Up @@ -250,6 +277,7 @@ def main():
cpu_registry_ = ""
gpu_registry_ = ""
supported_kernels = generate_supported_kernel_list(load_dict)

print("Supported kernels:")
print(supported_kernels)
for op_name in load_dict:
Expand Down
26 changes: 24 additions & 2 deletions tools/infrt/get_phi_kernel_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@
import yaml
from typing import List, Dict, Any

skipped_phi_api_list_file = "/tools/infrt/skipped_phi_api.json"
api_yaml_file = "/python/paddle/utils/code_gen/api.yaml"


def get_skipped_kernel_list():
skiped_kernel_list = []
with open(skipped_phi_api_list_file, 'r') as f:
skiped_api_list = json.load(f)
infer_meta_data = get_api_yaml_info(api_yaml_file)
for api in infer_meta_data:
if "kernel" not in api or "infer_meta" not in api:
continue
if api["api"] in skiped_api_list["phi_apis"]:
skiped_kernel_list.append(api["kernel"]["func"])
skiped_kernel_list += skiped_api_list["phi_kernels"]
return skiped_kernel_list


def parse_args():
parser = argparse.ArgumentParser("gather phi kernel and infermate info")
Expand Down Expand Up @@ -50,7 +67,7 @@ def parse_args():


def get_api_yaml_info(file_path):
f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r")
f = open(file_path, "r")
cont = f.read()
return yaml.load(cont, Loader=yaml.FullLoader)

Expand Down Expand Up @@ -259,8 +276,11 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
# TODO(wilber): handle the unknown inferShape func.
return ""

skipped_kernel_list = get_skipped_kernel_list()
for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype)
if item[0].lower() in skipped_kernel_list:
continue
ir_name = ir_ctx_name + '.' + item[0].lower(
) + '.' + ir_dtype + '.' + item[2].lower()
if ir_name in attr_data.keys() and attr_data[ir_name] is not None:
Expand Down Expand Up @@ -342,7 +362,9 @@ def gen_phi_kernel_register_code(resources: List[List[str]],

if __name__ == "__main__":
args = parse_args()
infer_meta_data = get_api_yaml_info(args.paddle_root_path)
skipped_phi_api_list_file = args.paddle_root_path + skipped_phi_api_list_file
api_yaml_file = args.paddle_root_path + api_yaml_file
infer_meta_data = get_api_yaml_info(api_yaml_file)
kernel_data = get_kernel_info(args.kernel_info_file)
info_meta_wrap_data = get_infermeta_info(args.infermeta_wrap_file)
attr_data = get_attr_info(args.attr_info_file)
Expand Down
4 changes: 4 additions & 0 deletions tools/infrt/skipped_phi_api.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"phi_apis":["conj"],
"phi_kernels":["equal_all"]
}

0 comments on commit 1840349

Please sign in to comment.