From 8e267c0204ce5abe8e113fd401234d49f377646a Mon Sep 17 00:00:00 2001 From: melMass Date: Sun, 2 Jul 2023 21:13:00 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=90=9B=20separate=20faceswap=20mod?= =?UTF-8?q?el=20load?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit closes #5 --- nodes/faceswap.py | 85 ++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/nodes/faceswap.py b/nodes/faceswap.py index 7b504fb..6e11779 100644 --- a/nodes/faceswap.py +++ b/nodes/faceswap.py @@ -18,8 +18,40 @@ # endregion -logger = mklog(__name__) -providers = onnxruntime.get_available_providers() +log = mklog(__name__) + + +class LoadFaceSwapModel: + """Loads a faceswap model""" + + @staticmethod + def get_models() -> List[Path]: + models_path = os.path.join(folder_paths.models_dir, "insightface/*") + models = glob.glob(models_path) + models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] + return models + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "faceswap_model": ( + [x.name for x in cls.get_models()], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("FACESWAP_MODEL",) + FUNCTION = "load_model" + CATEGORY = "face" + + def load_model(self, faceswap_model: str): + model_path = os.path.join( + folder_paths.models_dir, "insightface", faceswap_model + ) + log.info(f"Loading model {model_path}") + return (insightface.model_zoo.get_model(model_path),) # region roop node @@ -32,13 +64,6 @@ class FaceSwap: def __init__(self) -> None: pass - @staticmethod - def get_models() -> List[Path]: - models_path = os.path.join(folder_paths.models_dir, "insightface/*") - models = glob.glob(models_path) - models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] - return models - @classmethod def INPUT_TYPES(cls): return { @@ -46,10 +71,7 @@ def INPUT_TYPES(cls): "image": ("IMAGE",), "reference": ("IMAGE",), "faces_index": ("STRING", {"default": "0"}), - "faceswap_model": ( - [x.name for x in cls.get_models()], - {"default": "None"}, - ), + "faceswap_model": ("FACESWAP_MODEL", {"default": "None"}), }, "optional": {"debug": (["true", "false"], {"default": "false"})}, } @@ -63,7 +85,7 @@ def swap( image: torch.Tensor, reference: torch.Tensor, faces_index: str, - faceswap_model: str, + faceswap_model, debug: str, ): def do_swap(img): @@ -72,13 +94,13 @@ def do_swap(img): face_ids = { int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() } - model = self.getFaceSwapModel(faceswap_model) - swapped = swap_face(ref, img, model, face_ids) + + swapped = swap_face(ref, img, faceswap_model, face_ids) return pil2tensor(swapped) batch_count = image.size(0) - logger.info(f"Running insightface swap (batch size: {batch_count})") + log.info(f"Running insightface swap (batch size: {batch_count})") if reference.size(0) != 1: raise ValueError("Reference image must have batch size 1") @@ -91,31 +113,20 @@ def do_swap(img): return (image,) - def getFaceSwapModel(self, model_path: str): - model_path = os.path.join(folder_paths.models_dir, "insightface", model_path) - if self.model_path is None or self.model_path != model_path: - logger.info(f"Loading model {model_path}") - self.model_path = model_path - self.model = insightface.model_zoo.get_model( - model_path, providers=providers - ) - else: - logger.info("Using cached model") - - logger.info("Model loaded") - return self.model - # endregion # region face swap utils def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)): - face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers) + face_analyser = insightface.app.FaceAnalysis( + name="buffalo_l", root=os.path.join(folder_paths.models_dir, "insightface") + ) face_analyser.prepare(ctx_id=0, det_size=det_size) face = face_analyser.get(img_data) if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + log.debug("No face detected, trying again with smaller image") det_size_half = (det_size[0] // 2, det_size[1] // 2) return get_face_single(img_data, face_index=face_index, det_size=det_size_half) @@ -139,7 +150,7 @@ def swap_face( ) -> Image.Image: if faces_index is None: faces_index = {0} - logger.info(f"Swapping faces: {faces_index}") + log.debug(f"Swapping faces: {faces_index}") result_image = target_img converted = convert_to_sd(target_img) scale, fn = converted[0], converted[1] @@ -167,17 +178,17 @@ def swap_face( if target_face is not None: result = face_swapper_model.get(result, target_face, source_face) else: - logger.warning(f"No target face found for {face_num}") + log.warning(f"No target face found for {face_num}") result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) else: - logger.warning("No source face found") + log.warning("No source face found") else: - logger.error("No face swap model provided") + log.error("No face swap model provided") return result_image # endregion face swap utils -__nodes__ = [FaceSwap] +__nodes__ = [FaceSwap, LoadFaceSwapModel]