-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoption1.py
47 lines (43 loc) · 2.01 KB
/
option1.py
1
import torch,os,sys,torchvision,argparseimport torch,warningswarnings.filterwarnings('ignore')##定义超参数parser=argparse.ArgumentParser()parser.add_argument('--epochs',type=int,default=200)#parser.add_argument('--device',type=str,default='Automatic detection')parser.add_argument('--device',type=str,default='cuda')parser.add_argument('--resume',type=bool,default=True)parser.add_argument('--eval_epoch',type=int,default=1)parser.add_argument('--eval_out',type=int,default=10)parser.add_argument('--lr', default=0.001, type=float, help='learning rate')parser.add_argument('--model_dir',type=str, default='./trained_models/UIE_PVtransformer_4/UIE_PVtransformer_epoch_200.pk')parser.add_argument('--trainset',type=str,default='its_train')parser.add_argument('--net',type=str,default='model_NEW')parser.add_argument('--blocks',type=int,default=5,help='residual_blocks')parser.add_argument('--bs',type=int,default=6,help='batch size')parser.add_argument('--checkpoint',type=str, default=r'E:\xzx\训练代码\trained_models/')parser.add_argument('--crop',action='store_true')parser.add_argument('--crop_size',type=int,default=256,help='Takes effect when using --crop ')parser.add_argument('--no_lr_sche',action='store_true',help='no lr cos schedule')parser.add_argument('--perloss',action='store_true',default=True,help='perceptual loss')parser.add_argument('--resize',type=int,default=256,help='resize dataset')opt= parser.parse_args(args=[])opt.device='cuda' if torch.cuda.is_available() else 'cpu'model_name=opt.netcheckpoint_save=opt.checkpoint+'/UIE_PVtransformer_4/'log_dir='logs/'+model_name+'_11_00'print(opt)##print('model_dir:',opt.model_dir)if not os.path.exists('trained_models'): os.mkdir('trained_models')if not os.path.exists('numpy_files'): os.mkdir('numpy_files')if not os.path.exists('logs'): os.mkdir('logs')if not os.path.exists('samples'): os.mkdir('samples')if not os.path.exists(f"samples/{model_name}"): os.mkdir(f'samples/{model_name}')if not os.path.exists(log_dir): os.mkdir(log_dir)