-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_main.py
47 lines (40 loc) · 1.36 KB
/
train_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from __future__ import print_function
import os
import torch
import train_options
from train_Processor import Processor
from utils.utils import write_settings_to_file
def main():
args = train_options.parser.parse_args()
# assert train
assert args.run_type in [0, 1, 2, 3]
# fix random seed for stable results
torch.manual_seed(args.seed)
# set visible gpus
args.gpu_ids = visible_gpu(args.gpus)
# create folder
if not os.path.exists('./ckpt/'):
os.makedirs('./ckpt/')
if not os.path.exists('./logs/'):
os.makedirs('./logs/')
out_dir = os.path.join('./ckpt/', args.dataset_name, str(args.model_id))
log_dir = os.path.join('./logs/', args.dataset_name, str(args.model_id))
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if args.run_type == 0 or args.run_type == 1 and args.pretrained == False:
write_settings_to_file(args)
processor = Processor(args)
processor.processing()
def visible_gpu(gpus):
"""
set visible gpu.
can be a single id, or a list
return a list of new gpus ids
"""
gpus = [gpus] if isinstance(gpus, int) else list(gpus)
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(list(map(str, gpus)))
return list(range(len(gpus)))
if __name__ == '__main__':
main()