Skip to content

Commit f981fdc

Browse files
HAOCHENYEzhouzaida
andauthored
[Docs] Add Inferencer design document. (#852)
* add infer.md * minor refine * minor refine * minor refine * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Fix conflicts and minor refine * minor refine * Fix as comment Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
1 parent c25f3ab commit f981fdc

File tree

4 files changed

+251
-3
lines changed

4 files changed

+251
-3
lines changed

docs/zh_cn/design/infer.md

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# 推理接口
2+
3+
基于 MMEngine 开发时,我们通常会为具体算法定义一个配置文件,根据配置文件去构建[执行器](./runner.md),执行训练、测试流程,并保存训练好的权重。基于训练好的模型进行推理时,通常需要执行以下步骤:
4+
5+
1. 基于配置文件构建模型
6+
2. 加载模型权重
7+
3. 搭建数据预处理流程
8+
4. 执行模型前向推理
9+
5. 可视化推理结果
10+
6. 输出推理结果
11+
12+
对于此类标准的推理流程,MMEngine 提供了统一的推理接口,并且建议用户基于这一套接口规范来开发推理代码。
13+
14+
## 使用样例
15+
16+
### 定义推理器
17+
18+
基于 `BaseInferencer` 实现自定义的推理器
19+
20+
```python
21+
from mmengine.infer import BaseInferencer
22+
23+
class CustomInferencer(BaseInferencer)
24+
...
25+
```
26+
27+
具体细节参考[开发规范](#推理接口开发规范)
28+
29+
### 构建推理器
30+
31+
**基于配置文件路径构建推理器**
32+
33+
```python
34+
cfg = 'path/to/config.py'
35+
weight = 'path/to/weight.pth'
36+
37+
inferencer = CustomInferencer(model=cfg, weight=weight)
38+
```
39+
40+
**基于配置类实例构建推理器**
41+
42+
```python
43+
from mmengine import Config
44+
45+
cfg = Config.fromfile('path/to/config.py')
46+
weight = 'path/to/weight.pth'
47+
48+
inferencer = CustomInferencer(model=cfg, weight=weight)
49+
```
50+
51+
**基于 model-index 中定义的 model name 构建推理器**,以 MMDetection 中的 [atss 检测器为例](https://github.com/open-mmlab/mmdetection/blob/31c84958f54287a8be2b99cbf87a6dcf12e57753/configs/atss/metafile.yml#L22),model name 为 `atss_r50_fpn_1x_coco`,由于 model-index 中已经定义了 weight 的路径,因此可以不配置 weight 参数。
52+
53+
```python
54+
inferencer = CustomInferencer(model='atss_r50_fpn_1x_coco')
55+
```
56+
57+
### 执行推理
58+
59+
**推理单张图片**
60+
61+
```python
62+
# 输入为图片路径
63+
img = 'path/to/img.jpg'
64+
result = inferencer(img)
65+
66+
# 输入为读取的图片(类型为 np.ndarray)
67+
img = cv2.imread('path/to/img.jpg')
68+
result = inferencer(img)
69+
70+
# 输入为 url
71+
img = 'https://xxx.com/img.jpg'
72+
result = inferencer(img)
73+
```
74+
75+
**推理多张图片**
76+
77+
```python
78+
img_dir = 'path/to/directory'
79+
result = inferencer(img_dir)
80+
```
81+
82+
```{note}
83+
OpenMMLab 系列算法库要求 `inferencer(img)` 输出一个 `dict`,其中包含 `visualization: list``predictions: list` 两个字段,分别对应可视化结果和预测结果。
84+
```
85+
86+
## 推理接口开发规范
87+
88+
inferencer 执行推理时,通常会执行以下步骤:
89+
90+
1. preprocess:输入数据预处理,包括数据读取、数据预处理、数据格式转换等
91+
2. forward: 模型前向推理
92+
3. visualize:预测结果可视化
93+
4. postprocess:预测结果后处理,包括结果格式转换、导出预测结果等
94+
95+
为了优化 inferencer 的使用体验,我们不希望使用者在执行推理时,需要为每个过程都配置一遍参数。换句话说,我们希望使用者可以在不感知上述流程的情况下,简单为 `__call__` 接口配置参数,即可完成推理。
96+
97+
`__call__` 接口会按照顺序执行上述步骤,但是本身却不知道使用者传入的参数需要分发给哪个步骤,因此开发者在实现 `CustomInferencer` 时,需要定义 `preprocess_kwargs``forward_kwargs``visualize_kwargs``postprocess_kwargs` 4 个类属性,每个属性均为一个字符集合(`Set[str]`),用于指定 `__call__` 接口中的参数对应哪个步骤:
98+
99+
```python
100+
class CustomInferencer(BaseInferencer):
101+
preprocess_kwargs = {'a'}
102+
forward_kwargs = {'b'}
103+
visualize_kwargs = {'c'}
104+
postprocess_kwargs = {'d'}
105+
106+
def preprocess(self, inputs, batch_size=1, a=None):
107+
pass
108+
109+
def forward(self, inputs, b=None):
110+
pass
111+
112+
def visualize(self, inputs, preds, show, c=None):
113+
pass
114+
115+
def postprocess(self, preds, visualization, return_datasample=False, d=None):
116+
pass
117+
118+
def __call__(
119+
self,
120+
inputs,
121+
batch_size=1,
122+
show=True,
123+
return_datasample=False,
124+
a=None,
125+
b=None,
126+
c=None,
127+
d=None):
128+
return super().__call__(
129+
inputs, batch_size, show, return_datasample, a=a, b=b, c=c, d=d)
130+
```
131+
132+
上述代码中,`preprocess``forward``visualize``postprocess` 四个函数的 `a``b``c``d` 为用户可以传入的额外参数(`inputs`, `preds` 等参数在 `__call__` 的执行过程中会被自动填入),因此开发者需要在类属性 `preprocess_kwargs``forward_kwargs``visualize_kwargs``postprocess_kwargs` 中指定这些参数,这样 `__call__` 阶段用户传入的参数就可以被正确分发给对应的步骤。分发过程由 `BaseInferencer.__call__` 函数实现,开发者无需关心。
133+
134+
此外,我们需要将 `CustomInferencer` 注册到自定义注册器或者 MMEngine 的注册器中
135+
136+
```python
137+
from mmseg.registry import INFERENCERS
138+
# 也可以注册到 MMEngine 的注册中
139+
# from mmengine.registry import INFERENCERS
140+
141+
@INFERENCERS.register_module()
142+
class CustomInferencer(BaseInferencer):
143+
...
144+
```
145+
146+
```{note}
147+
OpenMMLab 系列算法仓库必须将 Inferencer 注册到下游仓库的注册器,而不能注册到 MMEngine 的根注册器(避免重名)。
148+
```
149+
150+
**核心接口说明**
151+
152+
### `__init__()`
153+
154+
`BaseInferencer.__init__` 已经实现了[使用样例](#构建推理器)中构建推理器的逻辑,因此通常情况下不需要重写 `__init__` 函数。如果想实现自定义的加载配置文件、权重初始化、pipeline 初始化等逻辑,也可以重写 `__init__` 方法。
155+
156+
### `_init_pipeline()`
157+
158+
```{note}
159+
抽象方法,子类必须实现
160+
```
161+
162+
初始化并返回 inferencer 所需的 pipeline。pipeline 用于单张图片,类似于 OpenMMLab 系列算法库中定义的 `train_pipeline``test_pipeline`。使用者调用 `__call__` 接口传入的每个 `inputs`,都会经过 pipeline 处理,组成 batch data,然后传入 `forward` 方法。
163+
164+
### `_init_collate()`
165+
166+
初始化并返回 inferencer 所需的 `collate_fn`,其值等价于训练过程中 Dataloader 的 `collate_fn``BaseInferencer` 默认会从 `test_dataloader` 的配置中获取 `collate_fn`,因此通常情况下不需要重写 `_init_collate` 函数。
167+
168+
### `_init_visualizer()`
169+
170+
初始化并返回 inferencer 所需的 `visualizer`,其值等价于训练过程中 `visualizer``BaseInferencer` 默认会从 `visualizer` 的配置中获取 `visualizer`,因此通常情况下不需要重写 `_init_visualizer` 函数。
171+
172+
### `preprocess()`
173+
174+
入参:
175+
176+
- inputs:输入数据,由 `__call__` 传入,通常为图片路径或者图片数据组成的列表
177+
- batch_size:batch 大小,由使用者在调用 `__call__` 时传入
178+
- 其他参数:由用户传入,且在 `preprocess_kwargs` 中指定
179+
180+
返回值:
181+
182+
- 生成器,每次迭代返回一个 batch 的数据。
183+
184+
`preprocess` 默认是一个生成器函数,将 `pipeline``collate_fn` 应用于输入数据,生成器迭代返回的是组完 batch,预处理后的结果。通常情况下子类无需重写。
185+
186+
### `forward()`
187+
188+
入参:
189+
190+
- inputs:输入数据,由 `preprocess` 处理后的 batch data
191+
- 其他参数:由用户传入,且在 `forward_kwargs` 中指定
192+
193+
返回值:
194+
195+
- 预测结果,默认类型为 `List[BaseDataElement]`
196+
197+
调用 `model.test_step` 执行前向推理,并返回推理结果。通常情况下子类无需重写。
198+
199+
### `visualize()`
200+
201+
```{note}
202+
抽象方法,子类必须实现
203+
```
204+
205+
入参:
206+
207+
- inputs:输入数据,未经过预处理的原始数据。
208+
- preds:模型的预测结果
209+
- show:是否可视化
210+
- 其他参数:由用户传入,且在 `visualize_kwargs` 中指定
211+
212+
返回值:
213+
214+
- 可视化结果,类型通常为 `List[np.ndarray]`,以目标检测任务为例,列表中的每个元素应该是画完检测框后的图像,直接使用 `cv2.imshow` 就能可视化检测结果。不同任务的可视化流程有所不同,`visualize` 应该返回该领域内,适用于常见可视化流程的结果。
215+
216+
### `postprocess()`
217+
218+
```{note}
219+
抽象方法,子类必须实现
220+
```
221+
222+
入参:
223+
224+
- preds:模型预测结果,类型为 `list`,列表中的每个元素表示一个数据的预测结果。OpenMMLab 系列算法库中,预测结果中每个元素的类型均为 `BaseDataElement`
225+
- visualization:可视化结果
226+
- return_datasample:是否维持 datasample 返回。`False` 时转换成 `dict` 返回
227+
- 其他参数:由用户传入,且在 `postprocess_kwargs` 中指定
228+
229+
返回值:
230+
231+
- 可视化结果和预测结果,类型为一个字典。OpenMMLab 系列算法库要求返回的字典包含 `predictions``visualization` 两个 key。
232+
233+
### `__call__()`
234+
235+
入参:
236+
237+
- inputs:输入数据,通常为图片路径、或者图片数据组成的列表。`inputs` 中的每个元素也可以是其他类型的数据,只需要保证数据能够被 [\_init_pipeline](#initpipeline) 返回的 `pipeline` 处理即可。当 `inputs` 只含一个推理数据时,它可以不是一个 `list`,`__call__` 会在内部将 `inputs` 包装成列表,以便于后续处理
238+
- return_datasample:是否将 datasample 转换成 `dict` 返回
239+
- batch_size:推理的 batch size,会被进一步传给 `preprocess` 函数
240+
- 其他参数:分发给 `preprocess``forward``visualize``postprocess` 函数的额外参数
241+
242+
返回值:
243+
244+
- `postprocess` 返回的可视化结果和预测结果,类型为一个字典。OpenMMLab 系列算法库要求返回的字典包含 `predictions``visualization` 两个 key

docs/zh_cn/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
design/evaluation.md
5959
design/visualization.md
6060
design/logging.md
61+
design/infer.md
6162

6263
.. toctree::
6364
:maxdepth: 1

mmengine/registry/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
build_runner_from_cfg, build_scheduler_from_cfg)
44
from .default_scope import DefaultScope
55
from .registry import Registry
6-
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
7-
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
6+
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS,
7+
LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
88
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
99
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
1010
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
@@ -16,7 +16,7 @@
1616
'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
1717
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
1818
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
19-
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
19+
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR', 'INFERENCERS',
2020
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
2121
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
2222
'build_scheduler_from_cfg', 'init_default_scope'

mmengine/registry/root.py

+3
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,6 @@
5656

5757
# manage logprocessor
5858
LOG_PROCESSORS = Registry('log_processor')
59+
60+
# manage inferencer
61+
INFERENCERS = Registry('inferencer')

0 commit comments

Comments
 (0)