-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathDemo_Student.py
128 lines (103 loc) · 4.76 KB
/
Demo_Student.py
1
import torchfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderimport torchvisionimport torch.nn.functional as Fimport torch.optim as optimfrom dataset_loader import MyData, MyTestDatafrom model import FocalNet, FocalNet_subfrom conv_lstm import ConvLSTMfrom functions import imsaveimport argparsefrom Trainer_Student import Trainerfrom resnet_18 import Resnet_18import osimport imageioif __name__ == '__main__': configurations = { 1: dict( max_iteration=300000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=10000, nclass=2, sshow=10, focal_num=12, ) } parser=argparse.ArgumentParser() parser.add_argument('--phase', type=str, default='test', help='train or test') parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters') parser.add_argument('--train_dataroot', type=str, default='H:\Light-Field\\train_data', help='path to train data') parser.add_argument('--test_dataroot', type=str, default='H:\Light-Field\\test_data', help='path to test data') parser.add_argument('--snapshot_root', type=str, default='./snapshot_Student', help='path to snapshot') parser.add_argument('--salmap_Studentroot', type=str, default='./sal_Student', help='path to Student saliency map') parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys()) args = parser.parse_args() cfg = configurations[args.config] cuda = torch.cuda.is_available """""""""""~~~ dataset loader ~~~""""""""" train_dataRoot = args.train_dataroot test_dataRoot = args.test_dataroot if not os.path.exists(args.snapshot_root): os.mkdir(args.snapshot_root) if not os.path.exists(args.salmap_Studentroot): os.mkdir(args.salmap_Studentroot) if args.phase == 'train': SnapRoot = args.snapshot_root # checkpoint train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), batch_size=1, shuffle=False, num_workers=4, pin_memory=True) else: MapStudentRoot = args.salmap_Studentroot test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), batch_size=1, shuffle=True, num_workers=4, pin_memory=True) print ('data already') """"""""""" ~~~nets~~~ """"""""" start_epoch = 0 start_iteration = 0 Student = Resnet_18() model_focal = FocalNet(cfg['nclass'],refine=True) model_focal_sub = FocalNet_sub(cfg['nclass']) model_clstm = ConvLSTM(input_channels=64, hidden_channels=[64, 32, 64], kernel_size=5, step=4, effective_step=[2, 4, 8]) if args.param is True: Student.load_state_dict(torch.load(os.path.join('./snapshot_Student', 'student_snapshot_iter_292000.pth'))) else: model_focal.load_state_dict(torch.load(os.path.join('./snapshot_Teacher', 'focal_snapshot_iter_500000.pth'))) model_clstm.load_state_dict(torch.load(os.path.join('./snapshot_Teacher', 'clstm_snapshot_iter_500000.pth'))) model_focal_sub.load_state_dict(torch.load(os.path.join('./snapshot_Teacher', 'focal_sub_snapshot_iter_500000.pth'))) if cuda: Student = Student.cuda() model_focal = model_focal.cuda() model_focal_sub = model_focal_sub.cuda() model_clstm = model_clstm.cuda() if args.phase == 'train': # Trainer: class, defined in trainer.py optimizer_student = optim.SGD(Student.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) training = Trainer( cuda=cuda, student = Student, model_focal=model_focal, model_focal_sub = model_focal_sub, model_clstm=model_clstm, optimizer_student=optimizer_student, train_loader= train_loader, max_iter=cfg['max_iteration'], snapshot=cfg['spshot'], outpath=args.snapshot_root, sshow=cfg['sshow'] ) training.epoch = start_epoch training.iteration = start_iteration training.train() else: for id, (data, focal, img_name, img_size) in enumerate(test_loader): print('testing bach %d' % id) inputs = Variable(data).cuda() outputs_Student, focus_Student = Student(inputs) outputs_Student = F.softmax(outputs_Student, dim=1) outputs_Student = outputs_Student[0][1] outputs_Student = outputs_Student.cpu().data.resize_(img_size) imsave(os.path.join(MapStudentRoot, img_name[0] + '.png'), outputs_Student, img_size) torch.cuda.empty_cache()