|
| 1 | +# 推理接口 |
| 2 | + |
| 3 | +基于 MMEngine 开发时,我们通常会为具体算法定义一个配置文件,根据配置文件去构建[执行器](./runner.md),执行训练、测试流程,并保存训练好的权重。基于训练好的模型进行推理时,通常需要执行以下步骤: |
| 4 | + |
| 5 | +1. 基于配置文件构建模型 |
| 6 | +2. 加载模型权重 |
| 7 | +3. 搭建数据预处理流程 |
| 8 | +4. 执行模型前向推理 |
| 9 | +5. 可视化推理结果 |
| 10 | +6. 输出推理结果 |
| 11 | + |
| 12 | +对于此类标准的推理流程,MMEngine 提供了统一的推理接口,并且建议用户基于这一套接口规范来开发推理代码。 |
| 13 | + |
| 14 | +## 使用样例 |
| 15 | + |
| 16 | +### 定义推理器 |
| 17 | + |
| 18 | +基于 `BaseInferencer` 实现自定义的推理器 |
| 19 | + |
| 20 | +```python |
| 21 | +from mmengine.infer import BaseInferencer |
| 22 | + |
| 23 | +class CustomInferencer(BaseInferencer) |
| 24 | + ... |
| 25 | +``` |
| 26 | + |
| 27 | +具体细节参考[开发规范](#推理接口开发规范) |
| 28 | + |
| 29 | +### 构建推理器 |
| 30 | + |
| 31 | +**基于配置文件路径构建推理器** |
| 32 | + |
| 33 | +```python |
| 34 | +cfg = 'path/to/config.py' |
| 35 | +weight = 'path/to/weight.pth' |
| 36 | + |
| 37 | +inferencer = CustomInferencer(model=cfg, weight=weight) |
| 38 | +``` |
| 39 | + |
| 40 | +**基于配置类实例构建推理器** |
| 41 | + |
| 42 | +```python |
| 43 | +from mmengine import Config |
| 44 | + |
| 45 | +cfg = Config.fromfile('path/to/config.py') |
| 46 | +weight = 'path/to/weight.pth' |
| 47 | + |
| 48 | +inferencer = CustomInferencer(model=cfg, weight=weight) |
| 49 | +``` |
| 50 | + |
| 51 | +**基于 model-index 中定义的 model name 构建推理器**,以 MMDetection 中的 [atss 检测器为例](https://github.com/open-mmlab/mmdetection/blob/31c84958f54287a8be2b99cbf87a6dcf12e57753/configs/atss/metafile.yml#L22),model name 为 `atss_r50_fpn_1x_coco`,由于 model-index 中已经定义了 weight 的路径,因此可以不配置 weight 参数。 |
| 52 | + |
| 53 | +```python |
| 54 | +inferencer = CustomInferencer(model='atss_r50_fpn_1x_coco') |
| 55 | +``` |
| 56 | + |
| 57 | +### 执行推理 |
| 58 | + |
| 59 | +**推理单张图片** |
| 60 | + |
| 61 | +```python |
| 62 | +# 输入为图片路径 |
| 63 | +img = 'path/to/img.jpg' |
| 64 | +result = inferencer(img) |
| 65 | + |
| 66 | +# 输入为读取的图片(类型为 np.ndarray) |
| 67 | +img = cv2.imread('path/to/img.jpg') |
| 68 | +result = inferencer(img) |
| 69 | + |
| 70 | +# 输入为 url |
| 71 | +img = 'https://xxx.com/img.jpg' |
| 72 | +result = inferencer(img) |
| 73 | +``` |
| 74 | + |
| 75 | +**推理多张图片** |
| 76 | + |
| 77 | +```python |
| 78 | +img_dir = 'path/to/directory' |
| 79 | +result = inferencer(img_dir) |
| 80 | +``` |
| 81 | + |
| 82 | +```{note} |
| 83 | +OpenMMLab 系列算法库要求 `inferencer(img)` 输出一个 `dict`,其中包含 `visualization: list` 和 `predictions: list` 两个字段,分别对应可视化结果和预测结果。 |
| 84 | +``` |
| 85 | + |
| 86 | +## 推理接口开发规范 |
| 87 | + |
| 88 | +inferencer 执行推理时,通常会执行以下步骤: |
| 89 | + |
| 90 | +1. preprocess:输入数据预处理,包括数据读取、数据预处理、数据格式转换等 |
| 91 | +2. forward: 模型前向推理 |
| 92 | +3. visualize:预测结果可视化 |
| 93 | +4. postprocess:预测结果后处理,包括结果格式转换、导出预测结果等 |
| 94 | + |
| 95 | +为了优化 inferencer 的使用体验,我们不希望使用者在执行推理时,需要为每个过程都配置一遍参数。换句话说,我们希望使用者可以在不感知上述流程的情况下,简单为 `__call__` 接口配置参数,即可完成推理。 |
| 96 | + |
| 97 | +`__call__` 接口会按照顺序执行上述步骤,但是本身却不知道使用者传入的参数需要分发给哪个步骤,因此开发者在实现 `CustomInferencer` 时,需要定义 `preprocess_kwargs`,`forward_kwargs`,`visualize_kwargs`,`postprocess_kwargs` 4 个类属性,每个属性均为一个字符集合(`Set[str]`),用于指定 `__call__` 接口中的参数对应哪个步骤: |
| 98 | + |
| 99 | +```python |
| 100 | +class CustomInferencer(BaseInferencer): |
| 101 | + preprocess_kwargs = {'a'} |
| 102 | + forward_kwargs = {'b'} |
| 103 | + visualize_kwargs = {'c'} |
| 104 | + postprocess_kwargs = {'d'} |
| 105 | + |
| 106 | + def preprocess(self, inputs, batch_size=1, a=None): |
| 107 | + pass |
| 108 | + |
| 109 | + def forward(self, inputs, b=None): |
| 110 | + pass |
| 111 | + |
| 112 | + def visualize(self, inputs, preds, show, c=None): |
| 113 | + pass |
| 114 | + |
| 115 | + def postprocess(self, preds, visualization, return_datasample=False, d=None): |
| 116 | + pass |
| 117 | + |
| 118 | + def __call__( |
| 119 | + self, |
| 120 | + inputs, |
| 121 | + batch_size=1, |
| 122 | + show=True, |
| 123 | + return_datasample=False, |
| 124 | + a=None, |
| 125 | + b=None, |
| 126 | + c=None, |
| 127 | + d=None): |
| 128 | + return super().__call__( |
| 129 | + inputs, batch_size, show, return_datasample, a=a, b=b, c=c, d=d) |
| 130 | +``` |
| 131 | + |
| 132 | +上述代码中,`preprocess`,`forward`,`visualize`,`postprocess` 四个函数的 `a`,`b`,`c`,`d` 为用户可以传入的额外参数(`inputs`, `preds` 等参数在 `__call__` 的执行过程中会被自动填入),因此开发者需要在类属性 `preprocess_kwargs`,`forward_kwargs`,`visualize_kwargs`,`postprocess_kwargs` 中指定这些参数,这样 `__call__` 阶段用户传入的参数就可以被正确分发给对应的步骤。分发过程由 `BaseInferencer.__call__` 函数实现,开发者无需关心。 |
| 133 | + |
| 134 | +此外,我们需要将 `CustomInferencer` 注册到自定义注册器或者 MMEngine 的注册器中 |
| 135 | + |
| 136 | +```python |
| 137 | +from mmseg.registry import INFERENCERS |
| 138 | +# 也可以注册到 MMEngine 的注册中 |
| 139 | +# from mmengine.registry import INFERENCERS |
| 140 | + |
| 141 | +@INFERENCERS.register_module() |
| 142 | +class CustomInferencer(BaseInferencer): |
| 143 | + ... |
| 144 | +``` |
| 145 | + |
| 146 | +```{note} |
| 147 | +OpenMMLab 系列算法仓库必须将 Inferencer 注册到下游仓库的注册器,而不能注册到 MMEngine 的根注册器(避免重名)。 |
| 148 | +``` |
| 149 | + |
| 150 | +**核心接口说明**: |
| 151 | + |
| 152 | +### `__init__()` |
| 153 | + |
| 154 | +`BaseInferencer.__init__` 已经实现了[使用样例](#构建推理器)中构建推理器的逻辑,因此通常情况下不需要重写 `__init__` 函数。如果想实现自定义的加载配置文件、权重初始化、pipeline 初始化等逻辑,也可以重写 `__init__` 方法。 |
| 155 | + |
| 156 | +### `_init_pipeline()` |
| 157 | + |
| 158 | +```{note} |
| 159 | +抽象方法,子类必须实现 |
| 160 | +``` |
| 161 | + |
| 162 | +初始化并返回 inferencer 所需的 pipeline。pipeline 用于单张图片,类似于 OpenMMLab 系列算法库中定义的 `train_pipeline`,`test_pipeline`。使用者调用 `__call__` 接口传入的每个 `inputs`,都会经过 pipeline 处理,组成 batch data,然后传入 `forward` 方法。 |
| 163 | + |
| 164 | +### `_init_collate()` |
| 165 | + |
| 166 | +初始化并返回 inferencer 所需的 `collate_fn`,其值等价于训练过程中 Dataloader 的 `collate_fn`。`BaseInferencer` 默认会从 `test_dataloader` 的配置中获取 `collate_fn`,因此通常情况下不需要重写 `_init_collate` 函数。 |
| 167 | + |
| 168 | +### `_init_visualizer()` |
| 169 | + |
| 170 | +初始化并返回 inferencer 所需的 `visualizer`,其值等价于训练过程中 `visualizer`。`BaseInferencer` 默认会从 `visualizer` 的配置中获取 `visualizer`,因此通常情况下不需要重写 `_init_visualizer` 函数。 |
| 171 | + |
| 172 | +### `preprocess()` |
| 173 | + |
| 174 | +入参: |
| 175 | + |
| 176 | +- inputs:输入数据,由 `__call__` 传入,通常为图片路径或者图片数据组成的列表 |
| 177 | +- batch_size:batch 大小,由使用者在调用 `__call__` 时传入 |
| 178 | +- 其他参数:由用户传入,且在 `preprocess_kwargs` 中指定 |
| 179 | + |
| 180 | +返回值: |
| 181 | + |
| 182 | +- 生成器,每次迭代返回一个 batch 的数据。 |
| 183 | + |
| 184 | +`preprocess` 默认是一个生成器函数,将 `pipeline` 及 `collate_fn` 应用于输入数据,生成器迭代返回的是组完 batch,预处理后的结果。通常情况下子类无需重写。 |
| 185 | + |
| 186 | +### `forward()` |
| 187 | + |
| 188 | +入参: |
| 189 | + |
| 190 | +- inputs:输入数据,由 `preprocess` 处理后的 batch data |
| 191 | +- 其他参数:由用户传入,且在 `forward_kwargs` 中指定 |
| 192 | + |
| 193 | +返回值: |
| 194 | + |
| 195 | +- 预测结果,默认类型为 `List[BaseDataElement]` |
| 196 | + |
| 197 | +调用 `model.test_step` 执行前向推理,并返回推理结果。通常情况下子类无需重写。 |
| 198 | + |
| 199 | +### `visualize()` |
| 200 | + |
| 201 | +```{note} |
| 202 | +抽象方法,子类必须实现 |
| 203 | +``` |
| 204 | + |
| 205 | +入参: |
| 206 | + |
| 207 | +- inputs:输入数据,未经过预处理的原始数据。 |
| 208 | +- preds:模型的预测结果 |
| 209 | +- show:是否可视化 |
| 210 | +- 其他参数:由用户传入,且在 `visualize_kwargs` 中指定 |
| 211 | + |
| 212 | +返回值: |
| 213 | + |
| 214 | +- 可视化结果,类型通常为 `List[np.ndarray]`,以目标检测任务为例,列表中的每个元素应该是画完检测框后的图像,直接使用 `cv2.imshow` 就能可视化检测结果。不同任务的可视化流程有所不同,`visualize` 应该返回该领域内,适用于常见可视化流程的结果。 |
| 215 | + |
| 216 | +### `postprocess()` |
| 217 | + |
| 218 | +```{note} |
| 219 | +抽象方法,子类必须实现 |
| 220 | +``` |
| 221 | + |
| 222 | +入参: |
| 223 | + |
| 224 | +- preds:模型预测结果,类型为 `list`,列表中的每个元素表示一个数据的预测结果。OpenMMLab 系列算法库中,预测结果中每个元素的类型均为 `BaseDataElement` |
| 225 | +- visualization:可视化结果 |
| 226 | +- return_datasample:是否维持 datasample 返回。`False` 时转换成 `dict` 返回 |
| 227 | +- 其他参数:由用户传入,且在 `postprocess_kwargs` 中指定 |
| 228 | + |
| 229 | +返回值: |
| 230 | + |
| 231 | +- 可视化结果和预测结果,类型为一个字典。OpenMMLab 系列算法库要求返回的字典包含 `predictions` 和 `visualization` 两个 key。 |
| 232 | + |
| 233 | +### `__call__()` |
| 234 | + |
| 235 | +入参: |
| 236 | + |
| 237 | +- inputs:输入数据,通常为图片路径、或者图片数据组成的列表。`inputs` 中的每个元素也可以是其他类型的数据,只需要保证数据能够被 [\_init_pipeline](#initpipeline) 返回的 `pipeline` 处理即可。当 `inputs` 只含一个推理数据时,它可以不是一个 `list`,`__call__` 会在内部将 `inputs` 包装成列表,以便于后续处理 |
| 238 | +- return_datasample:是否将 datasample 转换成 `dict` 返回 |
| 239 | +- batch_size:推理的 batch size,会被进一步传给 `preprocess` 函数 |
| 240 | +- 其他参数:分发给 `preprocess`、`forward`、`visualize`、`postprocess` 函数的额外参数 |
| 241 | + |
| 242 | +返回值: |
| 243 | + |
| 244 | +- `postprocess` 返回的可视化结果和预测结果,类型为一个字典。OpenMMLab 系列算法库要求返回的字典包含 `predictions` 和 `visualization` 两个 key |
0 commit comments