-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
130 lines (123 loc) · 6 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
def get_parser():
parser = argparse.ArgumentParser(description='RIS training and testing')
# Dataset settings
parser.add_argument('--dataset',
default='refcoco',
choices=['refcoco', 'refcoco+', 'refcocog'],
help='refcoco, refcoco+, or refcocog')
parser.add_argument('--img_size',
default=448,
type=int, help='input image size')
parser.add_argument('--split',
default='test',
help='only used when testing')
parser.add_argument('--splitBy',
default='unc',
help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
parser.add_argument('--refer_data_root',
default=None,
help='REFER dataset root directory')
parser.add_argument('--refer_root',
default=None,
help='REFER annotations root directory')
# General model settings
parser.add_argument('--model',
default='maskris',
help='model')
parser.add_argument('--model_id',
default=None,
help='name to identify the model')
parser.add_argument('--bert_tokenizer',
default='bert-base-uncased',
help='BERT tokenizer')
parser.add_argument('--ck_bert',
default='bert-base-uncased',
help='pre-trained BERT weights')
parser.add_argument('--swin_type',
default='base',
help='tiny, small, base, or large variants of the Swin Transformer')
parser.add_argument('--pretrained_swin_weights',
default='',
help='path to pre-trained Swin backbone weights')
# For training
parser.add_argument('--amsgrad', action='store_true',
help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
parser.add_argument('--clip_grads',
action='store_true',
help='if true, enable gradient clipping.')
parser.add_argument('--clip_value',
default=1.0, type=float,
help='max norm of the gradients.')
parser.add_argument('-b', '--batch-size',
default=8, type=int)
parser.add_argument('--epochs',
default=40, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--lr',
default=0.00005, type=float, help='the initial learning rate')
parser.add_argument('--min_lr',
default=0, type=float, help='the minimal learning rate')
parser.add_argument('--wd', '--weight-decay',
default=1e-2, type=float, metavar='W',
help='weight decay',
dest='weight_decay')
parser.add_argument('--warmup',
action='store_true',
help='if true, use warmup for training.')
parser.add_argument('--warmup_iters',
default=100, type=int)
parser.add_argument('--warmup_ratio',
default=0.1, type=float)
# For testing
parser.add_argument('--ddp_trained_weights',
action='store_true',
help='Only needs specified when testing,'
'whether the weights to be loaded are from a DDP-trained model')
parser.add_argument('--device',
default='cuda:0', help='device') # only used when testing on a single machine
parser.add_argument('--window12',
action='store_true',
help='only needs specified when testing,'
'when training, window size is inferred from pre-trained weights file name'
'(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
parser.add_argument('--eval_ori_size',
action='store_true',
help='evaluation with original image size')
# Experment settings
parser.add_argument("--local_rank",
type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--output-dir',
default=None, help='path where to save checkpoint weights')
parser.add_argument('--pin_mem',
action='store_true',
help='If true, pin memory when using the data loader.')
parser.add_argument('--print-freq',
default=10, type=int, help='print frequency')
parser.add_argument('--resume',
default='', help='resume from checkpoint')
parser.add_argument('-j', '--workers',
default=4, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--mix',
action='store_true',
help='if true, use refcoco/+/g mixed dataset for training.')
parser.add_argument('--save_visual_dir',
type=str, default=None)
# seed
parser.add_argument('--seed',
type=int, default=0)
parser.add_argument('--deterministic',
action='store_true')
# MaskRIS hyperparameters
parser.add_argument('--img_mask_ratio',
type=float, default=0.0)
parser.add_argument('--img_patch_size',
type=int, default=32)
parser.add_argument('--txt_mask_ratio',
type=float, default=0.0)
parser.add_argument('--txt_mask_ratio_sub',
type=float, default=(0.8, 0.2), nargs=2)
return parser
if __name__ == "__main__":
parser = get_parser()
args_dict = parser.parse_args()