-
Notifications
You must be signed in to change notification settings - Fork 2
/
cityscapes.py
77 lines (73 loc) · 2.63 KB
/
cityscapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
_base_ = [
'_base_/models/upernet_r50.py', '_base_/datasets/cityscapes.py',
'_base_/default_runtime.py', '_base_/schedules/schedule_80k.py'
]
IMG_MEAN = [v * 255 for v in [0.5, 0.5, 0.5]]
IMG_VAR = [v * 255 for v in [0.5, 0.5, 0.5]]
img_norm_cfg = dict(mean=IMG_MEAN, std=IMG_VAR, to_rgb=True)
head_c=512
in_c=150
model = dict(
type='MetaPromptsSeg',
sd_path='checkpoints/v1-5-pruned-emaonly.ckpt',
refine_step=3,
num_prompt=in_c,
decode_head=dict(
type='UPerHead',
in_channels=[in_c, in_c, in_c, in_c],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=head_c,
dropout_ratio=0.1,
num_classes=19,
loss_decode=
[
dict(type='CrossEntropyLoss', loss_name='loss_ce', use_sigmoid=False, loss_weight=1.0),
dict(type='LovaszLoss', reduction='none', loss_weight=1.0)
]
),
auxiliary_head=dict(
type='FCNHead',
in_channels=in_c,
in_index=2,
channels=head_c,
num_convs=1,
dropout_ratio=0.1,
num_classes=19,
loss_decode=dict(type='CrossEntropyLoss', loss_name='loss_ce_aux', use_sigmoid=False, loss_weight=0.4)
),
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(512,512))
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
# flip=True,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(samples_per_gpu=1, workers_per_gpu=2,
train=dict(
img_dir='leftImg8bit/train',
ann_dir='gtFine/train'),
test=dict(pipeline=test_pipeline))
lr_config = dict(policy='poly', power=1, min_lr=0.0, by_epoch=False,
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6)
optimizer = dict(type='AdamW', lr=0.00008, weight_decay=0.001,
paramwise_cfg=dict(bypass_duplicate=True,
custom_keys={'unet': dict(lr_mult=0.1),
'encoder_vq': dict(lr_mult=0.0),
'text_encoder': dict(lr_mult=0.1),
'norm': dict(decay_mult=0.)}))
checkpoint_config = dict(by_epoch=False, interval=8000)
evaluation = dict(interval=8000, metric='mIoU', save_best = 'mIoU', pre_eval=True)