-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpostpretrain.py
362 lines (318 loc) · 11.3 KB
/
postpretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""PyTorch Lightning training script for the midtraining task with weighted contrastive loss."""
import os
import torch
from torch.utils.data import DataLoader, Subset
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from termcolor import colored
from external.utils_videoclip import load_videoclip_model
import package.datasets as datasets
import package.models as models
from package.utils.log import print_update
import warnings
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def check_args(args):
"""Checks validity of arguments."""
if args.only_eval:
args.only_train = False
if args.only_train:
args.only_eval = False
if args.dataset == "synthetic":
assert args.eval_split == "test", \
f"Not a valid split(={args.eval_split}) for synthetic dataset."\
if args.dataset == "synthetic" and args.eval_split == "test":
assert args.eval_subset in ["v2.0"], \
f"Not a valid subset(={args.eval_subset}) for synthetic dataset."\
"Only --subset v2.0 is supported."
if args.dataset == "tempo" and args.eval_split in ["val", "test"]:
assert args.eval_subset in ["temporal_1k"], \
f"Not a valid subset(={args.eval_subset}) for tempo dataset."\
"Only --subset temporal_1k is supported."
if args.gpus is None:
if torch.cuda.is_available():
args.gpus = torch.cuda.device_count()
else:
args.gpus = None
return args
def update_config(config, args):
config.lr = args.lr
# config.contrastive_lambda = args.w_contrastive
config.contrastive_lambda = 1.0
# config.temporal_lambda = args.w_temporal
config.batch_size = args.batch_size
config.epoch = args.epochs
config.freeze_layers = args.freeze_layers
config.video_freeze_layers = args.video_freeze_layers
config.text_freeze_layers = args.text_freeze_layers
# config.no_reverse = args.no_reverse
config.alpha_same = args.alpha_same
config.alpha_cross = args.alpha_cross
config.beta = args.beta
return config
def freeze_required_layers(model, args):
# freeze layers (these layers are frozen by default)
modules_to_freeze = [
model.video_encoder.bert.encoder.layer[:args.video_freeze_layers],
model.text_encoder.encoder.layer[:args.text_freeze_layers],
model.video_encoder.bert.embeddings.word_embeddings,
model.text_encoder.embeddings.word_embeddings,
]
if args.freeze_videomlp:
print(">>> Freezing video MLP")
modules_to_freeze += [model.video_encoder.videomlp]
if args.freeze_pooler:
print(">>> Freezing pooler for video/text")
modules_to_freeze += [
model.video_encoder.bert.pooler,
model.text_encoder.pooler,
]
if args.freeze_pos_emb:
print("\n>>> Freezing positional embeddings")
modules_to_freeze.extend([
model.video_encoder.bert.embeddings.position_embeddings,
model.text_encoder.embeddings.position_embeddings,
])
if args.remove_pos_emb:
print(">>> Removing positional embeddings")
# also set them to 0
model.video_encoder.bert.embeddings.position_embeddings.weight.data.fill_(0)
model.text_encoder.embeddings.position_embeddings.weight.data.fill_(0)
print(">>> Positional embeddings set to 0")
print(model.video_encoder.bert.embeddings.position_embeddings.weight.data)
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
# sanity check
print(">>> Parameters to train:")
for name, params in model.named_parameters():
if params.requires_grad:
print(name, params.shape)
return model
if __name__ == "__main__":
# read arguments
import argparse
parser = argparse.ArgumentParser("Train a model")
# Model args
parser.add_argument(
"--model", type=str, default="videoclip",
choices=["videoclip"],
)
parser.add_argument(
"--config", type=str,
default="external/fairseq/examples/MMPT/"\
"projects/retri/videoclip/test_vtt_zs.yaml",
help="Path to config file",
)
parser.add_argument(
"--freeze_videomlp", action="store_true",
help="Freeze video MLP",
)
parser.add_argument(
"--freeze_pooler", action="store_true",
help="Freeze pooler for both video/text",
)
parser.add_argument(
"--freeze_pos_emb", action="store_true",
help="Freeze positional embeddings or not",
)
parser.add_argument(
"--freeze_layers", type=int, default=5,
help="Number of layers to freeze",
)
parser.add_argument(
"--video_freeze_layers", type=int, default=5,
help="Number of layers to freeze in video transformer",
)
parser.add_argument(
"--text_freeze_layers", type=int, default=5,
help="Number of layers to freeze in text transformer",
)
parser.add_argument(
"--alpha_same", type=float, default=1.0,
help="Alpha for same-sample time-reversal",
)
parser.add_argument(
"--alpha_cross", type=float, default=1.0,
help="Alpha for cross-sample time-reversal",
)
parser.add_argument(
"--beta", type=float, default=1.0,
help="Beta for the contrastive loss",
)
parser.add_argument(
"--remove_pos_emb", action="store_true",
help="Remove positional embeddings at all",
)
parser.add_argument(
"-c", "--ckpt_path", type=str,
default=None,
help="Path to checkpoint (only used for evaluation)",
)
# Dataset args
parser.add_argument(
"--dataset", type=str, default="tempo",
help="Dataset name", choices=["tempo", "synthetic"],
)
parser.add_argument(
"--data_root", type=str, required=True,
)
parser.add_argument(
"--eval_split", type=str, default="val",
help="Split name for validation dataset", choices=["val", "test"],
)
parser.add_argument(
"--eval_subset", type=str, default=None,
help="Subset name for validation dataset",
)
# Optimization and other args
parser.add_argument(
"--lr", type=float, default=5.0e-06,
help="Learning rate",
)
parser.add_argument(
"--gpus", type=int, default=None, nargs="+",
)
parser.add_argument(
"--batch_size", type=int, default=32,
help="Batch size",
)
parser.add_argument(
"--epochs", type=int, default=40,
help="Number of epochs",
)
parser.add_argument(
"--no_wandb", action="store_true",
help="Force not to use wandb",
)
parser.add_argument(
"--overfit_batches", type=float, default=0.0,
help="Overfit batches",
)
parser.add_argument(
"--suffix", type=str, default="",
help="Suffix to add to the name of the run",
)
parser.add_argument(
"--save_every", type=int, default=10,
help="Save every n epochs",
)
parser.add_argument(
"--only_eval", action="store_true",
help="Only evaluate the model",
)
parser.add_argument(
"--only_train", action="store_true",
help="Only train the model",
)
parser.add_argument(
"--debug", action="store_true",
help="Debug mode",
)
args = parser.parse_args()
args = check_args(args)
# 1. Load the datasets
dataset_load_function = getattr(datasets, f"load_{args.dataset}_dataset")
dataset_load_args = dict(
data_root=args.data_root,
)
if not args.only_eval:
print_update(">>> Loading train set")
# 1.A. Load train set (only if not only_eval)
dataset_load_args.update(dict(mode="train", subset=None))
train_dataset = dataset_load_function(**dataset_load_args)
# 1.B. Load val set
print_update(">>> Loading validation set")
dataset_load_args.update(dict(mode=args.eval_split, subset=args.eval_subset))
valid_dataset = dataset_load_function(**dataset_load_args)
if args.debug:
train_dataset = Subset(train_dataset, range(1000))
val_dataset = Subset(valid_dataset, range(500))
# 1.C. Load the dataloaders
if not args.only_eval:
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
drop_last=True,
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
drop_last=True,
)
# 2. Load the model
print_update(">>> Loading model")
# 2.A. Load the base VideoCLIP model
config, model = load_videoclip_model(
cfg_path=args.config,
)
config = update_config(config, args)
model = freeze_required_layers(model, args)
# 2.B. Load PL module
model_loading_function = getattr(models, "VideoCLIP")
pl_module = model_loading_function(config, model)
# 3. Run the experiment (train/eval)
print_update(">>> Setting up experiment")
# 3.A. Setup logging/other cosmetics
log = not args.no_wandb
logger = None
if log:
run_name = f"ppt-{args.model}-{args.dataset}-bs_{args.batch_size}"\
f"-frozen-lr_{args.lr}-ep{args.epochs}"
run_name += "-overfit" if args.overfit_batches > 0 else ""
run_name += "-" + args.suffix
run_name += "-alpha_same_" + str(args.alpha_same) \
+ "-alpha_cross_" + str(args.alpha_cross) \
+ "-beta_" + str(args.beta)
print("WARNING: If you need to log to W&B, "\
"you need to change entity & project.")
logger = pl_loggers.WandbLogger(
project="test-of-time",
entity="bpiyush",
name=run_name,
)
callbacks = []
if not args.only_eval:
# 3.B. Save the model every 5 epochs
save_every_k_epochs = ModelCheckpoint(
every_n_epochs=args.save_every,
save_top_k=-1,
save_last=True,
)
callbacks.append(save_every_k_epochs)
# 3.C. Define the trainer
trainer = pl.Trainer(
logger=logger,
gpus=args.gpus,
max_epochs=args.epochs,
log_every_n_steps=2,
callbacks=callbacks,
overfit_batches=args.overfit_batches,
)
# 3.D. Load the checkpoint if required
if args.ckpt_path is not None:
print(
colored(
f">>> Initializing with checkpoint: {args.ckpt_path}",
"magenta"
)
)
state_dict = torch.load(args.ckpt_path, map_location='cpu')['state_dict']
pl_module.load_state_dict(state_dict)
# 3.E. Run the evaluation (before training)
if not args.only_train:
print_update(">>> Running evaluation")
trainer.validate(pl_module, dataloaders=valid_dataloader)
# 3.F. Run the training
if not args.only_eval:
print_update(">>> Running training")
trainer.fit(
model=pl_module,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
)