generated from songquanpeng/pytorch-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
29 lines (25 loc) · 818 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from munch import Munch
from config import load_cfg, save_cfg, print_cfg
from utils.misc import setup, validate
from solver.solver import Solver
from data.loader import get_train_loader, get_test_loader, get_selected_loader
def main(args):
setup(args)
validate(args)
solver = Solver(args)
if args.mode == 'train':
loaders = Munch(train=get_train_loader(**args), test=get_test_loader(**args))
if args.selected_path:
loaders.selected = get_selected_loader(**args)
solver.train(loaders)
elif args.mode == 'sample':
solver.sample()
elif args.mode == 'eval':
solver.evaluate()
else:
assert False, f"Unimplemented mode: {args.mode}"
if __name__ == '__main__':
cfg = load_cfg()
save_cfg(cfg)
print_cfg(cfg)
main(cfg)