-
Notifications
You must be signed in to change notification settings - Fork 22
/
my_args.py
123 lines (98 loc) · 6.49 KB
/
my_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import datetime
import argparse
import numpy
import networks
import torch
modelnames = networks.__all__
import datasets_benchmark
datasetNames = datasets_benchmark.__all__
parser = argparse.ArgumentParser(description='MEMC-Net')
parser.add_argument('--debug',action = 'store_true', help='Enable debug mode')
parser.add_argument('--netName', type=str, default='MEMC_Net',
choices = modelnames,help = 'model architecture: ' +
' | '.join(modelnames) +
' (default: MEMC_Net)')
parser.add_argument('--datasetName', default='Vimeo_90K_interp',
choices= datasetNames,nargs='+',
help='dataset type : ' +
' | '.join(datasetNames) +
' (default: Vimeo_90K_interp)')
parser.add_argument('--datasetPath',nargs='+', default='',help = 'the path of selected datasets')
parser.add_argument('--dataset_split', type = int, default=97, help = 'Split a dataset into trainining and validation by percentage (default: 97)')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--numEpoch', '-e', type = int, default=100, help= 'Number of epochs to train(default:150)')
parser.add_argument('--batch_size', '-b',type = int ,default=1, help = 'batch size (default:1)' )
parser.add_argument('--workers', '-w', type =int,default=8, help = 'parallel workers for loading training samples (default : 1.6*10 = 16)')
parser.add_argument('--channels', '-c', type=int,default=3,choices = [1,3], help ='channels of images (default:3)')
parser.add_argument('--filter_size', '-f', type=int, default=4, help = 'the size of filters used (default: 4)',
choices=[2,4,6, 5,51]
)
parser.add_argument('--task', type=str,choices=['sr','denoise','deblock'], help = 'select a tast to train for (default:sr)')
parser.add_argument('--task_param', type=float,nargs = '+', default= [4.0], help = 'the task parameters such as sr ratio, denoise variance, salt&pepper ratio')
# parser.add_argument('--lr', type =float, default= 0.002, help= 'the basic learning rate for three subnetworks (default: 0.002)')
parser.add_argument('--save_which', '-s', type=int, default=1, choices=[0,1], help='choose which result to save: 0 ==> interpolated, 1==> rectified')
# parser.add_argument('--flow_lr_coe', type = float, default=0.01, help = 'relative learning rate w.r.t basic learning rate (default: 0.01)')
# parser.add_argument('--occ_lr_coe', type = float, default=1.0, help = 'relative learning rate w.r.t basic learning rate (default: 1.0)')
# parser.add_argument('--filter_lr_coe', type = float, default=1.0, help = 'relative learning rate w.r.t basic learning rate (default: 1.0)')
# parser.add_argument('--keepRectifyWeights', type=int, default=1, choices=[1,0], help='whether to keep the weights of RectifyNet or not: 1 ==> dropped, 1==> kept')
#
# parser.add_argument('--use_negPSNR', action='store_true', help ='whether to use negPSNR as loss to replace L1-norm loss (default:false)' )
# parser.add_argument('--alpha', type=float,nargs='+', default=[1.0, 0.0], help= 'the ration of loss for interpolated and rectified result (default: [1.0, 0.0])')
# parser.add_argument('--lambda1', type = float,nargs='+', default=[0.0], help = 'regularize the total variation of flow')
# parser.add_argument('--lambda2', type = float,nargs='+', default=[0.0], help = 'regularize the sum of two occlusion maps ')
# parser.add_argument('--lambda3', type = float,nargs='+', default=[0.0], help = 'regularize the symmetry of two estimated flow')
#
# parser.add_argument('--epsilon', type = float, default=1e-6, help = 'the epsilon for charbonier loss,etc (default: 1e-6)')
# parser.add_argument('--weight_decay', type = float, default=0, help = 'the weight decay for whole network ' )
# parser.add_argument('--patience', type=int, default=5, help = 'the patience of reduce on plateou')
# parser.add_argument('--factor', type = float, default=0.2, help = 'the factor of reduce on plateou')
parser.add_argument('--pretrained', dest='SAVED_MODEL', default='MEMC-Net_best.pth', help ='path to the pretrained model weights')
parser.add_argument('--no-date', action='store_true', help='don\'t append date timestamp to folder' )
parser.add_argument('--use_cuda', default= True, type = bool, help='use cuda or not')
parser.add_argument('--use_cudnn',default=1,type=int, help = 'use cudnn or not')
# parser.add_argument('--nocudnn', dest='use_cudnn', default=)
parser.add_argument('--dtype', default=torch.cuda.FloatTensor, choices = [torch.cuda.FloatTensor,torch.FloatTensor],help = 'tensor data type ')
# parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)')
parser.add_argument('--uid', type=str, default= None, help='unique id for the training')
parser.add_argument('--force', action='store_true', help='force to override the given uid')
args = parser.parse_args()
if args.uid == None:
unique_id = str(numpy.random.randint(0, 100000))
print("revise the unique id to a random numer " + str(unique_id))
args.uid = unique_id
timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H:%M")
save_path = './model_weights/'+ args.uid +'-' + timestamp
else:
save_path = './model_weights/'+ str(args.uid)
import shutil
# print("no pth here : " + save_path + "/best"+".pth")
if not os.path.exists(save_path + "/best"+".pth"):
# print("no pth here : " + save_path + "/best" + ".pth")
os.makedirs(save_path,exist_ok=True)
else:
if not args.force:
raise("please use another uid ")
else:
print("override this uid" + args.uid)
for m in range(1,10):
if not os.path.exists(save_path+"/log.txt.bk" + str(m)):
shutil.copy(save_path+"/log.txt", save_path+"/log.txt.bk"+str(m))
shutil.copy(save_path+"/args.txt", save_path+"/args.txt.bk"+str(m))
break
parser.add_argument('--save_path',default=save_path,help = 'the output dir of weights')
parser.add_argument('--log', default = save_path+'/log.txt', help = 'the log file in training')
parser.add_argument('--arg', default = save_path+'/args.txt', help = 'the args used')
args = parser.parse_args()
with open(args.log, 'w') as f:
f.close()
with open(args.arg, 'w') as f:
print(args)
print(args,file=f)
f.close()
if args.use_cudnn:
print("cudnn is used")
torch.backends.cudnn.benchmark = True # to speed up the
else:
print("cudnn is not used")
torch.backends.cudnn.benchmark = False # to speed up the