This repository has been archived by the owner on Jan 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 64
/
options.py
120 lines (102 loc) · 6.49 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
""" Options
This script is largely based on junyanz/pytorch-CycleGAN-and-pix2pix.
Returns:
[argparse]: Class containing argparse
"""
import argparse
import os
import torch
# pylint: disable=C0103,C0301,R0903,W0622
class Options():
"""Options class
Returns:
[argparse]: argparse containing train and test options
"""
def __init__(self):
##
#
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
##
# Base
self.parser.add_argument('--dataset', default='cifar10', help='folder | cifar10 | mnist ')
self.parser.add_argument('--dataroot', default='', help='path to dataset')
self.parser.add_argument('--path', default='', help='path to the folder or image to be predicted.')
self.parser.add_argument('--batchsize', type=int, default=64, help='input batch size')
self.parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
self.parser.add_argument('--droplast', action='store_true', default=True, help='Drop last batch size.')
self.parser.add_argument('--isize', type=int, default=32, help='input image size.')
self.parser.add_argument('--nc', type=int, default=3, help='input image channels')
self.parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
self.parser.add_argument('--ngf', type=int, default=64)
self.parser.add_argument('--ndf', type=int, default=64)
self.parser.add_argument('--extralayers', type=int, default=0, help='Number of extra layers on gen and disc')
self.parser.add_argument('--device', type=str, default='gpu', help='Device: gpu | cpu')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')
self.parser.add_argument('--model', type=str, default='skipganomaly', help='chooses which model to use. ganomaly')
self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
self.parser.add_argument('--display', action='store_true', help='Use visdom.')
self.parser.add_argument('--verbose', action='store_true', help='Print the training and model details.')
self.parser.add_argument('--outf', default='./output', help='folder to output images and model checkpoints')
self.parser.add_argument('--manualseed', default=-1, type=int, help='manual seed')
self.parser.add_argument('--abnormal_class', default='automobile', help='Anomaly class idx for mnist and cifar datasets')
self.parser.add_argument('--metric', type=str, default='roc', help='Evaluation metric.')
##
# Train
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_image_freq', type=int, default=100, help='frequency of saving real and fake images')
self.parser.add_argument('--save_test_images', action='store_true', help='Save test images for demo.')
self.parser.add_argument('--load_weights', action='store_true', help='Load the pretrained weights')
self.parser.add_argument('--resume', default='', help="path to checkpoints (to continue training)")
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--iter', type=int, default=0, help='Start from iteration i')
self.parser.add_argument('--niter', type=int, default=15, help='number of epochs to train for')
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--w_adv', type=float, default=1, help='Weight for adversarial loss. default=1')
self.parser.add_argument('--w_con', type=float, default=50, help='Weight for reconstruction loss. default=50')
self.parser.add_argument('--w_lat', type=float, default=1, help='Weight for latent space loss. default=1')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
self.opt = None
def parse(self):
""" Parse Arguments.
"""
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
if self.opt.verbose:
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
if self.opt.name == 'experiment_name':
self.opt.name = "%s/%s" % (self.opt.model, self.opt.dataset)
expr_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
test_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
if not os.path.isdir(expr_dir):
os.makedirs(expr_dir)
if not os.path.isdir(test_dir):
os.makedirs(test_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt