-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
99 lines (63 loc) · 2.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import cv2
import random
import numpy as np
import torch
import argparse
from shutil import copyfile
from src.config import Config
from src.misf import MISF
import torch.nn as nn
def main(mode=None):
config = load_config(mode)
# cuda visble devices
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU)
# init device
if torch.cuda.is_available():
config.DEVICE = torch.device("cuda")
torch.backends.cudnn.benchmark = True # cudnn auto-tuner
else:
config.DEVICE = torch.device("cpu")
# set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
cv2.setNumThreads(0)
# initialize random seed
torch.manual_seed(config.SEED)
torch.cuda.manual_seed_all(config.SEED)
np.random.seed(config.SEED)
random.seed(config.SEED)
# build the model and initialize
model = MISF(config)
model.load()
iteration = model.inpaint_model.iteration
if len(config.GPU) > 1:
print('GPU:{}'.format(config.GPU))
model.inpaint_model.generator = nn.DataParallel(model.inpaint_model.generator, config.GPU)
model.inpaint_model.discriminator = nn.DataParallel(model.inpaint_model.discriminator, config.GPU)
model.inpaint_model.iteration = iteration
# print(model.inpaint_model)
# model training
if config.MODE == 1:
# config.print()
print('\nstart training...\n')
model.train()
# model test
elif config.MODE == 2:
print('\nstart testing...\n')
model.test()
def load_config(mode=None):
parser = argparse.ArgumentParser()
parser.add_argument('--path', '--checkpoints', type=str, default='./checkpoints', help='model checkpoints path (default: ./checkpoints)')
args = parser.parse_args()
config_path = os.path.join(args.path, 'config.yml')
if not os.path.exists(args.path):
os.makedirs(args.path)
config = Config(config_path)
# train mode
if mode == 1:
config.MODE = 1
# test mode
elif mode == 2:
config.MODE = 2
return config
if __name__ == "__main__":
main()