-
Notifications
You must be signed in to change notification settings - Fork 109
/
main.py
75 lines (62 loc) · 2.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import cv2
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import argparse
import numpy as np
import random
from clrnet.utils.config import Config
from clrnet.engine.runner import Runner
from clrnet.datasets import build_dataloader
def main():
args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
str(gpu) for gpu in args.gpus)
cfg = Config.fromfile(args.config)
cfg.gpus = len(args.gpus)
cfg.load_from = args.load_from
cfg.resume_from = args.resume_from
cfg.finetune_from = args.finetune_from
cfg.view = args.view
cfg.seed = args.seed
cfg.work_dirs = args.work_dirs if args.work_dirs else cfg.work_dirs
cudnn.benchmark = True
runner = Runner(cfg)
if args.validate:
runner.validate()
elif args.test:
runner.test()
else:
runner.train()
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work_dirs',
type=str,
default=None,
help='work dirs')
parser.add_argument('--load_from',
default=None,
help='the checkpoint file to load from')
parser.add_argument('--resume_from',
default=None,
help='the checkpoint file to resume from')
parser.add_argument('--finetune_from',
default=None,
help='the checkpoint file to resume from')
parser.add_argument('--view', action='store_true', help='whether to view')
parser.add_argument(
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
parser.add_argument(
'--test',
action='store_true',
help='whether to test the checkpoint on testing set')
parser.add_argument('--gpus', nargs='+', type=int, default='0')
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()