Skip to content

Commit

Permalink
support alignment cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Apr 15, 2023
1 parent 213c8bb commit b30ca12
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion facexlib/alignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'awing_fan':
model = FAN(num_modules=4, num_landmarks=98)
model = FAN(num_modules=4, num_landmarks=98, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
Expand Down
19 changes: 10 additions & 9 deletions facexlib/alignment/awing_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ def forward(self, input_tensor, heatmap=None):
"""
batch_size_tensor = input_tensor.shape[0]

xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).cuda()
xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device)
xx_ones = xx_ones.unsqueeze(-1)

xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).cuda()
xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)
xx_range = xx_range.unsqueeze(1)

xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
xx_channel = xx_channel.unsqueeze(-1)

yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).cuda()
yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device)
yy_ones = yy_ones.unsqueeze(1)

yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).cuda()
yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)
yy_range = yy_range.unsqueeze(-1)

yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
Expand All @@ -93,8 +93,8 @@ def forward(self, input_tensor, heatmap=None):
xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor)
yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor)
if self.with_boundary and heatmap is not None:
xx_boundary_channel = xx_boundary_channel.cuda()
yy_boundary_channel = yy_boundary_channel.cuda()
xx_boundary_channel = xx_boundary_channel.to(input_tensor.device)
yy_boundary_channel = yy_boundary_channel.to(input_tensor.device)
ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)

if self.with_r:
Expand Down Expand Up @@ -268,8 +268,9 @@ def forward(self, x, heatmap):

class FAN(nn.Module):

def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68):
def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'):
super(FAN, self).__init__()
self.device = device
self.num_modules = num_modules
self.gray_scale = gray_scale
self.end_relu = end_relu
Expand Down Expand Up @@ -355,14 +356,14 @@ def forward(self, x):

return outputs, boundary_channels

def get_landmarks(self, img, device='cuda'):
def get_landmarks(self, img):
H, W, _ = img.shape
offset = W / 64, H / 64, 0, 0

img = cv2.resize(img, (256, 256))
inp = img[..., ::-1]
inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float()
inp = inp.to(device)
inp = inp.to(self.device)
inp.div_(255.0).unsqueeze_(0)

outputs, _ = self.forward(inp)
Expand Down
4 changes: 2 additions & 2 deletions inference/inference_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def main(args):
# initialize model
align_net = init_alignment_model(args.model_name)
align_net = init_alignment_model(args.model_name, device=args.device)

img = cv2.imread(args.img_path)
with torch.no_grad():
Expand All @@ -23,7 +23,7 @@ def main(args):
parser.add_argument('--img_path', type=str, default='assets/test2.jpg')
parser.add_argument('--save_path', type=str, default='test_alignment.png')
parser.add_argument('--model_name', type=str, default='awing_fan')
parser.add_argument('--half', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--to68', action='store_true')
args = parser.parse_args()

Expand Down

0 comments on commit b30ca12

Please sign in to comment.