diff --git a/inference_gfpgan_full.py b/inference_gfpgan_full.py index 1b764157..07b3d5a8 100644 --- a/inference_gfpgan_full.py +++ b/inference_gfpgan_full.py @@ -19,7 +19,8 @@ def restoration(gfpgan, has_aligned=False, only_center_face=True, suffix=None, - paste_back=False): + paste_back=False, + device='cuda'): # read image img_name = os.path.basename(img_path) print(f'Processing {img_name} ...') @@ -43,7 +44,7 @@ def restoration(gfpgan, # prepare data cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda') + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) try: with torch.no_grad(): @@ -77,17 +78,18 @@ def restoration(gfpgan, if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + parser = argparse.ArgumentParser() - parser.add_argument('--upscale_factor', type=int, default=1) + parser.add_argument('--upscale_factor', type=int, default=2) parser.add_argument('--arch', type=str, default='clean') parser.add_argument('--channel', type=int, default=2) - parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth') + parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth') parser.add_argument('--test_path', type=str, default='inputs/whole_imgs') parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') parser.add_argument('--only_center_face', action='store_true') parser.add_argument('--aligned', action='store_true') - parser.add_argument('--paste_back', action='store_true') + parser.add_argument('--paste_back', action='store_false') parser.add_argument('--save_root', type=str, default='results') args = parser.parse_args() @@ -123,14 +125,17 @@ def restoration(gfpgan, narrow=1, sft_half=True) - gfpgan.to(device) - checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) - gfpgan.load_state_dict(checkpoint['params_ema']) - gfpgan.eval() + gfpgan.load_state_dict(torch.load(args.model_path, map_location=lambda storage, loc: storage)['params_ema']) + gfpgan.to(device).eval() # initialize face helper face_helper = FaceRestoreHelper( - args.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png') + args.upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + device=device) img_list = sorted(glob.glob(os.path.join(args.test_path, '*'))) for img_path in img_list: @@ -142,6 +147,7 @@ def restoration(gfpgan, has_aligned=args.aligned, only_center_face=args.only_center_face, suffix=args.suffix, - paste_back=args.paste_back) + paste_back=args.paste_back, + device=device) print(f'Results are in the [{args.save_root}] folder.')