Skip to content

Commit bcafcdd

Browse files
authored
[Feature] Add segformer decode head and related train config (open-mmlab#599)
* [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * [Feature] Add segformer decode head and related train config * Add ade20K trainval support for segformer 1. Add related train and val configs; 2. Add AlignedResize; * Set arg: find_unused_parameters = True * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * Replace Linear Layer to 1X1 Conv * Use nn.ModuleList to refactor segformer head. * Remove local to_xtuple * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * Modify configs of segformer. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Support progressive test with fewer memory cost. * Modify default value of pad_to_patch_size arg. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Fix some bugs about model loading and eval hook. * Add ade20k 640x640 dataset. * Fix related segformer configs. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Modify error patch size. * Fix pretrain of mit_b0 * Fix the test api error. * Modify dataset base config. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Add part of benchmark results. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * Update readme. * Update readme of segformer. * Updata readme of segformer. * Update segformer readme and fix segformer mit_b4. * Update readme of segformer. * Clean AlignedResize related config. * Clean code from pr open-mmlab#709 * Clean code from pr open-mmlab#709 * Add 512x512 segformer_mit-b5. * Fix lint. * Fix some segformer head bugs. * Add segformer unit tests. * Replace AlignedResize to ResizeToMultiple. * Modify readme of segformer. * Fix bug of ResizeToMultiple. * Add ResizeToMultiple unit tests. * Resolve conflict. * Simplify the implementation of ResizeToMultiple. * Update test results. * Fix multi-scale test error when resize_ratio=1.75 and input size=640x640. * Update segformer results. * Update Segformer results. * Fix some url bugs and pipelines bug. * Move ckpt convertion to tools. * Add segformer official pretrain weights usage. * Clean redundant codes. * Remove redundant codes. * Unfied format. * Add description for segformer converter. * Update workers.
1 parent f6dca38 commit bcafcdd

18 files changed

+494
-62
lines changed
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained=None,
6+
backbone=dict(
7+
type='MixVisionTransformer',
8+
in_channels=3,
9+
embed_dims=32,
10+
num_stages=4,
11+
num_layers=[2, 2, 2, 2],
12+
num_heads=[1, 2, 5, 8],
13+
patch_sizes=[7, 3, 3, 3],
14+
sr_ratios=[8, 4, 2, 1],
15+
out_indices=(0, 1, 2, 3),
16+
mlp_ratio=4,
17+
qkv_bias=True,
18+
drop_rate=0.0,
19+
attn_drop_rate=0.0,
20+
drop_path_rate=0.1),
21+
decode_head=dict(
22+
type='SegformerHead',
23+
in_channels=[32, 64, 160, 256],
24+
in_index=[0, 1, 2, 3],
25+
channels=256,
26+
dropout_ratio=0.1,
27+
num_classes=19,
28+
norm_cfg=norm_cfg,
29+
align_corners=False,
30+
loss_decode=dict(
31+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
32+
# model training and testing settings
33+
train_cfg=dict(),
34+
test_cfg=dict(mode='whole'))

configs/segformer/readme.md

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
2+
3+
## Introduction
4+
5+
<!-- [ALGORITHM] -->
6+
7+
```latex
8+
@article{xie2021segformer,
9+
title={SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers},
10+
author={Xie, Enze and Wang, Wenhai and Yu, Zhiding and Anandkumar, Anima and Alvarez, Jose M and Luo, Ping},
11+
journal={arXiv preprint arXiv:2105.15203},
12+
year={2021}
13+
}
14+
```
15+
16+
## Results and models
17+
18+
### ADE20k
19+
20+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
21+
| ------ | -------- | --------- | ------: | -------: | -------------- | ---: | ------------- | ------ | -------- |
22+
|Segformer | MIT-B0 | 512x512 | 160000 | 2.1 | 51.32 | 37.41 | 38.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530.log.json) |
23+
|Segformer | MIT-B1 | 512x512 | 160000 | 2.6 | 47.66 | 40.97 | 42.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106.log.json) |
24+
|Segformer | MIT-B2 | 512x512 | 160000 | 3.6 | 30.88 | 45.58 | 47.03 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103.log.json) |
25+
|Segformer | MIT-B3 | 512x512 | 160000 | 4.8 | 22.11 | 47.82 | 48.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410.log.json) |
26+
|Segformer | MIT-B4 | 512x512 | 160000 | 6.1 | 15.45 | 48.46 | 49.76 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055.log.json) |
27+
|Segformer | MIT-B5 | 512x512 | 160000 | 7.2 | 11.89 | 49.13 | 50.22 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235.log.json) |
28+
|Segformer | MIT-B5 | 640x640 | 160000 | 11.5 | 11.30 | 49.62 | 50.36 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243.log.json) |
29+
30+
Evaluation with AlignedResize:
31+
32+
| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) |
33+
| ------ | -------- | --------- | ------: | ---: | ------------- |
34+
|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 |
35+
|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 |
36+
|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 |
37+
|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 |
38+
|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 |
39+
|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 |
40+
|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 |
41+
42+
We replace `AlignedResize` in original implementatiuon to `Resize + ResizeToMultiple`. If you want to test by
43+
using `AlignedResize`, you can change the dataset pipeline like this:
44+
45+
```python
46+
test_pipeline = [
47+
dict(type='LoadImageFromFile'),
48+
dict(
49+
type='MultiScaleFlipAug',
50+
img_scale=(2048, 512),
51+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
52+
flip=False,
53+
transforms=[
54+
dict(type='Resize', keep_ratio=True),
55+
# resize image to multiple of 32, improve SegFormer by 0.5-1.0 mIoU.
56+
dict(type='ResizeToMultiple', size_divisor=32),
57+
dict(type='RandomFlip'),
58+
dict(type='Normalize', **img_norm_cfg),
59+
dict(type='ImageToTensor', keys=['img']),
60+
dict(type='Collect', keys=['img']),
61+
])
62+
]
63+
```
64+
65+
## How to use segformer official pretrain weights
66+
67+
We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`.
68+
69+
You may follow below steps to start segformer training preparation:
70+
71+
1. Download segformer pretrain weights (Suggest put in `pretrain/`);
72+
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`;
73+
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
_base_ = [
2+
'../_base_/models/segformer_mit-b0.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
6+
model = dict(
7+
pretrained='pretrain/mit_b0.pth', decode_head=dict(num_classes=150))
8+
9+
# optimizer
10+
optimizer = dict(
11+
_delete_=True,
12+
type='AdamW',
13+
lr=0.00006,
14+
betas=(0.9, 0.999),
15+
weight_decay=0.01,
16+
paramwise_cfg=dict(
17+
custom_keys={
18+
'pos_block': dict(decay_mult=0.),
19+
'norm': dict(decay_mult=0.),
20+
'head': dict(lr_mult=10.)
21+
}))
22+
23+
lr_config = dict(
24+
_delete_=True,
25+
policy='poly',
26+
warmup='linear',
27+
warmup_iters=1500,
28+
warmup_ratio=1e-6,
29+
power=1.0,
30+
min_lr=0.0,
31+
by_epoch=False)
32+
33+
data = dict(samples_per_gpu=2, workers_per_gpu=2)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
2+
3+
# model settings
4+
model = dict(
5+
pretrained='pretrain/mit_b1.pth',
6+
backbone=dict(
7+
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[2, 2, 2, 2]),
8+
decode_head=dict(in_channels=[64, 128, 320, 512]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
2+
3+
# model settings
4+
model = dict(
5+
pretrained='pretrain/mit_b2.pth',
6+
backbone=dict(
7+
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 6, 3]),
8+
decode_head=dict(in_channels=[64, 128, 320, 512]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
2+
3+
# model settings
4+
model = dict(
5+
pretrained='pretrain/mit_b3.pth',
6+
backbone=dict(
7+
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 18, 3]),
8+
decode_head=dict(in_channels=[64, 128, 320, 512]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
2+
3+
# model settings
4+
model = dict(
5+
pretrained='pretrain/mit_b4.pth',
6+
backbone=dict(
7+
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 8, 27, 3]),
8+
decode_head=dict(in_channels=[64, 128, 320, 512]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
2+
3+
# model settings
4+
model = dict(
5+
pretrained='pretrain/mit_b5.pth',
6+
backbone=dict(
7+
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]),
8+
decode_head=dict(in_channels=[64, 128, 320, 512]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
2+
3+
# dataset settings
4+
img_norm_cfg = dict(
5+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6+
crop_size = (640, 640)
7+
train_pipeline = [
8+
dict(type='LoadImageFromFile'),
9+
dict(type='LoadAnnotations', reduce_zero_label=True),
10+
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
11+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12+
dict(type='RandomFlip', prob=0.5),
13+
dict(type='PhotoMetricDistortion'),
14+
dict(type='Normalize', **img_norm_cfg),
15+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16+
dict(type='DefaultFormatBundle'),
17+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18+
]
19+
test_pipeline = [
20+
dict(type='LoadImageFromFile'),
21+
dict(
22+
type='MultiScaleFlipAug',
23+
img_scale=(2048, 640),
24+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25+
flip=False,
26+
transforms=[
27+
dict(type='Resize', keep_ratio=True),
28+
dict(type='RandomFlip'),
29+
dict(type='Normalize', **img_norm_cfg),
30+
dict(type='ImageToTensor', keys=['img']),
31+
dict(type='Collect', keys=['img']),
32+
])
33+
]
34+
data = dict(
35+
train=dict(pipeline=train_pipeline),
36+
val=dict(pipeline=test_pipeline),
37+
test=dict(pipeline=test_pipeline))
38+
39+
# model settings
40+
model = dict(
41+
pretrained='pretrain/mit_b5.pth',
42+
backbone=dict(
43+
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]),
44+
decode_head=dict(in_channels=[64, 128, 320, 512]))

mmseg/datasets/pipelines/transforms.py

+57
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,63 @@
66
from ..builder import PIPELINES
77

88

9+
@PIPELINES.register_module()
10+
class ResizeToMultiple(object):
11+
"""Resize images & seg to multiple of divisor.
12+
13+
Args:
14+
size_divisor (int): images and gt seg maps need to resize to multiple
15+
of size_divisor. Default: 32.
16+
interpolation (str, optional): The interpolation mode of image resize.
17+
Default: None
18+
"""
19+
20+
def __init__(self, size_divisor=32, interpolation=None):
21+
self.size_divisor = size_divisor
22+
self.interpolation = interpolation
23+
24+
def __call__(self, results):
25+
"""Call function to resize images, semantic segmentation map to
26+
multiple of size divisor.
27+
28+
Args:
29+
results (dict): Result dict from loading pipeline.
30+
31+
Returns:
32+
dict: Resized results, 'img_shape', 'pad_shape' keys are updated.
33+
"""
34+
# Align image to multiple of size divisor.
35+
img = results['img']
36+
img = mmcv.imresize_to_multiple(
37+
img,
38+
self.size_divisor,
39+
scale_factor=1,
40+
interpolation=self.interpolation
41+
if self.interpolation else 'bilinear')
42+
43+
results['img'] = img
44+
results['img_shape'] = img.shape
45+
results['pad_shape'] = img.shape
46+
47+
# Align segmentation map to multiple of size divisor.
48+
for key in results.get('seg_fields', []):
49+
gt_seg = results[key]
50+
gt_seg = mmcv.imresize_to_multiple(
51+
gt_seg,
52+
self.size_divisor,
53+
scale_factor=1,
54+
interpolation='nearest')
55+
results[key] = gt_seg
56+
57+
return results
58+
59+
def __repr__(self):
60+
repr_str = self.__class__.__name__
61+
repr_str += (f'(size_divisor={self.size_divisor}, '
62+
f'interpolation={self.interpolation})')
63+
return repr_str
64+
65+
966
@PIPELINES.register_module()
1067
class Resize(object):
1168
"""Resize images & seg.

mmseg/models/backbones/mit.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ...utils import get_root_logger
1313
from ..builder import BACKBONES
14-
from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw
14+
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
1515

1616

1717
class MixFFN(BaseModule):
@@ -159,7 +159,13 @@ def forward(self, x, hw_shape, identity=None):
159159
if identity is None:
160160
identity = x_q
161161

162-
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
162+
# `need_weights=True` will let nn.MultiHeadAttention
163+
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
164+
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
165+
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
166+
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
167+
# the error that large scale tensor sum operation may cause cuda error.
168+
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
163169

164170
return identity + self.dropout_layer(self.proj_drop(out))
165171

@@ -387,17 +393,9 @@ def init_weights(self):
387393
self.pretrained, logger=logger, map_location='cpu')
388394
if 'state_dict' in checkpoint:
389395
state_dict = checkpoint['state_dict']
390-
elif 'model' in checkpoint:
391-
state_dict = checkpoint['model']
392396
else:
393397
state_dict = checkpoint
394398

395-
if self.pretrain_style == 'official':
396-
# Because segformer backbone is not support by mmcls,
397-
# so we need to convert pretrain weights to match this
398-
# implementation.
399-
state_dict = mit_convert(state_dict)
400-
401399
self.load_state_dict(state_dict, False)
402400

403401
def forward(self, x):

mmseg/models/decode_heads/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .point_head import PointHead
1717
from .psa_head import PSAHead
1818
from .psp_head import PSPHead
19+
from .segformer_head import SegformerHead
1920
from .sep_aspp_head import DepthwiseSeparableASPPHead
2021
from .sep_fcn_head import DepthwiseSeparableFCNHead
2122
from .setr_mla_head import SETRMLAHead
@@ -26,5 +27,6 @@
2627
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
2728
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
2829
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
29-
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
30+
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
31+
'SETRMLAHead', 'SegformerHead'
3032
]

0 commit comments

Comments
 (0)