From 6a9faad10296085a537fbcc670b8ff1205522f4a Mon Sep 17 00:00:00 2001 From: Chengkun Cao Date: Fri, 1 Jul 2022 11:09:03 +0800 Subject: [PATCH] add custom device support for RetinaFace class in detection --- facexlib/detection/__init__.py | 4 ++-- facexlib/detection/retinaface.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/facexlib/detection/__init__.py b/facexlib/detection/__init__.py index ce867f8..b7b9366 100644 --- a/facexlib/detection/__init__.py +++ b/facexlib/detection/__init__.py @@ -7,10 +7,10 @@ def init_detection_model(model_name, half=False, device='cuda'): if model_name == 'retinaface_resnet50': - model = RetinaFace(network_name='resnet50', half=half) + model = RetinaFace(network_name='resnet50', half=half, device=device) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth' elif model_name == 'retinaface_mobile0.25': - model = RetinaFace(network_name='mobile0.25', half=half) + model = RetinaFace(network_name='mobile0.25', half=half, device=device) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') diff --git a/facexlib/detection/retinaface.py b/facexlib/detection/retinaface.py index 8f6adc5..cd1ab41 100644 --- a/facexlib/detection/retinaface.py +++ b/facexlib/detection/retinaface.py @@ -11,8 +11,6 @@ from facexlib.detection.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, py_cpu_nms) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - def generate_config(network_name): @@ -72,7 +70,9 @@ def generate_config(network_name): class RetinaFace(nn.Module): - def __init__(self, network_name='resnet50', half=False, phase='test'): + def __init__(self, network_name='resnet50', half=False, phase='test', device=None): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + super(RetinaFace, self).__init__() self.half_inference = half cfg = generate_config(network_name) @@ -83,7 +83,7 @@ def __init__(self, network_name='resnet50', half=False, phase='test'): self.phase = phase self.target_size, self.max_size = 1600, 2150 self.resize, self.scale, self.scale1 = 1., None, None - self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) + self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=self.device) self.reference = get_reference_facial_points(default_square=True) # Build network. backbone = None @@ -112,7 +112,7 @@ def __init__(self, network_name='resnet50', half=False, phase='test'): self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) - self.to(device) + self.to(self.device) self.eval() if self.half_inference: self.half() @@ -145,19 +145,19 @@ def forward(self, inputs): def __detect_faces(self, inputs): # get scale height, width = inputs.shape[2:] - self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device) tmp = [width, height, width, height, width, height, width, height, width, height] - self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device) # forawrd - inputs = inputs.to(device) + inputs = inputs.to(self.device) if self.half_inference: inputs = inputs.half() loc, conf, landmarks = self(inputs) # get priorbox priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) - priors = priorbox.forward().to(device) + priors = priorbox.forward().to(self.device) return loc, conf, landmarks, priors @@ -197,7 +197,7 @@ def detect_faces( use_origin_size=True, ): image, self.resize = self.transform(image, use_origin_size) - image = image.to(device) + image = image.to(self.device) if self.half_inference: image = image.half() image = image - self.mean_tensor @@ -316,7 +316,7 @@ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, us """ # self.t['forward_pass'].tic() frames, self.resize = self.batched_transform(frames, use_origin_size) - frames = frames.to(device) + frames = frames.to(self.device) frames = frames - self.mean_tensor b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)