forked from mattans/AgeProgression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
157 lines (130 loc) · 6.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import model
import consts
import logging
import os
import re
import glob
import numpy as np
import argparse
import sys
import random
import datetime
import torch
from utils import *
from torchvision.datasets.folder import pil_loader
import gc
import torch
gc.collect()
#assert sys.version_info >= (3, 6),\
# "This script requires Python >= 3.6" # TODO 3.7?
#assert tuple(int(ver_num) for ver_num in torch.__version__.split('.')) >= (0, 4, 0),\
# "This script requires PyTorch >= 0.4.0" # TODO 0.4.1?
def str_to_gender(s):
s = str(s).lower()
if s in ('m', 'man', '0'):
return 0
elif s in ('f', 'female', '1'):
return 1
else:
raise KeyError("No gender found")
def str_to_bool(s):
s = s.lower()
if s in ('true', 't', 'yes', 'y', '1'):
return True
elif s in ('false', 'f', 'no', 'n', 'o'):
return False
else:
raise KeyError("Invalid boolean")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='AgeProgression on PyTorch.', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--mode', choices=['train', 'test'], default='train')
# train params
parser.add_argument('--epochs', '-e', default=1, type=int)
parser.add_argument(
'--models-saving',
'--ms',
dest='models_saving',
choices=('always', 'last', 'tail', 'never'),
default='always',
type=str,
help='Model saving preference.{br}'
'\talways: Save trained model at the end of every epoch (default){br}'
'\tUse this option if you have a lot of free memory and you wish to experiment with the progress of your results.{br}'
'\tlast: Save trained model only at the end of the last epoch{br}'
'\tUse this option if you don\'t have a lot of free memory and removing large binary files is a costly operation.{br}'
'\ttail: "Safe-last". Save trained model at the end of every epoch and remove the saved model of the previous epoch{br}'
'\tUse this option if you don\'t have a lot of free memory and removing large binary files is a cheap operation.{br}'
'\tnever: Don\'t save trained model{br}'
'\tUse this option if you only wish to collect statistics and validation results.{br}'
'All options except \'never\' will also save when interrupted by the user.'.format(br=os.linesep)
)
parser.add_argument('--batch-size', '--bs', dest='batch_size', default=64, type=int)
parser.add_argument('--weight-decay', '--wd', dest='weight_decay', default=1e-5, type=float)
parser.add_argument('--learning-rate', '--lr', dest='learning_rate', default=2e-4, type=float)
parser.add_argument('--b1', '-b', dest='b1', default=0.5, type=float)
parser.add_argument('--b2', '-B', dest='b2', default=0.999, type=float)
parser.add_argument('--shouldplot', '--sp', dest='sp', default=False, type=bool)
# test params
parser.add_argument('--age', '-a', required=False, type=int)
parser.add_argument('--gender', '-g', required=False, type=str_to_gender)
parser.add_argument('--watermark', '-w', action='store_true')
# shared params
parser.add_argument('--cpu', '-c', action='store_true', help='Run on CPU even if CUDA is available.')
parser.add_argument('--load', '-l', required=False, default=None, help='Trained models path for pre-training or for testing')
parser.add_argument('--input', '-i', default=None, help='Training dataset path (default is {}) or testing image path'.format(default_train_results_dir()))
parser.add_argument('--output', '-o', default='')
parser.add_argument('-z', dest='z_channels', default=50, type=int, help='Length of Z vector')
args = parser.parse_args()
consts.NUM_Z_CHANNELS = args.z_channels
net = model.Net()
if not args.cpu and torch.cuda.is_available():
print("Using CUDA")
net.cuda()
if args.mode == 'train':
betas = (args.b1, args.b2) if args.load is None else None
weight_decay = args.weight_decay if args.load is None else None
lr = args.learning_rate if args.load is None else None
if args.load is not None:
net.load(args.load)
print("Loading pre-trained models from {}".format(args.load))
data_src = args.input or consts.UTKFACE_DEFAULT_PATH
print("Data folder is {}".format(data_src))
results_dest = args.output or default_train_results_dir()
os.makedirs(results_dest, exist_ok=True)
print("Results folder is {}".format(results_dest))
with open(os.path.join(results_dest, 'session_arguments.txt'), 'w') as info_file:
info_file.write(' '.join(sys.argv))
log_path = os.path.join(results_dest, 'log_results.log')
if os.path.exists(log_path):
os.remove(log_path)
logging.basicConfig(filename=log_path, level=logging.DEBUG)
net.teach(
utkface_path=data_src,
batch_size=args.batch_size,
betas=betas,
epochs=args.epochs,
weight_decay=weight_decay,
lr=lr,
should_plot=args.sp,
where_to_save=results_dest,
models_saving=args.models_saving
)
elif args.mode == 'test':
if args.load is None:
raise RuntimeError("Must provide path of trained models")
net.load(path=args.load, slim=True)
results_dest = args.output or default_test_results_dir()
if not os.path.isdir(results_dest):
os.makedirs(results_dest)
input_images = glob.glob(args.input + "/*.jpg")
for filename in input_images:
basename = base_name = os.path.basename(filename)
image_tensor = pil_to_model_tensor_transform(pil_loader(filename)).to(net.device)
net.test_single(
image_tensor=image_tensor,
age=args.age,
gender=args.gender,
target=results_dest,
watermark=args.watermark,
filename=basename
)