Skip to content

Commit

Permalink
Merge 0f5b626 into 194c1c9
Browse files Browse the repository at this point in the history
  • Loading branch information
MengzhangLI authored Jun 30, 2021
2 parents 194c1c9 + 0f5b626 commit 99d0571
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O

## 安装

请参考[快速入门文档](docs/get_started.md#installation)进行安装和数据集准备。
请参考[快速入门文档](docs_zh-CN/get_started.md#installation)进行安装和数据集准备。

## 快速入门

请参考[训练教程](docs/train.md)[测试教程](docs/inference.md)学习 MMSegmentation 的基本使用。
请参考[训练教程](docs_zh-CN/train.md)[测试教程](docs_zh-CN/inference.md)学习 MMSegmentation 的基本使用。
我们也提供了一些进阶教程,内容覆盖了[增加自定义数据集](docs/tutorials/customize_datasets.md)[设计新的数据预处理流程](docs/tutorials/data_pipeline.md)[增加自定义模型](docs/tutorials/customize_models.md)[增加自定义的运行时配置](docs/tutorials/customize_runtime.md)
除此之外,我们也提供了很多实用的[训练技巧说明](docs/tutorials/training_tricks.md)

Expand Down
52 changes: 51 additions & 1 deletion docs_zh-CN/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
@@ -1 +1,51 @@
# 教程 5: 训练小技巧
# 教程 5: 训练技巧

MMSegmentation 支持如下训练技巧:

## 主干网络和头组件使用不同的学习率 (Learning Rate, LR)

在语义分割里,一些方法会让头组件的学习率大于主干网络的学习率,这样可以获得更好的表现或更快的收敛。

在 MMSegmentation 里面,您也可以在配置文件里添加如下行来让头组件的学习率是主干组件的10倍。

```python
optimizer=dict(
paramwise_cfg = dict(
custom_keys={
'head': dict(lr_mult=10.)}))
```

通过这种修改,任何被分组到 `'head'` 的参数的学习率都将乘以10。您也可以参照 [MMCV 文档](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.DefaultOptimizerConstructor) 获取更详细的信息。

## 在线难样本挖掘 (Online Hard Example Mining, OHEM)

对于训练时采样,我们在 [这里](https://github.com/open-mmlab/mmsegmentation/tree/master/mmseg/core/seg/sampler) 做了像素采样器。
如下例子是使用 PSPNet 训练并采用 OHEM 策略的配置:

```python
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
decode_head=dict(
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
```

通过这种方式,只有置信分数在0.7以下的像素值点会被拿来训练。在训练时我们至少要保留100000个像素值点。如果 `thresh` 并未被指定,前 ``min_kept``
个损失的像素值点才会被选择。

## 类别平衡损失 (Class Balanced Loss)

对于没有平衡类别分布的数据集,您也许可以改变每个类别的损失权重。这里以 cityscapes 数据集为例:

```python
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
decode_head=dict(
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,
# DeepLab 对 cityscapes 使用这种权重
class_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))
```

`class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss)

0 comments on commit 99d0571

Please sign in to comment.