Skip to content

Commit

Permalink
fix: 🐛 separate faceswap model load
Browse files Browse the repository at this point in the history
closes #5
  • Loading branch information
melMass committed Jul 2, 2023
1 parent f8dc768 commit 8e267c0
Showing 1 changed file with 48 additions and 37 deletions.
85 changes: 48 additions & 37 deletions nodes/faceswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,40 @@

# endregion

logger = mklog(__name__)
providers = onnxruntime.get_available_providers()
log = mklog(__name__)


class LoadFaceSwapModel:
"""Loads a faceswap model"""

@staticmethod
def get_models() -> List[Path]:
models_path = os.path.join(folder_paths.models_dir, "insightface/*")
models = glob.glob(models_path)
models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")]
return models

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"faceswap_model": (
[x.name for x in cls.get_models()],
{"default": "None"},
),
},
}

RETURN_TYPES = ("FACESWAP_MODEL",)
FUNCTION = "load_model"
CATEGORY = "face"

def load_model(self, faceswap_model: str):
model_path = os.path.join(
folder_paths.models_dir, "insightface", faceswap_model
)
log.info(f"Loading model {model_path}")
return (insightface.model_zoo.get_model(model_path),)


# region roop node
Expand All @@ -32,24 +64,14 @@ class FaceSwap:
def __init__(self) -> None:
pass

@staticmethod
def get_models() -> List[Path]:
models_path = os.path.join(folder_paths.models_dir, "insightface/*")
models = glob.glob(models_path)
models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")]
return models

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"reference": ("IMAGE",),
"faces_index": ("STRING", {"default": "0"}),
"faceswap_model": (
[x.name for x in cls.get_models()],
{"default": "None"},
),
"faceswap_model": ("FACESWAP_MODEL", {"default": "None"}),
},
"optional": {"debug": (["true", "false"], {"default": "false"})},
}
Expand All @@ -63,7 +85,7 @@ def swap(
image: torch.Tensor,
reference: torch.Tensor,
faces_index: str,
faceswap_model: str,
faceswap_model,
debug: str,
):
def do_swap(img):
Expand All @@ -72,13 +94,13 @@ def do_swap(img):
face_ids = {
int(x) for x in faces_index.strip(",").split(",") if x.isnumeric()
}
model = self.getFaceSwapModel(faceswap_model)
swapped = swap_face(ref, img, model, face_ids)

swapped = swap_face(ref, img, faceswap_model, face_ids)
return pil2tensor(swapped)

batch_count = image.size(0)

logger.info(f"Running insightface swap (batch size: {batch_count})")
log.info(f"Running insightface swap (batch size: {batch_count})")

if reference.size(0) != 1:
raise ValueError("Reference image must have batch size 1")
Expand All @@ -91,31 +113,20 @@ def do_swap(img):

return (image,)

def getFaceSwapModel(self, model_path: str):
model_path = os.path.join(folder_paths.models_dir, "insightface", model_path)
if self.model_path is None or self.model_path != model_path:
logger.info(f"Loading model {model_path}")
self.model_path = model_path
self.model = insightface.model_zoo.get_model(
model_path, providers=providers
)
else:
logger.info("Using cached model")

logger.info("Model loaded")
return self.model


# endregion


# region face swap utils
def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)):
face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers)
face_analyser = insightface.app.FaceAnalysis(
name="buffalo_l", root=os.path.join(folder_paths.models_dir, "insightface")
)
face_analyser.prepare(ctx_id=0, det_size=det_size)
face = face_analyser.get(img_data)

if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
log.debug("No face detected, trying again with smaller image")
det_size_half = (det_size[0] // 2, det_size[1] // 2)
return get_face_single(img_data, face_index=face_index, det_size=det_size_half)

Expand All @@ -139,7 +150,7 @@ def swap_face(
) -> Image.Image:
if faces_index is None:
faces_index = {0}
logger.info(f"Swapping faces: {faces_index}")
log.debug(f"Swapping faces: {faces_index}")
result_image = target_img
converted = convert_to_sd(target_img)
scale, fn = converted[0], converted[1]
Expand Down Expand Up @@ -167,17 +178,17 @@ def swap_face(
if target_face is not None:
result = face_swapper_model.get(result, target_face, source_face)
else:
logger.warning(f"No target face found for {face_num}")
log.warning(f"No target face found for {face_num}")

result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
else:
logger.warning("No source face found")
log.warning("No source face found")
else:
logger.error("No face swap model provided")
log.error("No face swap model provided")
return result_image


# endregion face swap utils


__nodes__ = [FaceSwap]
__nodes__ = [FaceSwap, LoadFaceSwapModel]

0 comments on commit 8e267c0

Please sign in to comment.