Skip to content

Commit

Permalink
support on CPU mode for FaceRestoreHelper (#17)
Browse files Browse the repository at this point in the history
Co-authored-by: 钟长鸿 <zhongchanghong@adtiger.hk>
  • Loading branch information
longredzhong and 钟长鸿 authored Jun 22, 2022
1 parent c6c5aa6 commit 4f7e851
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions facexlib/utils/face_restoration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ def __init__(self,
self.pad_input_imgs = []

if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = device
self.device = device

# init face detection model
self.face_det = init_detection_model(det_model, half=False, device=device)
self.face_det = init_detection_model(det_model, half=False, device=self.device)

# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(model_name='parsenet', device=device)
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)

def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
Expand Down Expand Up @@ -303,7 +303,7 @@ def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).cuda()
face_input = torch.unsqueeze(face_input, 0).to(self.device)
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()
Expand Down

0 comments on commit 4f7e851

Please sign in to comment.