Skip to content

Commit

Permalink
Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GF…
Browse files Browse the repository at this point in the history
…PGAN
  • Loading branch information
akx committed Dec 27, 2023
1 parent 6014c97 commit 37458e6
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 239 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ jobs:
run: |
wait-for-it --service 127.0.0.1:7860 -t 600
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
env:
IGNORE_CMD_ARGS_ERRORS: "1"
- name: Kill test server
if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ notification.mp3
/node_modules
/package-lock.json
/.coverage*
/test/test_outputs
158 changes: 40 additions & 118 deletions modules/codeformer_model.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,62 @@
import os
from __future__ import annotations

import logging

import cv2
import torch

import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader, errors
from modules.paths import models_path
from modules import (
devices,
errors,
face_restoration,
face_restoration_utils,
modelloader,
shared,
)

logger = logging.getLogger(__name__)

model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_download_name = 'codeformer-v0.1.0.pth'

codeformer = None
# used by e.g. postprocessing_codeformer.py
codeformer: face_restoration.FaceRestoration | None = None


class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
def name(self):
return "CodeFormer"

def __init__(self, dirname):
self.net = None
self.face_helper = None
self.cmd_dir = dirname

def create_models(self):
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper

if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
model_paths = modelloader.load_models(
model_path,
model_url,
self.cmd_dir,
download_name='codeformer-v0.1.0.pth',
def load_net(self) -> torch.Module:
for model_path in modelloader.load_models(
model_path=self.model_path,
model_url=model_url,
command_path=self.model_path,
download_name=model_download_name,
ext_filter=['.pth'],
)

if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)

if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer

face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=devices.device_codeformer,
)

self.net = net
self.face_helper = face_helper

def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)

def restore(self, np_image, w=None):
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor, tensor2img
np_image = np_image[:, :, ::-1]

original_resolution = np_image.shape[0:2]

self.create_models()
if self.net is None or self.face_helper is None:
return np_image

self.send_model_to(devices.device_codeformer)

self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()

for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)

try:
with torch.no_grad():
res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
if isinstance(res, tuple):
output = res[0]
else:
output = res
if not isinstance(res, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(res)}")
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)

self.face_helper.get_inverse_affine(None)

restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
).model
raise ValueError("No codeformer model found")

if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(
restored_img,
(0, 0),
fx=original_resolution[1]/restored_img.shape[1],
fy=original_resolution[0]/restored_img.shape[0],
interpolation=cv2.INTER_LINEAR,
)
def get_device(self):
return devices.device_codeformer

self.face_helper.clean_all()
def restore(self, np_image, w: float | None = None):
if w is None:
w = getattr(shared.opts, "code_former_weight", 0.5)

if shared.opts.face_restoration_unload:
self.send_model_to(devices.cpu)
def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, w=w, adain=True)[0]

return restored_img
return self.restore_with_helper(np_image, restore_face)


def setup_model(dirname):
os.makedirs(model_path, exist_ok=True)
def setup_model(dirname: str) -> None:
global codeformer
try:
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
Expand Down
163 changes: 163 additions & 0 deletions modules/face_restoration_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from __future__ import annotations

import logging
import os
from functools import cached_property
from typing import TYPE_CHECKING, Callable

import cv2
import numpy as np
import torch

from modules import devices, errors, face_restoration, shared

if TYPE_CHECKING:
from facexlib.utils.face_restoration_helper import FaceRestoreHelper

logger = logging.getLogger(__name__)


def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
if hasattr(retinaface, 'device'):
retinaface.device = device
return FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=device,
)


def restore_with_face_helper(
np_image: np.ndarray,
face_helper: FaceRestoreHelper,
restore_face: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
"""
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
`restore_face` should take a cropped face image and return a restored face image.
"""
from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]

try:
logger.debug("Detecting faces...")
face_helper.clean_all()
face_helper.read_image(np_image)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)

try:
with torch.no_grad():
restored_face = tensor2img(
restore_face(cropped_face_t),
rgb2bgr=True,
min_max=(-1, 1),
)
devices.torch_gc()
except Exception:
errors.report('Failed face-restoration inference', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)

logger.debug("Merging restored faces into image")
face_helper.get_inverse_affine(None)
img = face_helper.paste_faces_to_input_image()
img = img[:, :, ::-1]
if original_resolution != img.shape[0:2]:
img = cv2.resize(
img,
(0, 0),
fx=original_resolution[1] / img.shape[1],
fy=original_resolution[0] / img.shape[0],
interpolation=cv2.INTER_LINEAR,
)
logger.debug("Face restoration complete")
finally:
face_helper.clean_all()
return img


class CommonFaceRestoration(face_restoration.FaceRestoration):
net: torch.Module | None
model_url: str
model_download_name: str

def __init__(self, model_path: str):
super().__init__()
self.net = None
self.model_path = model_path
os.makedirs(model_path, exist_ok=True)

@cached_property
def face_helper(self) -> FaceRestoreHelper:
return create_face_helper(self.get_device())

def send_model_to(self, device):
if self.net:
logger.debug("Sending %s to %s", self.net, device)
self.net.to(device)
if self.face_helper:
logger.debug("Sending face helper to %s", device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)

def get_device(self):
raise NotImplementedError("get_device must be implemented by subclasses")

def load_net(self) -> torch.Module:
raise NotImplementedError("load_net must be implemented by subclasses")

def restore_with_helper(
self,
np_image: np.ndarray,
restore_face: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
try:
if self.net is None:
self.net = self.load_net()
except Exception:
logger.warning("Unable to load face-restoration model", exc_info=True)
return np_image

try:
self.send_model_to(self.get_device())
return restore_with_face_helper(np_image, self.face_helper, restore_face)
finally:
if shared.opts.face_restoration_unload:
self.send_model_to(devices.cpu)


def patch_facexlib(dirname: str) -> None:
import facexlib.detection
import facexlib.parsing

det_facex_load_file_from_url = facexlib.detection.load_file_from_url
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url

def update_kwargs(kwargs):
return dict(kwargs, save_dir=dirname, model_dir=None)

def facex_load_file_from_url(**kwargs):
return det_facex_load_file_from_url(**update_kwargs(kwargs))

def facex_load_file_from_url2(**kwargs):
return par_facex_load_file_from_url(**update_kwargs(kwargs))

facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
Loading

0 comments on commit 37458e6

Please sign in to comment.