Skip to content

Commit 9357b57

Browse files
Merge pull request #489 from cangtianhuang/develop
Fix `API Tracer`
2 parents d44c918 + 7ca03a7 commit 9357b57

File tree

7 files changed

+137
-91
lines changed

7 files changed

+137
-91
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ __pycache__
55
log.log
66
run.sh
77
tester/api_config/**/test_log*
8-
tools/api_tracer/.huggingface
8+
tester/api_config/api_config*
9+
tools/api_tracer/.huggingface
10+
tools/api_tracer/trace_output*

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ paddle.concat(tuple(Tensor([31376, 768],"float32"),Tensor([1, 768],"float32"),),
106106
```
107107
- 安装第三方库:
108108
```bash
109-
pip install pebble pynvml pandas
109+
pip install func_timeout pandas pebble pynvml pyyaml
110110
```
111111

112112
4. PaddlePaddle 与 PyTorch 的部分依赖项可能发生冲突,请先安装 *paddlepaddle-gpu* 再安装 *torch*,重新安装请在 pip 后添加 `--force-reinstall` 参数,仅更新 paddle 请添加 `--no-deps` 参数;engineV2 建议使用 python>=3.10

engineV2-README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
```bash
2323
pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/
2424
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
25-
pip install pebble pynvml pandas
25+
pip install func_timeout pandas pebble pynvml pyyaml
2626
```
2727
2. 克隆 PaddleAPITest 仓库并进入项目目录
2828
```bash

tools/api_tracer/framework_dialect.py

Lines changed: 93 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import pkgutil
77
import traceback
8-
from functools import partial
8+
import types
99
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
1010

1111
import torch
@@ -19,9 +19,15 @@
1919
class TracingHook(abc.ABC):
2020
"""钩子的抽象基类"""
2121

22-
def __init__(self, serializer: "ConfigSerializer", level: int):
22+
def __init__(
23+
self,
24+
serializer: "ConfigSerializer",
25+
level: int,
26+
dialect: Optional["FrameworkDialect"] = None,
27+
):
2328
self.serializer = serializer
2429
self.level = level
30+
self.dialect = dialect
2531

2632
@abc.abstractmethod
2733
def install(self):
@@ -40,8 +46,7 @@ def __init__(
4046
level: int,
4147
dialect: "FrameworkDialect",
4248
):
43-
super().__init__(serializer, level)
44-
self.dialect = dialect
49+
super().__init__(serializer, level, dialect)
4550
self._original_apis: Dict[str, Any] = {}
4651
self._module_cache: Dict[str, Any] = {}
4752

@@ -85,9 +90,12 @@ def wrapper(*args, **kwargs):
8590
return wrapper
8691

8792
def install(self):
93+
if self.dialect is None:
94+
return
8895
api_list = self.dialect.discover_apis() + self.dialect.discover_custom_ops()
8996

90-
# with open(os.path.join(os.path.dirname(__file__), "trace_output", "api_list.yaml"), "w") as f:
97+
# api_path = os.path.join(os.path.dirname(__file__), "trace_output/api_list.yaml")
98+
# with open(api_path, "w") as f:
9199
# yaml.dump(api_list, f)
92100

93101
print(f"[SetattrHook] Attempting to patch {len(api_list)} APIs...")
@@ -106,29 +114,40 @@ def install(self):
106114
original_api = getattr(parent_obj, func_name)
107115
wrapper = None
108116

109-
if isinstance(original_api, property):
110-
if original_api.fget and original_api.fset:
111-
wrapped_getter = self._create_wrapper(
112-
f"{api_name}.fget",
113-
original_api.fget,
114-
self.serializer,
115-
self.level,
116-
)
117-
wrapper = property(
118-
wrapped_getter,
119-
original_api.fset,
120-
original_api.fdel,
121-
original_api.__doc__,
122-
)
117+
if isinstance(
118+
original_api,
119+
(
120+
types.FunctionType,
121+
types.BuiltinFunctionType,
122+
types.MethodType,
123+
types.BuiltinMethodType,
124+
),
125+
):
126+
wrapped_func = self._create_wrapper(
127+
api_name, original_api, self.serializer, self.level
128+
)
123129
elif isinstance(original_api, (classmethod, staticmethod)):
124130
original_func = original_api.__func__
125131
wrapped_func = self._create_wrapper(
126132
api_name, original_func, self.serializer, self.level
127133
)
128134
wrapper = type(original_api)(wrapped_func)
129-
elif callable(original_api):
130-
wrapper = self._create_wrapper(
131-
api_name, original_api, self.serializer, self.level
135+
elif (
136+
isinstance(original_api, property)
137+
and original_api.fget
138+
and original_api.fset
139+
):
140+
wrapped_getter = self._create_wrapper(
141+
f"{api_name}.fget",
142+
original_api.fget,
143+
self.serializer,
144+
self.level,
145+
)
146+
wrapper = property(
147+
wrapped_getter,
148+
original_api.fset,
149+
original_api.fdel,
150+
original_api.__doc__,
132151
)
133152

134153
if wrapper:
@@ -153,7 +172,8 @@ def install(self):
153172
f"[SetattrHook] Patched {patched_count} APIs. Skipped {skipped_count} non-writable APIs."
154173
)
155174

156-
# with open(os.path.join(os.path.dirname(__file__), "trace_output", "api_list_wrap.yaml"), "w") as f:
175+
# api_path = os.path.join(os.path.dirname(__file__), "trace_output/api_list_wrap.yaml")
176+
# with open(api_path, "w") as f:
157177
# yaml.dump(list(self._original_apis.keys()), f)
158178

159179
def uninstall(self):
@@ -171,9 +191,13 @@ def uninstall(self):
171191

172192

173193
class TorchFunctionModeTracer(torch.overrides.TorchFunctionMode):
174-
def __init__(self, serializer: "ConfigSerializer", level: int):
194+
def __init__(
195+
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
196+
):
175197
self.serializer = serializer
176198
self.level = level
199+
self.disable_torch_api_list = getattr(dialect, "disable_torch_api_list", False)
200+
self.target_apis = getattr(dialect, "target_apis", [])
177201

178202
# skip these for duplicate property access of paddle.Tensor in SetattrHook
179203
# (SetattrHook and TorchFunctionHook are installed at the same time)
@@ -209,9 +233,11 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
209233

210234

211235
class TorchFunctionHook(TracingHook):
212-
def __init__(self, serializer: "ConfigSerializer", level: int):
213-
super().__init__(serializer, level)
214-
self.tracing_mode = TorchFunctionModeTracer(serializer, level)
236+
def __init__(
237+
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
238+
):
239+
super().__init__(serializer, level, dialect)
240+
self.tracing_mode = TorchFunctionModeTracer(serializer, level, dialect)
215241

216242
def install(self):
217243
print(f"[TorchFunctionHook] Enabling __torch_function__ tracing mode...")
@@ -225,9 +251,13 @@ def uninstall(self):
225251

226252

227253
class TorchDispatchModeTracer(TorchDispatchMode):
228-
def __init__(self, serializer: "ConfigSerializer", level: int):
254+
def __init__(
255+
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
256+
):
229257
self.serializer = serializer
230258
self.level = level
259+
self.disable_torch_api_list = getattr(dialect, "disable_torch_api_list", False)
260+
self.target_apis = getattr(dialect, "target_apis", [])
231261

232262
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
233263
kwargs = kwargs or {}
@@ -240,9 +270,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
240270

241271

242272
class TorchDispatchHook(TracingHook):
243-
def __init__(self, serializer: "ConfigSerializer", level: int):
244-
super().__init__(serializer, level)
245-
self.tracing_mode = TorchDispatchModeTracer(serializer, level)
273+
def __init__(
274+
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
275+
):
276+
super().__init__(serializer, level, dialect)
277+
self.tracing_mode = TorchDispatchModeTracer(serializer, level, dialect)
246278

247279
def install(self):
248280
print(f"[TorchDispatchHook] Enabling __torch_dispatch__ tracing mode...")
@@ -394,13 +426,14 @@ class PyTorchDialect(FrameworkDialect):
394426
# recommended to skip
395427
"__call__",
396428
"__format__",
397-
"__instancecheck__",
398429
"__iter__",
399430
"__repr__",
400431
"__str__",
432+
"__instancecheck__",
401433
"__subclasscheck__",
402434
"__subclasshook__",
403-
# optional to skip
435+
"__getstate__",
436+
"__setstate__",
404437
"__enter__",
405438
"__exit__",
406439
}
@@ -416,10 +449,11 @@ class PyTorchDialect(FrameworkDialect):
416449
"torch.TypedStorage",
417450
# methods
418451
"torch.autograd.function._is_setup_context_defined",
452+
"torch.distributed.reduce_op",
419453
"torch.fx.experimental.unification.multipledispatch.dispatcher.str_signature",
420454
"torch.nn.functional.handle_torch_function",
421455
"torch.nn.functional.has_torch_function_unary",
422-
"torch.distributed.reduce_op",
456+
"torch.optim.Optimizer.profile_hook_step",
423457
}
424458

425459
def get_framework_name(self) -> str:
@@ -480,7 +514,17 @@ def discover_apis(self) -> List[str]:
480514
continue
481515
if full_name in self.IGNORE_CLASSES_OR_METHODS:
482516
continue
483-
if callable(obj) and not inspect.isclass(obj):
517+
if isinstance(
518+
obj,
519+
(
520+
types.FunctionType,
521+
types.BuiltinFunctionType,
522+
types.MethodType,
523+
types.BuiltinMethodType,
524+
staticmethod,
525+
classmethod,
526+
),
527+
):
484528
api_set.add(full_name)
485529
elif inspect.isclass(obj):
486530
# custom op class should be skip
@@ -490,23 +534,22 @@ def discover_apis(self) -> List[str]:
490534
if cls_member_name in self.IGNORE_ATTRIBUTES:
491535
continue
492536
full_cls_name = f"{full_name}.{cls_member_name}"
493-
if inspect.ismethod(cls_member) or inspect.isfunction(
494-
cls_member
537+
if isinstance(
538+
cls_member,
539+
(
540+
types.FunctionType,
541+
types.BuiltinFunctionType,
542+
types.MethodType,
543+
types.BuiltinMethodType,
544+
staticmethod,
545+
classmethod,
546+
),
495547
):
496548
api_set.add(full_cls_name)
497-
elif isinstance(cls_member, (staticmethod, classmethod)):
498-
api_set.add(full_cls_name)
499-
elif isinstance(cls_member, property):
500-
if cls_member.fget and cls_member.fset:
501-
api_set.add(full_cls_name)
502-
elif isinstance(cls_member, partial):
503-
if hasattr(
504-
cls_member.func, "__module__"
505-
) and cls_member.func.__module__.startswith("torch"):
506-
api_set.add(full_cls_name)
507549
elif (
508-
hasattr(cls_member, "__isabstractmethod__")
509-
and cls_member.__isabstractmethod__
550+
isinstance(cls_member, property)
551+
and cls_member.fget
552+
and cls_member.fset
510553
):
511554
api_set.add(full_cls_name)
512555
except Exception as e:
@@ -584,10 +627,7 @@ def get_hooks(self, serializer, levels: List[int], **kwargs) -> List[TracingHook
584627
for level in levels:
585628
hook_class = hook_map.get(level)
586629
if hook_class:
587-
if level == 0:
588-
hooks.append(hook_class(serializer, level, self))
589-
else:
590-
hooks.append(hook_class(serializer, level))
630+
hooks.append(hook_class(serializer, level, self))
591631
else:
592632
raise ValueError(f"Invalid level: {level}")
593633
return hooks

tools/api_tracer/test_infer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
import sys
4+
import traceback
45

56
import yaml
67

@@ -58,11 +59,10 @@ def run_inference_test(model_name: str):
5859
print("\n--- Generated Response ---")
5960
print(response)
6061
print("--------------------------\n")
61-
62-
except Exception as e:
63-
print(f"An error occurred during inference for {model_name}: {e}")
64-
finally:
6562
print(f"✅ Test for {model_name} finished.")
63+
except Exception as e:
64+
traceback.print_exc()
65+
print(f"❌ An error occurred during inference for {model_name}: {e}")
6666

6767

6868
def main():

0 commit comments

Comments
 (0)