3535
3636### 多功能钩子策略
3737
38- - ` TorchFunctionHook ` :
39- 通过重写 PyTorch 官方的 ` torch.overrides.TorchFunctionMode ` 类实现。经过测试,这是追踪 PyTorch API 的首选方法,能够高效、准确地捕获所有进入 PyTorch C++ 后端的函数调用(Torch C API),覆盖范围广、对用户代码无侵入。
40-
4138- ` SetattrHook ` :
42- 通过 Python 的 ` setattr ` 机制,在运行时动态遍历并替换模块中的函数对象。它可以挂钩任意 Python 库的 API。可用于扫描全库并产出 api_list,追踪纯 Python 函数的调用。
39+ 通过 Python 的 ` setattr ` 机制,在运行时动态遍历并替换模块中的函数对象。它可以挂钩任意 Python 库的 API。可用于扫描全库并产出 ` api_list/torch_api_list_full.yaml ` ,追踪纯 Python 函数的调用。
4340
44- 但由于 PyTorch 的大量核心功能由 C++ 实现,` SetattrHook ` 无法覆盖到底层算子,表现不尽人意。因此对于 PyTorch 不再启用,留作备用/辅助策略,当前默认仅使用 ` TorchFunctionHook ` 对 PyTorch 进行追踪。
41+ 由于 PyTorch 的大量核心功能由 C++ 实现,` SetattrHook ` 钩子无法覆盖到底层算子,但可以抓取如 ` nn.Linear ` 等类级别 API。覆盖范围默认是 ` api_list/torch_api_list.yaml ` 的子集,由参数 ` disable_torch_api_list ` 控制。
42+
43+ - ` TorchFunctionHook ` :
44+ 通过重写 PyTorch 官方的 ` torch.overrides.TorchFunctionMode ` 类实现。该方法可以捕获所有支持 ` __torch_function__ ` 协议的 API 调用,即进入 PyTorch C++ 后端的函数调用。经过测试,这是追踪 PyTorch API 的首选方法(目前采用 ` SetattrHook`` + ` TorchFunctionHook`` 结合的方式),能够高效、准确地捕获所有 Torch C API 调用,覆盖范围广、对用户代码无侵入。
4545
46- ### PyTorch 追踪相关实现
46+ - ` TorchDispatchHook ` :
47+ 通过重写 PyTorch 官方的 ` torch.utils._python_dispatch ` 库的 ` TorchDispatchMode ` 类实现。该方法可以捕获所有通过 ` torch.dispatch ` 调用的函数,包括自定义的 ` Tensor ` 操作。 ` torch.dispatch ` 是 PyTorch 内部使用的调度机制,可以捕获到所有底层算子的调用(如 ` aten:: ` ),是抓取 PyTorch 底层算子的首选方法。
4748
48- - ` TorchFunctionHook ` 启用全局的 ` TorchFunctionMode ` ,所有支持该协议的 PyTorch API 调用都会被重载的 ` __torch_function__ ` 方法捕获
49- - ` FrameworkDialect ` 抽象类实现 ` PyTorchDialect ` ,其方法 ` serialize_special_type ` , ` format_special_type ` 实现了针对 ` torch.Tensor ` , ` torch.dtype ` , ` torch.device ` 等 PyTorch 特有类型的序列化逻辑
49+ ### PyTorch 其他定制实现
50+
51+ - ` PyTorchDialect ` 实现了 ` FrameworkDialect ` 抽象类 ,其方法 ` serialize_special_type ` 、 ` format_special_type ` 实现了针对 ` torch.Tensor ` 、 ` torch.dtype ` 、 ` torch.device ` 等 PyTorch 特有类型的序列化逻辑
5052
5153## 如何使用
5254
5860import torch
5961from api_tracer import APITracer
6062
61- # 初始化 Tracer,指定框架方言(当前为 'torch')和输出路径
62- tracer = APITracer(dialect = " torch" , output_path = " trace_output" )
63+ # 初始化 Tracer,指定框架方言为 'torch'
64+ tracer = APITracer(
65+ dialect = " torch" , output_path = " trace_output" , levels = 1 , merge_output = True
66+ )
6367
6468# 使用 with 语句来自动启动和停止追踪
6569with tracer:
6670 # 执行 PyTorch 模型或代码
67- tensor1 = torch.randn(2 , 3 , device = ' cpu' )
71+ tensor1 = torch.randn(2 , 3 , device = " cpu" )
6872 tensor2 = torch.ones(2 , 3 )
6973 result = torch.add(tensor1, tensor2, alpha = 10 )
7074 final_sum = result.sum()
@@ -73,13 +77,27 @@ with tracer:
7377# tracer.start()
7478# 执行 PyTorch 模型或代码
7579# tracer.stop()
80+
7681```
7782
83+ ** 参数说明**
84+
85+ - ` dialect ` (str): 支持的框架方言,目前仅支持 ` torch `
86+ - ` output_path ` (str): 抓取结果的保存目录路径
87+ - ` levels ` (int|List[ int] ): 控制钩子的粒度,可同时启用多个钩子,默认为 ` 0 ` 。映射如下:
88+ - ` 0 ` : ` SetattrHook `
89+ - ` 1 ` : ` TorchFunctionHook `
90+ - ` 2 ` : ` TorchDispatchHook `
91+ - ` merge_output ` (bool): 输出时是否将不同 level 的结果合并,默认为 ` False ` 。
92+
93+ ** 可选参数**
94+ - ` disable_torch_api_list ` (bool): 是否禁用 ` torch_api_list ` ,仅影响 ` PyTorchDialect ` 的 ` SetattrHook ` 钩子。设置为 ` True ` 时将抓取所有遍历到并被 ` setattr ` 钩住的 API ,除非在 ` PyTorchDialect ` 中被排除。默认为 ` False ` 。
95+
7896### 输出文件
7997
80- 执行上述代码后,你将在 ` trace_output ` 目录下找到两个文件 :
98+ 执行上述代码后,你将在 ` trace_output ` 目录下找到五个文件 :
8199
82- 1 . ** ` api_trace.yaml ` ** : 结构化的 API 调用记录
100+ 1 . ** ` api_trace.yaml ` ** : 结构化的 API 调用记录
83101
84102 ``` yaml
85103 - api : torch.randn
@@ -120,11 +138,40 @@ with tracer:
120138 kwargs : {}
121139 ` ` `
122140
123- 2. **` api_trace.txt`**: 更易读的格式
141+ 2. **` api_trace.txt`**: 更易读的格式
124142
125143 ` ` ` text
126144 torch.randn(2, 3, device="cpu")
127145 torch.ones(2, 3)
128146 torch.add(Tensor([2, 3], "float32"), Tensor([2, 3], "float32"), alpha=10)
129147 torch.Tensor.sum(Tensor([2, 3], "float32"))
130148 ` ` `
149+ 3. **`api_apis.txt`** : API 集合
150+
151+ ` ` ` text
152+ torch.Tensor.sum
153+ torch.add
154+ torch.ones
155+ torch.randn
156+ ` ` `
157+
158+ 4. **`api_configs.yaml`** : API 配置集合(去重排序 `api_trace.txt` )
159+
160+ ` ` ` text
161+ torch.Tensor.sum(Tensor([2, 3], "float32"))
162+ torch.add(Tensor([2, 3], "float32"), Tensor([2, 3], "float32"), alpha=10)
163+ torch.ones(2, 3)
164+ torch.randn(2, 3, device="cpu")
165+ ` ` `
166+
167+ 5. **`api_statistics.yaml`** : API 统计信息
168+
169+ ` ` ` text
170+ Total APIs: 4
171+ Total API calls: 4
172+
173+ torch.randn: 1 (25.00%)
174+ torch.ones: 1 (25.00%)
175+ torch.add: 1 (25.00%)
176+ torch.Tensor.sum: 1 (25.00%)
177+ ` ` `
0 commit comments