Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom device support for RetinaFace class in detection #19

Merged
merged 1 commit into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions facexlib/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

def init_detection_model(model_name, half=False, device='cuda'):
if model_name == 'retinaface_resnet50':
model = RetinaFace(network_name='resnet50', half=half)
model = RetinaFace(network_name='resnet50', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
elif model_name == 'retinaface_mobile0.25':
model = RetinaFace(network_name='mobile0.25', half=half)
model = RetinaFace(network_name='mobile0.25', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
Expand Down
22 changes: 11 additions & 11 deletions facexlib/detection/retinaface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from facexlib.detection.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
py_cpu_nms)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def generate_config(network_name):

Expand Down Expand Up @@ -72,7 +70,9 @@ def generate_config(network_name):

class RetinaFace(nn.Module):

def __init__(self, network_name='resnet50', half=False, phase='test'):
def __init__(self, network_name='resnet50', half=False, phase='test', device=None):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device

super(RetinaFace, self).__init__()
self.half_inference = half
cfg = generate_config(network_name)
Expand All @@ -83,7 +83,7 @@ def __init__(self, network_name='resnet50', half=False, phase='test'):
self.phase = phase
self.target_size, self.max_size = 1600, 2150
self.resize, self.scale, self.scale1 = 1., None, None
self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device)
self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=self.device)
self.reference = get_reference_facial_points(default_square=True)
# Build network.
backbone = None
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(self, network_name='resnet50', half=False, phase='test'):
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])

self.to(device)
self.to(self.device)
self.eval()
if self.half_inference:
self.half()
Expand Down Expand Up @@ -145,19 +145,19 @@ def forward(self, inputs):
def __detect_faces(self, inputs):
# get scale
height, width = inputs.shape[2:]
self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device)
tmp = [width, height, width, height, width, height, width, height, width, height]
self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)

# forawrd
inputs = inputs.to(device)
inputs = inputs.to(self.device)
if self.half_inference:
inputs = inputs.half()
loc, conf, landmarks = self(inputs)

# get priorbox
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
priors = priorbox.forward().to(device)
priors = priorbox.forward().to(self.device)

return loc, conf, landmarks, priors

Expand Down Expand Up @@ -197,7 +197,7 @@ def detect_faces(
use_origin_size=True,
):
image, self.resize = self.transform(image, use_origin_size)
image = image.to(device)
image = image.to(self.device)
if self.half_inference:
image = image.half()
image = image - self.mean_tensor
Expand Down Expand Up @@ -316,7 +316,7 @@ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, us
"""
# self.t['forward_pass'].tic()
frames, self.resize = self.batched_transform(frames, use_origin_size)
frames = frames.to(device)
frames = frames.to(self.device)
frames = frames - self.mean_tensor

b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
Expand Down