Skip to content

Commit

Permalink
--refactor=support build
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Aug 9, 2023
1 parent 1f1c66e commit e7a75f3
Show file tree
Hide file tree
Showing 21 changed files with 181 additions and 7 deletions.
2 changes: 2 additions & 0 deletions projects/CAT-Seg/cat_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .models import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from mmengine.model import BaseModule

from mmseg.registry import MODELS
from mmseg.utils import ConfigType, clip_templates
from mmseg.utils import ConfigType
from ..utils import clip_wrapper
from ..utils.clip_templates import (IMAGENET_TEMPLATES,
IMAGENET_TEMPLATES_SELECT)


@MODELS.register_module()
Expand Down Expand Up @@ -88,9 +90,9 @@ def __init__(
# prepare clip templates
self.prompt_ensemble_type = prompt_ensemble_type
if self.prompt_ensemble_type == 'imagenet_select':
prompt_templates = clip_templates.IMAGENET_TEMPLATES_SELECT
prompt_templates = IMAGENET_TEMPLATES_SELECT
elif self.prompt_ensemble_type == 'imagenet':
prompt_templates = clip_templates.IMAGENET_TEMPLATES
prompt_templates = IMAGENET_TEMPLATES
elif self.prompt_ensemble_type == 'single':
prompt_templates = [
'A photo of a {} in the scene',
Expand Down
10 changes: 10 additions & 0 deletions projects/CAT-Seg/cat_seg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
from .self_attention_block import FullAttention, LinearAttention

__all__ = [
'FullAttention', 'LinearAttention', 'IMAGENET_TEMPLATES',
'IMAGENET_TEMPLATES_SELECT', 'IMAGENET_TEMPLATES_SELECT_CLIP',
'ViLD_templates'
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
15 changes: 15 additions & 0 deletions projects/CAT-Seg/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False

tta_model = dict(type='SegTTAModel')
25 changes: 25 additions & 0 deletions projects/CAT-Seg/configs/_base_/schedules/schedule_160k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=160000,
by_epoch=False)
]
# training schedule for 160k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=160000, val_interval=16000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=16000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
24 changes: 24 additions & 0 deletions projects/CAT-Seg/configs/_base_/schedules/schedule_20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=20000,
by_epoch=False)
]
# training schedule for 20k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=20000, val_interval=2000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
25 changes: 25 additions & 0 deletions projects/CAT-Seg/configs/_base_/schedules/schedule_240k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=240000,
by_epoch=False)
]
# training schedule for 240k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
25 changes: 25 additions & 0 deletions projects/CAT-Seg/configs/_base_/schedules/schedule_320k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=320000,
by_epoch=False)
]
# training schedule for 320k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=320000, val_interval=32000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=32000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
24 changes: 24 additions & 0 deletions projects/CAT-Seg/configs/_base_/schedules/schedule_40k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=40000,
by_epoch=False)
]
# training schedule for 40k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=40000, val_interval=4000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
24 changes: 24 additions & 0 deletions projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=80000,
by_epoch=False)
]
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
'../_base_/datasets/coco-stuff164k_384x384.py'
]

custom_imports = dict(imports=['cat_seg'])

norm_cfg = dict(type='SyncBN', requires_grad=True)
crop_size = (384, 384)
data_preprocessor = dict(
Expand Down
4 changes: 0 additions & 4 deletions projects/CAT-Seg/mmseg/utils/__init__.py

This file was deleted.

0 comments on commit e7a75f3

Please sign in to comment.