Skip to content

Commit

Permalink
[Doc] Refine runner doc. (open-mmlab#178)
Browse files Browse the repository at this point in the history
* [Doc] Refine runner doc.

* resolve comments
  • Loading branch information
RangiLyu authored Apr 20, 2022
1 parent 5571320 commit ecf816e
Showing 1 changed file with 72 additions and 30 deletions.
102 changes: 72 additions & 30 deletions docs/zh_cn/tutorials/runner.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,18 @@ runner.train()
model = FasterRCNN()
test_dataset = CocoDataset()
test_dataloader = Dataloader(dataset=test_dataset, batch_size=2, num_workers=2)
evaluator = CocoEvaluator(metric='bbox')
metric = CocoMetric()
test_evaluator = Evaluator(metric)

# 初始化执行器
runner = Runner(model=model, test_dataloader=test_dataloader, evaluator=evaluator,
load_checkpoint='./faster_rcnn.pth')
runner = Runner(model=model, test_dataloader=test_dataloader, test_evaluator=test_evaluator,
load_from='./faster_rcnn.pth')

# 执行测试
runner.test()
```

这个例子中我们手动构建了一个 Faster R-CNN 检测模型,以及测试用的 COCO 数据集和对应的 COCO 评测器,并使用这些模块初始化执行器,最后通过调用执行器的 `test` 函数进行模型测试。
这个例子中我们手动构建了一个 Faster R-CNN 检测模型,以及测试用的 COCO 数据集和使用 COCO 指标的评测器,并使用这些模块初始化执行器,最后通过调用执行器的 `test` 函数进行模型测试。

### 通过配置文件使用执行器

Expand Down Expand Up @@ -146,12 +147,13 @@ test_dataloader = ...
optimizer = dict(type='SGD', lr=0.01)
# 参数调度器配置
param_scheduler = dict(type='MultiStepLR', milestones=[80, 90])
#评测器配置
evaluator = dict(type='Accuracy')
#验证和测试的评测器配置
val_evaluator = dict(type='Accuracy')
test_evaluator = dict(type='Accuracy')

# 训练、验证、测试流程配置
train_cfg = dict(by_epoch=True, max_epochs=100)
validation_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
val_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
test_cfg = dict()

# 自定义钩子
Expand All @@ -163,20 +165,40 @@ default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1), # 模型保存钩子
logger=dict(type='TextLoggerHook'), # 训练日志钩子
optimizer=dict(type='OptimzierHook', grad_clip=False), # 优化器钩子
param_scheduler=dict(type='ParamSchedulerHook')) # 参数调度器执行钩子
param_scheduler=dict(type='ParamSchedulerHook'), # 参数调度器执行钩子
sampler_seed=dict(type='DistSamplerSeedHook')) # 为每轮次的数据采样设置随机种子的钩子

# 环境配置
env_cfg = dict(
dist_params=dict(backend='nccl'),
cudnn_benchmark=False,
dist_cfg=dict(backend='nccl'),
mp_cfg=dict(mp_start_method='fork')
)
# 系统日志配置
log_cfg = dict(log_level='INFO')
# 日志等级配置
log_level = 'INFO'

# 加载权重
load_from = None
# 恢复训练
resume = False
```

一个完整的配置文件主要由模型、数据、优化器、参数调度器、评测器等模块的配置,训练、验证、测试等流程的配置,还有执行流程过程中的各种钩子模块的配置,以及环境和日志等其他配置的字段组成。
通过配置文件构建的执行器采用了懒初始化 (lazy initialization),只有当调用到训练或测试等执行函数时,才会根据配置文件去完整初始化所需要的模块。

## 加载权重或恢复训练

执行器可以通过 `load_from` 参数加载检查点(checkpoint)文件中的模型权重,只需要将 `load_from` 参数设置为检查点文件的路径即可。

```python
runner = Runner(model=model, test_dataloader=test_dataloader, test_evaluator=test_evaluator,
load_from='./faster_rcnn.pth')
```

如果是通过配置文件使用执行器,只需修改配置文件中的 `load_from` 字段即可。

用户也可通过设置 `resume=True` 来,加载检查点中的训练状态信息来恢复训练。当 `load_from``resume=True` 同时被设置时,执行器将加载 `load_from` 路径对应的检查点文件中的训练状态。如果仅设置 `resume=True`,执行器将会尝试从 `work_dir` 文件夹中寻找并读取最新的检查点文件。

## 进阶使用

MMEngine 中的默认执行器能够完成大部分的深度学习任务,但不可避免会存在无法满足的情况。有的用户希望能够对执行器进行更多自定义修改,因此,MMEngine 支持自定义模型的训练、验证以及测试的流程。
Expand All @@ -195,48 +217,68 @@ MMEngine 内提供了四种默认的循环:

用户可以通过继承循环基类来实现自己的训练流程。循环基类需要提供两个输入:`runner` 执行器的实例和 `loader` 循环所需要迭代的迭代器。
用户如果有自定义的需求,也可以增加更多的输入参数。MMEngine 中同样提供了 LOOPS 注册器对循环类进行管理,用户可以向注册器内注册自定义的循环模块,
然后在配置文件的 `train_cfg``validation_cfg``test_cfg` 中增加 `type` 字段来指定使用何种循环。
然后在配置文件的 `train_cfg``val_cfg``test_cfg` 中增加 `type` 字段来指定使用何种循环。
用户可以在自定义的循环中实现任意的执行逻辑,也可以增加或删减钩子(hook)点位,但需要注意的是一旦钩子点位被修改,默认的钩子函数可能不会被执行,导致一些训练过程中默认发生的行为发生变化。
因此,我们强烈建议用户按照本文档中定义的循环执行流程图以及[钩子规范](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/hook.html) 去重载循环基类。

```python
from mmengine.registry import LOOPS
from mmengine.registry import LOOPS, HOOKS
from mmengine.runner.loop import BaseLoop
from mmengine.hooks import Hook


# 自定义验证循环
@LOOPS.register_module()
class CustomValLoop(BaseLoop):
def __init__(self, runner, loader, evaluator, loader2):
super().__init__(runner, loader, evaluator)
self.loader2 = runner.build_dataloader(loader2)
def __init__(self, runner, dataloader, evaluator, dataloader2):
super().__init__(runner, dataloader, evaluator)
self.dataloader2 = runner.build_dataloader(dataloader2)

def run(self):
self.runner.call_hooks('before_val_epoch')
for idx, databatch in enumerate(self.loader):
self.runner.call_hooks('before_val_iter',
args=dict(databatch=databatch))
outputs = self.run_iter(idx, databatch)
self.runner.call_hooks('after_val_iter',
args=dict(databatch=databatch, outputs=outputs))
for idx, data_batch in enumerate(self.dataloader):
self.runner.call_hooks(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
outputs = self.run_iter(idx, data_batch)
self.runner.call_hooks(
'after_val_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
metric = self.evaluator.evaluate()
for idx, databatch in enumerate(self.loader2):
self.runner.call_hooks('before_val_iter2',
args=dict(databatch=databatch))
self.run_iter(idx, databatch)
self.runner.call_hooks('after_val_iter2',
args=dict(databatch=databatch, outputs=outputs))

# 增加额外的验证循环
for idx, data_batch in enumerate(self.dataloader2):
# 增加额外的钩子点位
self.runner.call_hooks(
'before_valloader2_iter', batch_idx=idx, data_batch=data_batch)
self.run_iter(idx, data_batch)
# 增加额外的钩子点位
self.runner.call_hooks(
'after_valloader2_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
metric2 = self.evaluator.evaluate()

...

self.runner.call_hooks('after_val_epoch')


# 定义额外点位的钩子类
@HOOKS.register_module()
class CustomValHook(Hook):
def before_valloader2_iter(self, batch_idx, data_batch):
...

def after_valloader2_iter(self, batch_idx, data_batch, outputs):
...

```

上面的例子中实现了一个与默认验证循环不一样的自定义验证循环,它在两个不同的验证集上进行验证,同时对第二次验证增加了额外的钩子点位,并在最后对两个验证结果进行进一步的处理。在实现了自定义的循环类之后,
只需要在配置文件的 `validation_cfg` 内设置 `type='CustomValLoop'`,并添加额外的配置即可。
只需要在配置文件的 `val_cfg` 内设置 `type='CustomValLoop'`,并添加额外的配置即可。

```python
validation_cfg = dict(type='CustomValLoop', loader2=dict(dataset=dict(type='ValDataset2'), ...))
# 自定义验证循环
val_cfg = dict(type='CustomValLoop', dataloader2=dict(dataset=dict(type='ValDataset2'), ...))
# 额外点位的钩子
custom_hooks = [dict(type='CustomValHook')]
```

### 自定义执行器
Expand Down

0 comments on commit ecf816e

Please sign in to comment.