Skip to content

Commit 46fd127

Browse files
Merge pull request #467 from cangtianhuang/develop
Update `API Tracer`
2 parents 0bf0ff4 + ca1f4b6 commit 46fd127

30 files changed

+4377
-10138012
lines changed

tools/api_tracer/README.md

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@
3535

3636
### 多功能钩子策略
3737

38-
- `TorchFunctionHook`:
39-
通过重写 PyTorch 官方的 `torch.overrides.TorchFunctionMode` 类实现。经过测试,这是追踪 PyTorch API 的首选方法,能够高效、准确地捕获所有进入 PyTorch C++ 后端的函数调用(Torch C API),覆盖范围广、对用户代码无侵入。
40-
4138
- `SetattrHook`:
42-
通过 Python 的 `setattr` 机制,在运行时动态遍历并替换模块中的函数对象。它可以挂钩任意 Python 库的 API。可用于扫描全库并产出 api_list,追踪纯 Python 函数的调用。
39+
通过 Python 的 `setattr` 机制,在运行时动态遍历并替换模块中的函数对象。它可以挂钩任意 Python 库的 API。可用于扫描全库并产出 `api_list/torch_api_list_full.yaml` ,追踪纯 Python 函数的调用。
4340

44-
但由于 PyTorch 的大量核心功能由 C++ 实现,`SetattrHook` 无法覆盖到底层算子,表现不尽人意。因此对于 PyTorch 不再启用,留作备用/辅助策略,当前默认仅使用 `TorchFunctionHook` 对 PyTorch 进行追踪。
41+
由于 PyTorch 的大量核心功能由 C++ 实现,`SetattrHook` 钩子无法覆盖到底层算子,但可以抓取如 `nn.Linear` 等类级别 API。覆盖范围默认是 `api_list/torch_api_list.yaml` 的子集,由参数 `disable_torch_api_list` 控制。
42+
43+
- `TorchFunctionHook`:
44+
通过重写 PyTorch 官方的 `torch.overrides.TorchFunctionMode` 类实现。该方法可以捕获所有支持 `__torch_function__` 协议的 API 调用,即进入 PyTorch C++ 后端的函数调用。经过测试,这是追踪 PyTorch API 的首选方法(目前采用 `SetattrHook`` + `TorchFunctionHook`` 结合的方式),能够高效、准确地捕获所有 Torch C API 调用,覆盖范围广、对用户代码无侵入。
4545

46-
### PyTorch 追踪相关实现
46+
- `TorchDispatchHook`:
47+
通过重写 PyTorch 官方的 `torch.utils._python_dispatch` 库的 `TorchDispatchMode` 类实现。该方法可以捕获所有通过 `torch.dispatch` 调用的函数,包括自定义的 `Tensor` 操作。 `torch.dispatch` 是 PyTorch 内部使用的调度机制,可以捕获到所有底层算子的调用(如 `aten::`),是抓取 PyTorch 底层算子的首选方法。
4748

48-
- `TorchFunctionHook` 启用全局的 `TorchFunctionMode`,所有支持该协议的 PyTorch API 调用都会被重载的 `__torch_function__` 方法捕获
49-
- `FrameworkDialect` 抽象类实现 `PyTorchDialect`,其方法 `serialize_special_type`, `format_special_type` 实现了针对 `torch.Tensor`, `torch.dtype`, `torch.device` 等 PyTorch 特有类型的序列化逻辑
49+
### PyTorch 其他定制实现
50+
51+
- `PyTorchDialect` 实现了 `FrameworkDialect` 抽象类 ,其方法 `serialize_special_type``format_special_type` 实现了针对 `torch.Tensor``torch.dtype``torch.device` 等 PyTorch 特有类型的序列化逻辑
5052

5153
## 如何使用
5254

@@ -58,13 +60,15 @@
5860
import torch
5961
from api_tracer import APITracer
6062

61-
# 初始化 Tracer,指定框架方言(当前为 'torch')和输出路径
62-
tracer = APITracer(dialect="torch", output_path="trace_output")
63+
# 初始化 Tracer,指定框架方言为 'torch'
64+
tracer = APITracer(
65+
dialect="torch", output_path="trace_output", levels=1, merge_output=True
66+
)
6367

6468
# 使用 with 语句来自动启动和停止追踪
6569
with tracer:
6670
# 执行 PyTorch 模型或代码
67-
tensor1 = torch.randn(2, 3, device='cpu')
71+
tensor1 = torch.randn(2, 3, device="cpu")
6872
tensor2 = torch.ones(2, 3)
6973
result = torch.add(tensor1, tensor2, alpha=10)
7074
final_sum = result.sum()
@@ -73,13 +77,27 @@ with tracer:
7377
# tracer.start()
7478
# 执行 PyTorch 模型或代码
7579
# tracer.stop()
80+
7681
```
7782

83+
**参数说明**
84+
85+
- `dialect` (str): 支持的框架方言,目前仅支持 `torch`
86+
- `output_path` (str): 抓取结果的保存目录路径
87+
- `levels` (int|List[int]): 控制钩子的粒度,可同时启用多个钩子,默认为 `0` 。映射如下:
88+
- `0`: `SetattrHook`
89+
- `1`: `TorchFunctionHook`
90+
- `2`: `TorchDispatchHook`
91+
- `merge_output` (bool): 输出时是否将不同 level 的结果合并,默认为 `False`
92+
93+
**可选参数**
94+
- `disable_torch_api_list` (bool): 是否禁用 `torch_api_list` ,仅影响 `PyTorchDialect``SetattrHook` 钩子。设置为 `True` 时将抓取所有遍历到并被 `setattr` 钩住的 API ,除非在 `PyTorchDialect` 中被排除。默认为 `False`
95+
7896
### 输出文件
7997

80-
执行上述代码后,你将在 `trace_output` 目录下找到两个文件
98+
执行上述代码后,你将在 `trace_output` 目录下找到五个文件
8199

82-
1. **`api_trace.yaml`**: 结构化的 API 调用记录
100+
1. **`api_trace.yaml`**: 结构化的 API 调用记录
83101

84102
```yaml
85103
- api: torch.randn
@@ -120,11 +138,40 @@ with tracer:
120138
kwargs: {}
121139
```
122140
123-
2. **`api_trace.txt`**: 更易读的格式
141+
2. **`api_trace.txt`**: 更易读的格式
124142

125143
```text
126144
torch.randn(2, 3, device="cpu")
127145
torch.ones(2, 3)
128146
torch.add(Tensor([2, 3], "float32"), Tensor([2, 3], "float32"), alpha=10)
129147
torch.Tensor.sum(Tensor([2, 3], "float32"))
130148
```
149+
3. **`api_apis.txt`**: API 集合
150+
151+
```text
152+
torch.Tensor.sum
153+
torch.add
154+
torch.ones
155+
torch.randn
156+
```
157+
158+
4. **`api_configs.yaml`**: API 配置集合(去重排序 `api_trace.txt` )
159+
160+
```text
161+
torch.Tensor.sum(Tensor([2, 3], "float32"))
162+
torch.add(Tensor([2, 3], "float32"), Tensor([2, 3], "float32"), alpha=10)
163+
torch.ones(2, 3)
164+
torch.randn(2, 3, device="cpu")
165+
```
166+
167+
5. **`api_statistics.yaml`**: API 统计信息
168+
169+
```text
170+
Total APIs: 4
171+
Total API calls: 4
172+
173+
torch.randn: 1 (25.00%)
174+
torch.ones: 1 (25.00%)
175+
torch.add: 1 (25.00%)
176+
torch.Tensor.sum: 1 (25.00%)
177+
```

tools/api_tracer/api_alias.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import re
2+
from pathlib import Path
3+
4+
INPUT_DIR = Path("tools/api_tracer/trace_output_tmp")
5+
6+
7+
def parse_api(api):
8+
if ".Tensor." in api:
9+
return api
10+
if re.search(r"\.[A-Z][a-zA-Z0-9]*\.", api):
11+
return api.rsplit(".", 1)[0]
12+
return api
13+
14+
15+
def process_file(input_path):
16+
input_name = input_path.name
17+
output_name = input_name[4:]
18+
output_path = input_path.parent / output_name
19+
20+
apis = set()
21+
with input_path.open("r") as f:
22+
apis = set([line.strip() for line in f if line.strip()])
23+
print(f"Read {len(apis)} apis from {input_path}", flush=True)
24+
25+
alias_apis = set()
26+
for api in apis:
27+
alias_apis.add(parse_api(api))
28+
29+
with output_path.open("w") as f:
30+
f.writelines(f"{line}\n" for line in sorted(alias_apis))
31+
print(f"Write {len(alias_apis)} alias apis to {output_path}", flush=True)
32+
33+
34+
def main():
35+
input_files = list(INPUT_DIR.glob("api_apis*.txt"))
36+
if not input_files:
37+
print(f"No input files found in {INPUT_DIR}", flush=True)
38+
return
39+
40+
for input_file in sorted(input_files):
41+
process_file(input_file)
42+
43+
44+
if __name__ == "__main__":
45+
main()

0 commit comments

Comments
 (0)