Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ __pycache__
log.log
run.sh
tester/api_config/**/test_log*
tools/api_tracer/.huggingface
tester/api_config/api_config*
tools/api_tracer/.huggingface
tools/api_tracer/trace_output*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ paddle.concat(tuple(Tensor([31376, 768],"float32"),Tensor([1, 768],"float32"),),
```
- 安装第三方库:
```bash
pip install pebble pynvml pandas
pip install func_timeout pandas pebble pynvml pyyaml
```

4. PaddlePaddle 与 PyTorch 的部分依赖项可能发生冲突,请先安装 *paddlepaddle-gpu* 再安装 *torch*,重新安装请在 pip 后添加 `--force-reinstall` 参数,仅更新 paddle 请添加 `--no-deps` 参数;engineV2 建议使用 python>=3.10
Expand Down
2 changes: 1 addition & 1 deletion engineV2-README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
```bash
pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install pebble pynvml pandas
pip install func_timeout pandas pebble pynvml pyyaml
```
2. 克隆 PaddleAPITest 仓库并进入项目目录
```bash
Expand Down
146 changes: 93 additions & 53 deletions tools/api_tracer/framework_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pkgutil
import traceback
from functools import partial
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch
Expand All @@ -19,9 +19,15 @@
class TracingHook(abc.ABC):
"""钩子的抽象基类"""

def __init__(self, serializer: "ConfigSerializer", level: int):
def __init__(
self,
serializer: "ConfigSerializer",
level: int,
dialect: Optional["FrameworkDialect"] = None,
):
self.serializer = serializer
self.level = level
self.dialect = dialect

@abc.abstractmethod
def install(self):
Expand All @@ -40,8 +46,7 @@ def __init__(
level: int,
dialect: "FrameworkDialect",
):
super().__init__(serializer, level)
self.dialect = dialect
super().__init__(serializer, level, dialect)
self._original_apis: Dict[str, Any] = {}
self._module_cache: Dict[str, Any] = {}

Expand Down Expand Up @@ -85,9 +90,12 @@ def wrapper(*args, **kwargs):
return wrapper

def install(self):
if self.dialect is None:
return
api_list = self.dialect.discover_apis() + self.dialect.discover_custom_ops()

# with open(os.path.join(os.path.dirname(__file__), "trace_output", "api_list.yaml"), "w") as f:
# api_path = os.path.join(os.path.dirname(__file__), "trace_output/api_list.yaml")
# with open(api_path, "w") as f:
# yaml.dump(api_list, f)

print(f"[SetattrHook] Attempting to patch {len(api_list)} APIs...")
Expand All @@ -106,29 +114,40 @@ def install(self):
original_api = getattr(parent_obj, func_name)
wrapper = None

if isinstance(original_api, property):
if original_api.fget and original_api.fset:
wrapped_getter = self._create_wrapper(
f"{api_name}.fget",
original_api.fget,
self.serializer,
self.level,
)
wrapper = property(
wrapped_getter,
original_api.fset,
original_api.fdel,
original_api.__doc__,
)
if isinstance(
original_api,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
types.BuiltinMethodType,
),
):
wrapped_func = self._create_wrapper(
api_name, original_api, self.serializer, self.level
)
elif isinstance(original_api, (classmethod, staticmethod)):
original_func = original_api.__func__
wrapped_func = self._create_wrapper(
api_name, original_func, self.serializer, self.level
)
wrapper = type(original_api)(wrapped_func)
elif callable(original_api):
wrapper = self._create_wrapper(
api_name, original_api, self.serializer, self.level
elif (
isinstance(original_api, property)
and original_api.fget
and original_api.fset
):
wrapped_getter = self._create_wrapper(
f"{api_name}.fget",
original_api.fget,
self.serializer,
self.level,
)
wrapper = property(
wrapped_getter,
original_api.fset,
original_api.fdel,
original_api.__doc__,
)

if wrapper:
Expand All @@ -153,7 +172,8 @@ def install(self):
f"[SetattrHook] Patched {patched_count} APIs. Skipped {skipped_count} non-writable APIs."
)

# with open(os.path.join(os.path.dirname(__file__), "trace_output", "api_list_wrap.yaml"), "w") as f:
# api_path = os.path.join(os.path.dirname(__file__), "trace_output/api_list_wrap.yaml")
# with open(api_path, "w") as f:
# yaml.dump(list(self._original_apis.keys()), f)

def uninstall(self):
Expand All @@ -171,9 +191,13 @@ def uninstall(self):


class TorchFunctionModeTracer(torch.overrides.TorchFunctionMode):
def __init__(self, serializer: "ConfigSerializer", level: int):
def __init__(
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
):
self.serializer = serializer
self.level = level
self.disable_torch_api_list = getattr(dialect, "disable_torch_api_list", False)
self.target_apis = getattr(dialect, "target_apis", [])

# skip these for duplicate property access of paddle.Tensor in SetattrHook
# (SetattrHook and TorchFunctionHook are installed at the same time)
Expand Down Expand Up @@ -209,9 +233,11 @@ def __torch_function__(self, func, types, args=(), kwargs=None):


class TorchFunctionHook(TracingHook):
def __init__(self, serializer: "ConfigSerializer", level: int):
super().__init__(serializer, level)
self.tracing_mode = TorchFunctionModeTracer(serializer, level)
def __init__(
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
):
super().__init__(serializer, level, dialect)
self.tracing_mode = TorchFunctionModeTracer(serializer, level, dialect)

def install(self):
print(f"[TorchFunctionHook] Enabling __torch_function__ tracing mode...")
Expand All @@ -225,9 +251,13 @@ def uninstall(self):


class TorchDispatchModeTracer(TorchDispatchMode):
def __init__(self, serializer: "ConfigSerializer", level: int):
def __init__(
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
):
self.serializer = serializer
self.level = level
self.disable_torch_api_list = getattr(dialect, "disable_torch_api_list", False)
self.target_apis = getattr(dialect, "target_apis", [])

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
Expand All @@ -240,9 +270,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):


class TorchDispatchHook(TracingHook):
def __init__(self, serializer: "ConfigSerializer", level: int):
super().__init__(serializer, level)
self.tracing_mode = TorchDispatchModeTracer(serializer, level)
def __init__(
self, serializer: "ConfigSerializer", level: int, dialect: "FrameworkDialect"
):
super().__init__(serializer, level, dialect)
self.tracing_mode = TorchDispatchModeTracer(serializer, level, dialect)

def install(self):
print(f"[TorchDispatchHook] Enabling __torch_dispatch__ tracing mode...")
Expand Down Expand Up @@ -394,13 +426,14 @@ class PyTorchDialect(FrameworkDialect):
# recommended to skip
"__call__",
"__format__",
"__instancecheck__",
"__iter__",
"__repr__",
"__str__",
"__instancecheck__",
"__subclasscheck__",
"__subclasshook__",
# optional to skip
"__getstate__",
"__setstate__",
"__enter__",
"__exit__",
}
Expand All @@ -416,10 +449,11 @@ class PyTorchDialect(FrameworkDialect):
"torch.TypedStorage",
# methods
"torch.autograd.function._is_setup_context_defined",
"torch.distributed.reduce_op",
"torch.fx.experimental.unification.multipledispatch.dispatcher.str_signature",
"torch.nn.functional.handle_torch_function",
"torch.nn.functional.has_torch_function_unary",
"torch.distributed.reduce_op",
"torch.optim.Optimizer.profile_hook_step",
}

def get_framework_name(self) -> str:
Expand Down Expand Up @@ -480,7 +514,17 @@ def discover_apis(self) -> List[str]:
continue
if full_name in self.IGNORE_CLASSES_OR_METHODS:
continue
if callable(obj) and not inspect.isclass(obj):
if isinstance(
obj,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
types.BuiltinMethodType,
staticmethod,
classmethod,
),
):
api_set.add(full_name)
elif inspect.isclass(obj):
# custom op class should be skip
Expand All @@ -490,23 +534,22 @@ def discover_apis(self) -> List[str]:
if cls_member_name in self.IGNORE_ATTRIBUTES:
continue
full_cls_name = f"{full_name}.{cls_member_name}"
if inspect.ismethod(cls_member) or inspect.isfunction(
cls_member
if isinstance(
cls_member,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
types.BuiltinMethodType,
staticmethod,
classmethod,
),
):
api_set.add(full_cls_name)
elif isinstance(cls_member, (staticmethod, classmethod)):
api_set.add(full_cls_name)
elif isinstance(cls_member, property):
if cls_member.fget and cls_member.fset:
api_set.add(full_cls_name)
elif isinstance(cls_member, partial):
if hasattr(
cls_member.func, "__module__"
) and cls_member.func.__module__.startswith("torch"):
api_set.add(full_cls_name)
elif (
hasattr(cls_member, "__isabstractmethod__")
and cls_member.__isabstractmethod__
isinstance(cls_member, property)
and cls_member.fget
and cls_member.fset
):
api_set.add(full_cls_name)
except Exception as e:
Expand Down Expand Up @@ -584,10 +627,7 @@ def get_hooks(self, serializer, levels: List[int], **kwargs) -> List[TracingHook
for level in levels:
hook_class = hook_map.get(level)
if hook_class:
if level == 0:
hooks.append(hook_class(serializer, level, self))
else:
hooks.append(hook_class(serializer, level))
hooks.append(hook_class(serializer, level, self))
else:
raise ValueError(f"Invalid level: {level}")
return hooks
8 changes: 4 additions & 4 deletions tools/api_tracer/test_infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import sys
import traceback

import yaml

Expand Down Expand Up @@ -58,11 +59,10 @@ def run_inference_test(model_name: str):
print("\n--- Generated Response ---")
print(response)
print("--------------------------\n")

except Exception as e:
print(f"An error occurred during inference for {model_name}: {e}")
finally:
print(f"✅ Test for {model_name} finished.")
except Exception as e:
traceback.print_exc()
print(f"❌ An error occurred during inference for {model_name}: {e}")


def main():
Expand Down
Loading