Skip to content

Commit

Permalink
fix: ✨ small edits
Browse files Browse the repository at this point in the history
  • Loading branch information
melMass committed Jun 28, 2023
1 parent 2dae020 commit bcf55ca
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 84 deletions.
75 changes: 50 additions & 25 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import traceback
from .log import log
from .log import log, blue_text, get_summary, get_label
from .utils import here
from pathlib import Path
import importlib
import os

NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS_DEBUG = {}
Expand All @@ -11,39 +11,64 @@
def load_nodes():
errors = []
nodes = []
for filename in (here / "nodes").iterdir():
if filename.suffix == '.py':
for filename in (here / "nodes").iterdir():
if filename.suffix == ".py":
module_name = filename.stem
module_path = filename.resolve().as_posix()


try:
module = importlib.import_module(f".nodes.{module_name}",package=__package__)
_nodes = getattr(module, '__nodes__')

module = importlib.import_module(
f".nodes.{module_name}", package=__package__
)
_nodes = getattr(module, "__nodes__")
nodes.extend(_nodes)
# Use the `nodes` variable here as needed
log.debug(f"Imported __nodes__ from {module_name}")


log.debug(f"Imported {module_name} nodes")

except AttributeError:
pass # wip nodes
except Exception:
error_message = traceback.format_exc().splitlines()[-1]
errors.append(f"Failed to import {module_name}. {error_message}")
# log.error(f"Failed to import {module_name}. {error_message}")
errors.append(f"Failed to import {module_name} because {error_message}")

if errors:
log.error(f"Some nodes failed to load:\n\t" + "\n\t".join(errors) + "\n\n" + "Check that you properly installed the dependencies.\n" + "If you think this is a bug, please report it on the github page (https://github.com/melMass/comfy_mtb/issues)")
log.error(
f"Some nodes failed to load:\n\t"
+ "\n\t".join(errors)
+ "\n\n"
+ "Check that you properly installed the dependencies.\n"
+ "If you think this is a bug, please report it on the github page (https://github.com/melMass/comfy_mtb/issues)"
)

return nodes
nodes = load_nodes()


# - REGISTER WEB EXTENSIONS
web_extensions_root = utils.comfy_dir / "web" / "extensions"
web_mtb = web_extensions_root / "mtb"

if web_mtb.exists():
log.debug(f"Web extensions folder found at {web_mtb}")
elif web_extensions_root.exists():
os.symlink((here / "web"), web_mtb.as_posix())
else:
log.error(
f"Comfy root probably not found automatically, please copy the folder {web_mtb} manually in the web/extensions folder of ComfyUI"
)

# - REGISTER NODES
nodes = load_nodes()
for node_class in nodes:
class_name = node_class.__name__
class_name = node_class.__name__

NODE_CLASS_MAPPINGS[class_name] = node_class
NODE_CLASS_MAPPINGS_DEBUG[class_name] = node_class.__doc__


def get_summary(docstring):
return docstring.strip().split('\n\n', 1)[0]

log.debug(f"Loaded the following nodes:\n\t" + "\n\t".join(f"{k}: {get_summary(doc) if doc else '-'}" for k,doc in NODE_CLASS_MAPPINGS_DEBUG.items()))
node_name = f"{get_label(class_name)} (mtb)"
NODE_CLASS_MAPPINGS[node_name] = node_class
NODE_CLASS_MAPPINGS_DEBUG[node_name] = node_class.__doc__


log.debug(
f"Loaded the following nodes:\n\t"
+ "\n\t".join(
f"{k}: {blue_text(get_summary(doc)) if doc else '-'}"
for k, doc in NODE_CLASS_MAPPINGS_DEBUG.items()
)
)
25 changes: 18 additions & 7 deletions log.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from logging import getLogger
import logging
import re

class Formatter(logging.Formatter):

class Formatter(logging.Formatter):
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
Expand All @@ -16,7 +16,7 @@ class Formatter(logging.Formatter):
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
logging.CRITICAL: bold_red + format + reset,
}

def format(self, record):
Expand All @@ -25,7 +25,6 @@ def format(self, record):
return formatter.format(record)



def mklog(name, level=logging.DEBUG):
logger = logging.getLogger(name)
logger.setLevel(level)
Expand All @@ -39,10 +38,22 @@ def mklog(name, level=logging.DEBUG):
return logger


#- The main app logger
# - The main app logger
log = mklog(__package__)



def log_user(arg):
print("\033[34mComfy MTB Utils:\033[0m {arg}")
print("\033[34mComfy MTB Utils:\033[0m {arg}")


def get_summary(docstring):
return docstring.strip().split("\n\n", 1)[0]


def blue_text(text):
return f"\033[94m{text}\033[0m"


def get_label(label):
words = re.findall(r"(?:^|[A-Z])[a-z]*", label)
return " ".join(words).strip()
16 changes: 8 additions & 8 deletions nodes/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@ def INPUT_TYPES(cls):
CATEGORY = "conditioning"

def do_step(self, step, start_percent, end_percent):

start = int(step * start_percent / 100)
end = int(step * end_percent / 100)

return (step, start, end)


def install_default_styles():
def install_default_styles(force=False):
styles_dir = Path(folder_paths.base_path) / "styles"
styles_dir.mkdir(parents=True, exist_ok=True)
default_style = here / "styles.csv"
dest_style = styles_dir / "default.csv"
log.debug("\n\n\n\tINSTALLING DEFAULT STYLE\n\n\n")
shutil.copy2(default_style.as_posix(), dest_style.as_posix())
log.debug("\n\n\n\tDEFAULT STYLE INSTALLED\n\n\n")

if force or not dest_style.exists():
log.debug(f"Copying default style to {dest_style}")
shutil.copy2(default_style.as_posix(), dest_style.as_posix())

return dest_style

Expand Down Expand Up @@ -101,6 +101,7 @@ def INPUT_TYPES(cls):
def load_style(self, style_name):
return (self.options[style_name][0], self.options[style_name][1])


class TextToImage:
"""Utils to convert text to image using a font
Expand All @@ -115,7 +116,6 @@ def __init__(self):

@classmethod
def INPUT_TYPES(cls):

fonts = list(Path(folder_paths.base_path).glob("**/*.ttf"))
if not fonts:
log.error(
Expand Down Expand Up @@ -157,11 +157,11 @@ def INPUT_TYPES(cls):
"color": (
"COLOR",
{"default": "black"},
),
),
"background": (
"COLOR",
{"default": "white"},
),
),
}
}

Expand Down
10 changes: 6 additions & 4 deletions nodes/deep_bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
# Disable MS telemetry
ort.disable_telemetry_events()


# - COLOR to NORMALS
def color_to_normals(color_img, overlap, progress_callback):
"""Computes a normal map from the given color map. 'color_img' must be a numpy array
in C,H,W format (with C as RGB). 'overlap' must be one of 'SMALL', 'MEDIUM', 'LARGE'."""
in C,H,W format (with C as RGB). 'overlap' must be one of 'SMALL', 'MEDIUM', 'LARGE'.
"""

# Remove alpha & convert to grayscale
img = np.mean(color_img[:3], axis=0, keepdimss=True)
Expand Down Expand Up @@ -237,6 +239,8 @@ def normals_to_height(normals_img, seamless, progress_callback):

# - ADDON
class DeepBump:
"""Normal & height maps generation from single pictures"""

def __init__(self):
pass

Expand Down Expand Up @@ -277,8 +281,6 @@ def apply(
normals_to_curvature_blur_radius="SMALL",
normals_to_height_seamless="TRUE",
):


image = utils_inference.tensor2pil(image)

in_img = np.transpose(image, (2, 0, 1)) / 255
Expand All @@ -302,4 +304,4 @@ def apply(
return (utils_inference.pil2tensor(out_img),)


__nodes__ = [DeepBump]
__nodes__ = [DeepBump]
50 changes: 27 additions & 23 deletions nodes/roop.py → nodes/faceswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@

from ..utils import pil2tensor, tensor2pil
from ..log import mklog

# endregion

logger = mklog(__name__)
providers = onnxruntime.get_available_providers()


# region roop node
class Roop:
class FaceSwap:
"""Face swap using deepinsight/insightface models"""

model = None
model_path = None

Expand All @@ -29,7 +34,7 @@ def __init__(self) -> None:

@staticmethod
def get_models() -> List[Path]:
models_path = os.path.join(folder_paths.models_dir, "roop/*")
models_path = os.path.join(folder_paths.models_dir, "insightface/*")
models = glob.glob(models_path)
models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")]
return models
Expand All @@ -41,53 +46,53 @@ def INPUT_TYPES(cls):
"image": ("IMAGE",),
"reference": ("IMAGE",),
"faces_index": ("STRING", {"default": "0"}),
"roop_model": ([x.name for x in cls.get_models()], {"default": "None"}),
},
"optional": {
"debug": (["true", "false"], {"default": "false"})

"faceswap_model": (
[x.name for x in cls.get_models()],
{"default": "None"},
),
},
"optional": {"debug": (["true", "false"], {"default": "false"})},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "swap"
CATEGORY = "image"
CATEGORY = "face"

def swap(
self,
image: torch.Tensor,
reference: torch.Tensor,
faces_index: str,
roop_model: str,
debug:str
faceswap_model: str,
debug: str,
):
def do_swap(img):
img = tensor2pil(img)
ref = tensor2pil(reference)
face_ids = {
int(x) for x in faces_index.strip(",").split(",") if x.isnumeric()
}
model = self.getFaceSwapModel(roop_model)
model = self.getFaceSwapModel(faceswap_model)
swapped = swap_face(ref, img, model, face_ids)
return pil2tensor(swapped)

batch_count = image.size(0)
logger.info(f"Running roop swap (batch size: {batch_count})")

logger.info(f"Running insightface swap (batch size: {batch_count})")

if reference.size(0) != 1:
raise ValueError("Reference image must have batch size 1")
if batch_count == 1:
image = do_swap(image)

else:
image = [do_swap(image[i]) for i in range(batch_count)]
image = torch.cat(image, dim=0)
image = [do_swap(image[i]) for i in range(batch_count)]
image = torch.cat(image, dim=0)

return (image,)

def getFaceSwapModel(self, model_path: str):
model_path = os.path.join(folder_paths.models_dir, "roop", model_path)
model_path = os.path.join(folder_paths.models_dir, "insightface", model_path)
if self.model_path is None or self.model_path != model_path:
logger.info(f"Loading model {model_path}")
self.model_path = model_path
Expand All @@ -103,6 +108,7 @@ def getFaceSwapModel(self, model_path: str):

# endregion


# region face swap utils
def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)):
face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers)
Expand Down Expand Up @@ -174,6 +180,4 @@ def swap_face(
# endregion face swap utils


__nodes__ = [
Roop
]
__nodes__ = [FaceSwap]
Loading

0 comments on commit bcf55ca

Please sign in to comment.