Skip to content

Commit 9f21f95

Browse files
Merge pull request #491 from cangtianhuang/develop
Fix `API Tracer`
2 parents 9357b57 + 989cefa commit 9f21f95

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

tools/api_tracer/api_alias.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import os
12
import re
23
from pathlib import Path
34

5+
import yaml
6+
47
INPUT_DIR = Path("tools/api_tracer/trace_output_tmp")
58

69

@@ -12,33 +15,55 @@ def parse_api(api):
1215
return api
1316

1417

15-
def process_file(input_path):
18+
def process_file(input_path, target_apis):
1619
input_name = input_path.name
1720
output_name = input_name[4:]
1821
output_path = input_path.parent / output_name
1922

23+
output_excluded_name = output_name.replace(".txt", "_excluded.txt")
24+
output_excluded_path = input_path.parent / output_excluded_name
25+
2026
apis = set()
2127
with input_path.open("r") as f:
2228
apis = set([line.strip() for line in f if line.strip()])
2329
print(f"Read {len(apis)} apis from {input_path}", flush=True)
2430

2531
alias_apis = set()
32+
excluded_apis = set()
2633
for api in apis:
34+
alias_api = api
35+
if alias_api not in target_apis:
36+
excluded_apis.add(api)
37+
continue
2738
alias_apis.add(parse_api(api))
2839

2940
with output_path.open("w") as f:
3041
f.writelines(f"{line}\n" for line in sorted(alias_apis))
3142
print(f"Write {len(alias_apis)} alias apis to {output_path}", flush=True)
3243

44+
with output_excluded_path.open("w") as f:
45+
f.writelines(f"{line}\n" for line in sorted(excluded_apis))
46+
print(f"Write {len(excluded_apis)} excluded apis to {output_excluded_path}", flush=True)
47+
3348

3449
def main():
35-
input_files = list(INPUT_DIR.glob("api_apis*.txt"))
50+
yaml_path = os.path.join(
51+
os.path.dirname(os.path.abspath(__file__)),
52+
"api_list",
53+
"torch_api_list.yaml",
54+
)
55+
target_apis = []
56+
with open(yaml_path, "r", encoding="utf-8") as f:
57+
target_apis = yaml.safe_load(f)
58+
print(f"Loaded {len(target_apis)} target APIs.")
59+
60+
input_files = list(INPUT_DIR.glob("api_apis.txt"))
3661
if not input_files:
3762
print(f"No input files found in {INPUT_DIR}", flush=True)
3863
return
3964

4065
for input_file in sorted(input_files):
41-
process_file(input_file)
66+
process_file(input_file, target_apis)
4267

4368

4469
if __name__ == "__main__":

tools/api_tracer/framework_dialect.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def install(self):
123123
types.BuiltinMethodType,
124124
),
125125
):
126-
wrapped_func = self._create_wrapper(
126+
wrapper = self._create_wrapper(
127127
api_name, original_api, self.serializer, self.level
128128
)
129129
elif isinstance(original_api, (classmethod, staticmethod)):
@@ -154,6 +154,8 @@ def install(self):
154154
setattr(parent_obj, func_name, wrapper)
155155
self._original_apis[api_name] = original_api
156156
patched_count += 1
157+
else:
158+
skipped_count += 1
157159
except (TypeError, AttributeError) as e:
158160
error_msg = str(e).lower()
159161
if (
@@ -453,7 +455,7 @@ class PyTorchDialect(FrameworkDialect):
453455
"torch.fx.experimental.unification.multipledispatch.dispatcher.str_signature",
454456
"torch.nn.functional.handle_torch_function",
455457
"torch.nn.functional.has_torch_function_unary",
456-
"torch.optim.Optimizer.profile_hook_step",
458+
"torch.optim.Optimizer.profile_hook_step", # it will be overridden by subclass of Optimizer
457459
}
458460

459461
def get_framework_name(self) -> str:
@@ -534,6 +536,11 @@ def discover_apis(self) -> List[str]:
534536
if cls_member_name in self.IGNORE_ATTRIBUTES:
535537
continue
536538
full_cls_name = f"{full_name}.{cls_member_name}"
539+
if full_cls_name in self.IGNORE_CLASSES_OR_METHODS:
540+
continue
541+
# it will be overridden by subclass of Optimizer
542+
if cls_member_name == "profile_hook_step":
543+
continue
537544
if isinstance(
538545
cls_member,
539546
(

0 commit comments

Comments
 (0)