Skip to content

Commit c4b0e30

Browse files
authored
feat: add DeepLabV3, DeepLabV3+ example (#735)
1 parent c76842f commit c4b0e30

15 files changed

+1937
-0
lines changed

examples/seg/deeplabv3/README.md

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# DeepLabV3, DeeplabV3+ Based on MindCV Backbones
2+
3+
> DeeplabV3: [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
4+
>
5+
> DeeplabV3+:[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
6+
7+
## Introduction
8+
9+
**DeepLabV3** is a semantic segmentation architecture improved over previous version. Two main contributions of DeepLabV3 are as follows. 1) Modules are designed which employ atrous convolution in cascade or in parallel to capture multi-scale context by adopting multiple atrous rates to handle the problem of segmenting objects at multiple scale. 2) The Atrous Spatial Pyramid Pooling (ASPP) module is augmented with image-level features encoding global context and further boost performance. The improved ASPP applys global average pooling on the last feature map of the model, feeds the resulting image-level features to a 1 × 1 convolution with 256 filters (and batch normalization), and then bilinearly upsamples the feature to the desired spatial dimension. The DenseCRF post-processing from DeepLabV2 is deprecated.
10+
11+
<p align="center">
12+
<img src="https://github.com/mindspore-lab/mindcv/assets/33061146/db2076ed-bccd-455f-badb-e03deb131dc5" width=700/>
13+
</p>
14+
<p align="center">
15+
<em>Figure 1. Architecture of DeepLabV3 with output_stride=16 [<a href="#references">1</a>] </em>
16+
</p>
17+
18+
19+
20+
**DeepLabV3+** extends DeepLabv3 by adding a simple yet effective decoder module to refine the segmentation results especially along object boundaries. It combines advantages from Spatial pyramid pooling module and encode-decoder structure. The last feature map before logits in the origin deeplabv3 becomes the encoder output. The encoder features are first bilinearly upsampled by a factor of 4 and then concatenated with the corresponding low-level features from the network backbone that have the same spatial resolution. Another 1 × 1 convolution is applied on the low-level features to reduce the number of channels. After the concatenation, a few 3 × 3 convolutions are applied to refine the features followed by another simple bilinear upsampling by a factor of 4.
21+
22+
<p align="center">
23+
<img src="https://github.com/mindspore-lab/mindcv/assets/33061146/e1a17518-b19a-46f1-b28a-ec67cafa81be" width=700/>
24+
</p>
25+
<p align="center">
26+
<em>Figure 2. DeepLabv3+ extends DeepLabv3 by employing a encoderdecoder structure [<a href="#references">2</a>] </em>
27+
</p>
28+
29+
30+
This example provides implementations of DeepLabV3 and DeepLabV3+ using backbones from MindCV. More details about feature extraction of MindCV are in [this tutorial](https://github.com/mindspore-lab/mindcv/blob/main/docs/en/how_to_guides/feature_extraction.md). Note that the ResNet in DeepLab contains atrous convolutions with different rates, `dilated_resnet.py` is provided as a modification of ResNet from MindCV, with atrous convolutions in block 3-4.
31+
32+
## Quick Start
33+
34+
### Preparation
35+
36+
1. Clone MindCV repository, enter `mindcv` and assume we are always in this project root.
37+
38+
```shell
39+
git clone https://github.com/mindspore-lab/mindcv.git
40+
cd mindcv
41+
```
42+
43+
2. Install dependencies as shown [here](https://mindspore-lab.github.io/mindcv/installation/), and also install `cv2`, `addict`.
44+
45+
```shell
46+
pip install opencv-python
47+
pip install addict
48+
```
49+
50+
3. Prepare dataset
51+
52+
* Download Pascal VOC 2012 dataset, [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) and Semantic Boundaries Dataset, [SBD](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz).
53+
54+
* Prepare training and test data list files with the path to image and annotation pairs. You could simply run `python examples/seg/deeplabv3/preprocess/get_dataset_list.py --data_root=/path/to/data` to generate the list files. This command results in 5 data list files. The lines in a list file should be like as follows:
55+
56+
```
57+
/path/to/data/JPEGImages/2007_000032.jpg /path/to/data/SegmentationClassGray/2007_000032.png
58+
/path/to/data/JPEGImages/2007_000039.jpg /path/to/data/SegmentationClassGray/2007_000039.png
59+
/path/to/data/JPEGImages/2007_000063.jpg /path/to/data/SegmentationClassGray/2007_000063.png
60+
......
61+
```
62+
63+
* Convert training dataset to mindrecords by running ``build_seg_data.py`` script. In accord with paper, we train on *trainaug* dataset (*voc train* + *SBD*). You can train on other dataset by changing the data list path at keyword `data_list` with the path of your target training set.
64+
65+
```shell
66+
python examples/seg/deeplabv3/preprocess/build_seg_data.py \
67+
--data_root=[root path of training data] \
68+
--data_list=[path of data list file prepared above] \
69+
--dst_path=[path to save mindrecords] \
70+
--num_shards=8
71+
```
72+
73+
* Note: the training steps use datasets in mindrecord format, while the evaluation steps directly use the data list files.
74+
75+
4. Backbone: download pre-trained backbone from MindCV, here we use [ResNet101](https://download.mindspore.cn/toolkits/mindcv/resnet/resnet101-689c5e77.ckpt).
76+
77+
### Train
78+
79+
Specify `deeplabv3` or `deeplabv3plus` at the key word `model` in the config file.
80+
81+
It is highly recommended to use **distributed training** for this DeepLabV3 and DeepLabV3+ implementation.
82+
83+
For distributed training using **OpenMPI's `mpirun`**, simply run
84+
```shell
85+
mpirun -n [# of devices] python examples/seg/deeplabv3/train.py --config [the path to the config file]
86+
```
87+
88+
For distributed training with [Ascend rank table](https://github.com/mindspore-lab/mindocr/blob/main/docs/en/tutorials/distribute_train.md#12-configure-rank_table_file-for-training), configure `ascend8p.sh` as follows
89+
90+
```shell
91+
#!/bin/bash
92+
export DEVICE_NUM=8
93+
export RANK_SIZE=8
94+
export RANK_TABLE_FILE="./hccl_8p_01234567_127.0.0.1.json"
95+
96+
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
97+
export DEVICE_ID=$i
98+
export RANK_ID=$i
99+
python -u examples/seg/deeplabv3/train.py --config [the path to the config file] &> ./train_$i.log &
100+
done
101+
```
102+
103+
and start training by running:
104+
```shell l
105+
bash ascend8p.sh
106+
```
107+
108+
For single-device training, simply set the keyword ``distributed`` to ``False`` in the config file and run:
109+
```shell
110+
python examples/seg/deeplabv3/train.py --config [the path to the config file]
111+
```
112+
113+
**Take mpirun command as an example, the training steps are as follow**:
114+
115+
- Step 1: Employ output_stride=16 and fine-tune pretrained resnet101 on *trainaug* dataset. In config file, please specify the path of pretrained backbone checkpoint in keyword `backbone_ckpt_path` and set `output_stride` to `16`.
116+
117+
```shell
118+
# for deeplabv3
119+
mpirun -n 8 python examples/seg/deeplabv3/train.py --config examples/seg/deeplabv3/deeplabv3_s16_dilated_resnet101.yaml
120+
121+
# for deeplabv3+
122+
mpirun -n 8 python examples/seg/deeplabv3/train.py --config examples/seg/deeplabv3/deeplabv3plus_s16_dilated_resnet101.yaml
123+
```
124+
125+
- Step 2: Employ output_stride=8, fine-tune model from step 1 on *trainaug* dataset with smaller base learning rate. In config file, please specify the path of checkpoint from previous step in `ckpt_path`, set `ckpt_pre_trained` to `True` and set `output_stride` to `8` .
126+
127+
```shell
128+
# for deeplabv3
129+
mpirun -n 8 python examples/seg/deeplabv3/train.py --config examples/seg/deeplabv3/deeplabv3_s8_dilated_resnet101.yaml
130+
131+
# for deeplabv3+
132+
mpirun -n 8 python examples/seg/deeplabv3/train.py --config examples/seg/deeplabv3/deeplabv3plus_s8_dilated_resnet101.yaml
133+
```
134+
135+
### Test
136+
137+
For testing the trained model, first specify the path to the model checkpoint at keyword `ckpt_path` in the config file. You could modify `output_stride`, `flip`, `scales` in the config file during inference.
138+
139+
For example, after replacing `ckpt_path` in config file with [checkpoint](https://download.mindspore.cn/toolkits/mindcv/deeplabv3/deeplabv3_s8_resnet101-a297e7af.ckpt) from 2-step training of deeplabv3, commands below employ os=8 without left-right filpped or muticale inputs.
140+
```shell
141+
python examples/seg/deeplabv3/eval.py --config examples/seg/deeplabv3/deeplabv3_s8_dilated_resnet101.yaml
142+
```
143+
144+
## Results
145+
146+
### Config
147+
148+
| Model | OS=16 config | OS=8 config | Download |
149+
| :--------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
150+
| DeepLabV3 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/seg/deeplabv3/config/deeplabv3_s16_dilated_resnet101.yaml) | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/seg/deeplabv3/config/deeplabv3_s8_dilated_resnet101.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/deeplabv3/deeplabv3_dilated_resnet101-8614f6af.ckpt) |
151+
| DeepLabV3+ | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/seg/deeplabv3/config/deeplabv3plus_s16_dilated_resnet101.yaml) | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/seg/deeplabv3/config/deeplabv3plus_s8_dilated_resnet101.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/deeplabv3/deeplabv3plus_dilated_resnet101-59ea7d95.ckpt) |
152+
153+
### Model results
154+
155+
156+
| Model | Infer OS | MS | FLIP | mIoU |
157+
| :--------: | :------: | :--: | :--: | :---: |
158+
| DeepLabV3 | 16 | | | 77.33 |
159+
| DeepLabV3 | 8 | | | 79.16 |
160+
| DeepLabV3 | 8 || | 79.93 |
161+
| DeepLabV3 | 8 ||| 80.14 |
162+
| DeepLabV3+ | 16 | | | 78.99 |
163+
| DeepLabV3+ | 8 | | | 80.31 |
164+
| DeepLabV3+ | 8 || | 80.99 |
165+
| DeepLabV3+ | 8 ||| 81.10 |
166+
167+
**Note**: **OS**: output stride. **MS**: multiscale inputs during test. **Flip**: adding left-right flipped inputs during test. **Weights** are checkpoint files saved after two-step training.
168+
169+
As illustrated in [<a href="#references">1</a>], adding left-right flipped inputs or muilt-scale inputs during test could improve the performence. Also, once the model is finally trained, employed output_stride=8 during inference bring improvement over using output_stride=16.
170+
171+
172+
## References
173+
[1] Chen L C, Papandreou G, Schroff F, et al. Rethinking atrous convolution for semantic image segmentation[J]. arXiv preprint arXiv:1706.05587, 2017.
174+
175+
[2] Chen, Liang-Chieh, et al. "Encoder-decoder with atrous separable convolution for semantic image segmentation." *Proceedings of the European conference on computer vision (ECCV)*. 2018.

examples/seg/deeplabv3/callbacks.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import logging
2+
import os
3+
import stat
4+
5+
from postprocess import apply_eval
6+
7+
# from mindspore import log as logger
8+
from mindspore import save_checkpoint
9+
from mindspore.train.callback import Callback, CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor
10+
11+
_logger = logging.getLogger(__name__)
12+
13+
14+
class EvalCallBack(Callback):
15+
"""
16+
Evaluation callback when training.
17+
Args:
18+
eval_function (function): evaluation function.
19+
eval_param_dict (dict): evaluation parameters' configure dict.
20+
interval (int): run evaluation interval, default is 1.
21+
eval_start_epoch (int): evaluation start epoch, default is 1.
22+
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
23+
best_ckpt_name (str): best checkpoint name, default is `best.ckpt`.
24+
metrics_name (str): evaluation metrics name, default is `mIoU`.
25+
Returns:
26+
None
27+
Examples:
28+
>>> EvalCallBack(eval_function, eval_param_dict)
29+
"""
30+
31+
def __init__(
32+
self,
33+
eval_function,
34+
eval_param_dict,
35+
interval=1,
36+
eval_start_epoch=1,
37+
save_best_ckpt=True,
38+
ckpt_directory="./",
39+
best_ckpt_name="best.ckpt",
40+
metrics_name="mIoU",
41+
) -> None:
42+
super(EvalCallBack, self).__init__()
43+
self.eval_function = eval_function
44+
self.eval_param_dict = eval_param_dict
45+
self.eval_start_epoch = eval_start_epoch
46+
47+
if interval < 1:
48+
raise ValueError("interval should >= 1.")
49+
50+
self.interval = interval
51+
self.save_best_ckpt = save_best_ckpt
52+
self.best_res = 0
53+
self.best_epoch = 0
54+
55+
if not os.path.isdir(ckpt_directory):
56+
os.makedirs(ckpt_directory)
57+
58+
self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
59+
self.metrics_name = metrics_name
60+
61+
def remove_ckpoint_file(self, file_name):
62+
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
63+
try:
64+
os.chmod(file_name, stat.S_IWRITE)
65+
os.remove(file_name)
66+
except OSError:
67+
_logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
68+
except ValueError:
69+
_logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
70+
71+
def on_train_epoch_end(self, run_context):
72+
"""Callback when epoch end."""
73+
cb_params = run_context.original_args()
74+
cur_epoch = cb_params.cur_epoch_num
75+
76+
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
77+
res = self.eval_function(self.eval_param_dict)
78+
_logger.info("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
79+
80+
if res >= self.best_res:
81+
self.best_res = res
82+
self.best_epoch = cur_epoch
83+
_logger.info("update best result: {}".format(res), flush=True)
84+
85+
if self.save_best_ckpt:
86+
if os.path.exists(self.best_ckpt_path):
87+
self.remove_ckpoint_file(self.best_ckpt_path)
88+
89+
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
90+
_logger.info("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
91+
92+
def on_train_end(self, run_context):
93+
_logger.info(
94+
"End training, the best {0} is: {1}, the best {0} epoch is {2}".format(
95+
self.metrics_name, self.best_res, self.best_epoch
96+
),
97+
flush=True,
98+
)
99+
100+
101+
def get_segment_train_callback(args, steps_per_epoch, rank_id):
102+
callbacks = [TimeMonitor(data_size=steps_per_epoch), LossMonitor()]
103+
if rank_id == 0:
104+
ckpt_config = CheckpointConfig(
105+
save_checkpoint_steps=args.save_steps,
106+
keep_checkpoint_max=args.keep_checkpoint_max,
107+
)
108+
prefix_name = str(args.model) + "_s" + str(args.output_stride) + "_" + args.backbone
109+
ckpt_cb = ModelCheckpoint(prefix=prefix_name, directory=args.ckpt_save_dir, config=ckpt_config)
110+
callbacks.append(ckpt_cb)
111+
return callbacks
112+
113+
114+
def get_segment_eval_callback(eval_model, eval_dataset, args):
115+
eval_param_dict = {"net": eval_model, "dataset": eval_dataset, "args": args}
116+
117+
eval_cb = EvalCallBack(
118+
eval_function=apply_eval,
119+
eval_param_dict=eval_param_dict,
120+
interval=args.eval_interval,
121+
eval_start_epoch=args.eval_start_epoch,
122+
save_best_ckpt=True,
123+
ckpt_directory=args.ckpt_save_dir,
124+
best_ckpt_name="best.ckpt",
125+
metrics_name="mIoU",
126+
)
127+
128+
return eval_cb
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# finetune on os=16 with pretrained backbone: resnet101 from mindcv
2+
3+
# system
4+
seed: 1
5+
mode: 1
6+
distribute: True
7+
num_parallel_workers: 8
8+
device_target: "Ascend"
9+
all_reduce_fusion_config: [90, 183, 279]
10+
11+
# dataset
12+
dataset: "vocaug"
13+
data_dir: "/path/to/vocaug0"
14+
batch_size: 32
15+
crop_size: 513
16+
image_mean: [103.53, 116.28, 123.675]
17+
image_std: [57.375, 57.120, 58.395]
18+
max_scale: 2.0
19+
min_scale: 0.5
20+
ignore_label: 255
21+
num_classes: 21
22+
shuffle: True
23+
24+
# backbone
25+
backbone: "dilated_resnet101"
26+
backbone_ckpt_path: "/path/to/resnet101-689c5e77.ckpt"
27+
backbone_ckpt_auto_mapping: False
28+
backbone_features_only: True
29+
backbone_out_indices: [4]
30+
output_stride: 16
31+
32+
# model
33+
model: "deeplabv3"
34+
ckpt_pre_trained: False
35+
ckpt_path: ""
36+
amp_level: "O3"
37+
amp_cast_list: None
38+
39+
# scheduler
40+
scheduler: "cosine_decay"
41+
lr: 0.08
42+
min_lr: 0.0000001
43+
decay_epochs: 40000
44+
decay_rate: 0.1
45+
epoch_size: 300
46+
lr_epoch_stair: False
47+
48+
# optimizer
49+
loss_scale_type: "fixed"
50+
drop_overflow_update: False
51+
loss_scale: 3072.0
52+
momentum: 0.9
53+
weight_decay: 0.0001
54+
filter_bias_and_bn: False
55+
gradient_accumulation_steps: 1
56+
57+
# callbacks
58+
save_steps: 410
59+
keep_checkpoint_max: 2
60+
ckpt_save_dir: "./ckpt"
61+
62+
# eval
63+
eval_while_train: True
64+
eval_data_lst: "/path/to/voc_val_lst.txt"
65+
data_root: ""
66+
eval_processing_log: False
67+
eval_interval: 2
68+
eval_start_epoch: 50
69+
input_format: "NCHW"
70+
flip: False
71+
scales: [1.0,]
72+
# scales: [0.5, 0.75, 1.0, 1.25, 1.75]

0 commit comments

Comments
 (0)