-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
72 lines (63 loc) · 3.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import logging
import argparse
import torch
from lib.config import Config
from lib.runner import Runner
from lib.experiment import Experiment
def parse_args():
parser = argparse.ArgumentParser(description="Train lane detector")
parser.add_argument("mode", help="Train or test?")
parser.add_argument("--exp_name", help="Experiment name", required=True)
parser.add_argument("--cfg", help="Config file")
parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument("--epoch", type=int, help="Epoch to test the model on")
parser.add_argument("--cpu", action="store_true", help="(Unsupported) Use CPU instead of GPU")
parser.add_argument("--save_predictions", action="store_true", help="Save predictions to pickle file")
parser.add_argument("--view", choices=["all", "mistakes", 'corrects'], help="Show predictions")
parser.add_argument("--deterministic",
action="store_true",
help="set cudnn.deterministic = True and cudnn.benchmark = False")
args = parser.parse_args()
if args.cfg is None and args.mode == "train":
raise Exception("If you are training, you have to set a config file using --cfg /path/to/your/config.yaml")
if args.resume and args.mode == "test":
raise Exception("args.resume is set on `test` mode: can't resume testing")
if args.epoch is not None and args.mode == 'train':
raise Exception("The `epoch` parameter should not be set when training")
# if args.view is not None and args.mode != "test":
# raise Exception('Visualization is only available during evaluation')
if args.cpu:
raise Exception("CPU training/testing is not supported: the NMS procedure is only implemented for CUDA")
return args
def main():
args = parse_args()# 读取命令行
exp = Experiment(args.exp_name, args, mode=args.mode)# 初始化一些实验的路径参数
if args.cfg is None:
cfg_path = exp.cfg_path
else:
cfg_path = args.cfg
cfg = Config(cfg_path)# 读取输入的cfg文件
exp.set_cfg(cfg, override=False)
device = torch.device('cpu') if not torch.cuda.is_available() or args.cpu else torch.device('cuda')
split = args.mode
if args.mode == 'fps':
split = 'train_10000'
elif args.mode == 'draw':
split = 'test_100'
runner = Runner(cfg, exp, device, args, view=args.view, resume=args.resume, deterministic=args.deterministic, mode=split)
if 'train' in args.mode:
try:
runner.train()
except KeyboardInterrupt:
logging.info('Training interrupted.')
runner.eval(epoch=args.epoch or exp.get_last_checkpoint_epoch(), save_predictions=args.save_predictions)
elif 'test' in args.mode:
runner.eval(epoch=args.epoch or exp.get_last_checkpoint_epoch(), save_predictions=args.save_predictions)
elif args.mode == 'fps':
runner.fps(epoch=args.epoch or exp.get_last_checkpoint_epoch(), save_predictions=args.save_predictions)
elif args.mode == 'draw':
runner.draw(epoch=args.epoch or exp.get_last_checkpoint_epoch(), save_predictions=args.save_predictions)
else:
runner.eval_by_class(epoch=args.epoch or exp.get_last_checkpoint_epoch(), save_predictions=False)
if __name__ == '__main__':
main()