-
Notifications
You must be signed in to change notification settings - Fork 17
/
run.py
executable file
·63 lines (41 loc) · 1.59 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import json
import torch
import torch.distributed as dist
from utils.args import get_args,logging_cfgs
from utils.initialize import initialize
from utils.build_model import build_model
from utils.build_optimizer import build_optimizer
from utils.build_dataloader import create_train_dataloaders, create_val_dataloaders
from utils.pipeline import train, test
def main():
### init
args = get_args()
initialize(args)
### logging cfgs
logging_cfgs(args)
if args.run_cfg.mode == 'training':
### create datasets and dataloader
train_loader = create_train_dataloaders(args)
val_loaders = create_val_dataloaders(args)
### build model and optimizer
model, optimizer_ckpt, start_step = build_model(args)
optimizer = build_optimizer(model, args, optimizer_ckpt)
### start evaluation
if args.run_cfg.first_eval or args.run_cfg.zero_shot:
test(model, val_loaders, args.run_cfg)
if args.run_cfg.zero_shot:
return
### start training
train(model, optimizer, train_loader, val_loaders, args.run_cfg, start_step = start_step, verbose_time=False)
elif args.run_cfg.mode == 'testing':
### build model
model,_,_ = build_model(args)
### create datasets and dataloader
val_loaders = create_val_dataloaders(args)
### start evaluation
test(model, val_loaders, args.run_cfg)
else:
raise NotImplementedError
if __name__ == "__main__":
main()