-
Notifications
You must be signed in to change notification settings - Fork 0
/
default_configs.py
194 lines (159 loc) · 6.27 KB
/
default_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from fvcore.common.config import CfgNode
from loguru import logger
import clip
from dataset.transforms import AUGMENTATIONS
_C = CfgNode()
# ---------------------------------------------------------------------------- #
# general options
_C.RNG_SEED = 0
_C.EVAL_INTERVAL = 1
_C.LOG_INTERVAL = 10
_C.SAVE_INTERVAL = None
_C.TRAIN_PRECISION = "fp16"
_C.SAVE_PATH = None
_C.RESUME_CHECKPOINT = None
_C.LOG_TO_WANDB = False
_C.EVAL_ONLY = False
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# Model options
_C.MODEL = CfgNode()
_C.MODEL.VIZ_BACKBONE = "ViT-B/16"
_C.MODEL.PER_CLS_PROMPTS = False # CoOp option
_C.MODEL.PROMPT_POSITION = "start" # CoOp option, position of the class names
_C.MODEL.IMG_CONDITIONING = False # CoCoop option
_C.MODEL.NUM_PROMPTS = 16 # CoOp option
_C.MODEL.FRAME_AGGREGATION = (
"transformer_2" # # Video option, mean, max, transformer_nlayers
)
_C.MODEL.SOFTMAX_TEMP = None # if None, use CLIP's temp
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# Data options
_C.DATA = CfgNode()
_C.DATA.TYPE = "image" # image or video
_C.DATA.DATASET_NAME = "ImageNet"
_C.DATA.DATA_PATH = ""
_C.DATA.MEAN = [0.48145466, 0.4578275, 0.40821073]
_C.DATA.STD = [0.26862954, 0.26130258, 0.27577711]
_C.DATA.TRAIN_AUGS = ["random_resized_crop", "random_flip", "normalize"]
_C.DATA.TRAIN_RESIZE = None
_C.DATA.TRAIN_CROP_SIZE = 224
_C.DATA.TEST_AUGS = ["resize", "center_crop", "normalize"]
_C.DATA.TEST_RESIZE = 224
_C.DATA.TEST_CROP_SIZE = 224
_C.DATA.TEST_STRIDES = [8]
# Video data options
_C.DATA.NUM_FRAMES = 8
_C.DATA.TARGET_FPS = 30
_C.DATA.TRAIN_STRIDES = [8]
_C.DATA.TRAIN_VIDEO_SAMPLER = "random"
_C.DATA.TEST_NUM_CLIPS = 1
# single_view, multi_view_strides, multi_view_sliding
_C.DATA.TEST_METHOD = "single_view"
# Fewshot / Zeroshot options
_C.DATA.FEWSHOT = False # Video option, for images just set N_SHOT > 0
_C.DATA.ZEROSHOT = False # Video option
_C.DATA.USE_ALL_CLASSES = False # Video option
_C.DATA.N_SHOT = 0 # Video option
_C.DATA.C_WAY = -1 # Image and video option
_C.DATA.N_QUERY_SHOT = 95 # Video option
_C.DATA.USE_BASE_AND_NEW = False # Image option
# Image option, domain generalization
# ImageNet-A R V2 and Sketch
_C.DATA.TARGET_DATASET = None
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# data loader options
_C.DATALOADER = CfgNode()
_C.DATALOADER.NUM_WORKERS = 8
_C.DATALOADER.PIN_MEMORY = True
_C.DATALOADER.TRAIN_BATCHSIZE = 8
_C.DATALOADER.TEST_BATCHSIZE = 8
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# Optimizer options
_C.OPT = CfgNode()
_C.OPT.MAX_EPOCHS = 1
_C.OPT.LR = 0.1
_C.OPT.TYPE = "sgd"
_C.OPT.LR_POLICY = "cosine"
_C.OPT.COSINE_END_LR = 0.0
_C.OPT.LINEAR_END_LR = 0.0
_C.OPT.STEPS = []
_C.OPT.WEIGHT_DECAY = 1e-4
_C.OPT.WARMUP_EPOCHS = 0.0
_C.OPT.ZERO_WD_1D_PARAM = False
_C.OPT.CLIP_L2_GRADNORM = None
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
def check_and_update_configs(cfg):
assert cfg.TRAIN_PRECISION in ["fp32", "fp16", "amp"]
assert cfg.MODEL.VIZ_BACKBONE in clip.available_models()
assert cfg.MODEL.PROMPT_POSITION in ["start", "middle", "end"]
assert (
cfg.MODEL.NUM_PROMPTS < 65
), "Possibly no room left for class name and other tokens"
assert cfg.MODEL.FRAME_AGGREGATION.split(
"_")[0] in ["mean", "max", "transformer"]
assert cfg.DATA.TRAIN_VIDEO_SAMPLER in ["random", "center"]
assert cfg.OPT.TYPE in ["sgd", "adam", "adamw"]
assert cfg.OPT.LR_POLICY in ["cosine", "step", "linear", "constant"]
if cfg.OPT.CLIP_L2_GRADNORM:
assert cfg.OPT.CLIP_L2_GRADNORM > 0.0
assert cfg.DATA.TEST_METHOD in [
"single_view",
"multi_view_strides",
"multi_view_sliding",
]
if cfg.DATA.TEST_METHOD == "single_view":
assert cfg.DATA.TEST_NUM_CLIPS == 1 and len(cfg.DATA.TEST_STRIDES) == 1
cfg.DATA.TEST_VIDEO_SAMPLER = "center"
elif cfg.DATA.TEST_METHOD == "multi_view_strides":
assert cfg.DATA.TEST_NUM_CLIPS == len(cfg.DATA.TEST_STRIDES)
cfg.DATA.TEST_VIDEO_SAMPLER = "center"
else:
assert len(cfg.DATA.TEST_STRIDES) == 1 and cfg.DATA.TEST_NUM_CLIPS > 1
cfg.DATA.TEST_VIDEO_SAMPLER = "sliding"
for aug in cfg.DATA.TRAIN_AUGS:
assert aug in list(AUGMENTATIONS.keys()
), f"Augmentation {aug} is not supported"
if cfg.DATA.FEWSHOT:
assert cfg.DATA.TYPE == "video"
assert cfg.DATA.DATASET_NAME in [
"UCF101",
"HMDB51",
"K400",
], "Few-shot dataset is only supported for UCF101, HMDB51, K400"
cfg.DATA.DATASET_NAME = f"{cfg.DATA.DATASET_NAME}FewShot"
if cfg.DATA.ZEROSHOT:
assert cfg.DATA.TYPE == "video"
assert cfg.DATA.DATASET_NAME in [
"UCF101",
"HMDB51",
"K700",
], "Zero-shot dataset is only supported for UCF101, HMDB51, K700"
cfg.DATA.DATASET_NAME = f"{cfg.DATA.DATASET_NAME}ZeroShot"
if cfg.DATA.USE_ALL_CLASSES:
assert cfg.DATA.TYPE == "video"
assert cfg.DATA.C_WAY == -1, "C_WAY must be -1 if USE_ALL_CLASSES is True"
assert (
cfg.DATA.FEWSHOT
), "C way (using all classes) is only supported for fewshot"
assert (
cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
), "For clip, use the same as the original size for both"
if cfg.RESUME_CHECKPOINT:
assert cfg.RESUME_CHECKPOINT.endswith(".pyth")
return cfg
def get_cfg(args):
cfg = _C.clone()
# update default with ones from yaml file
if args.config_file is not None:
cfg.merge_from_file(args.config_file)
# update using the passed values to argparse
if args.opts is not None:
cfg.merge_from_list(args.opts)
cfg = check_and_update_configs(cfg)
logger.info(f"Configs of this run:\n{cfg}")
return cfg