diff --git a/__init__.py b/__init__.py index 1d3644c..550f7f1 100644 --- a/__init__.py +++ b/__init__.py @@ -1,8 +1,8 @@ import traceback -from .log import log +from .log import log, blue_text, get_summary, get_label from .utils import here -from pathlib import Path import importlib +import os NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS_DEBUG = {} @@ -11,39 +11,64 @@ def load_nodes(): errors = [] nodes = [] - for filename in (here / "nodes").iterdir(): - if filename.suffix == '.py': + for filename in (here / "nodes").iterdir(): + if filename.suffix == ".py": module_name = filename.stem - module_path = filename.resolve().as_posix() - + try: - module = importlib.import_module(f".nodes.{module_name}",package=__package__) - _nodes = getattr(module, '__nodes__') - + module = importlib.import_module( + f".nodes.{module_name}", package=__package__ + ) + _nodes = getattr(module, "__nodes__") nodes.extend(_nodes) - # Use the `nodes` variable here as needed - log.debug(f"Imported __nodes__ from {module_name}") - + + log.debug(f"Imported {module_name} nodes") + + except AttributeError: + pass # wip nodes except Exception: error_message = traceback.format_exc().splitlines()[-1] - errors.append(f"Failed to import {module_name}. {error_message}") - # log.error(f"Failed to import {module_name}. {error_message}") + errors.append(f"Failed to import {module_name} because {error_message}") if errors: - log.error(f"Some nodes failed to load:\n\t" + "\n\t".join(errors) + "\n\n" + "Check that you properly installed the dependencies.\n" + "If you think this is a bug, please report it on the github page (https://github.com/melMass/comfy_mtb/issues)") + log.error( + f"Some nodes failed to load:\n\t" + + "\n\t".join(errors) + + "\n\n" + + "Check that you properly installed the dependencies.\n" + + "If you think this is a bug, please report it on the github page (https://github.com/melMass/comfy_mtb/issues)" + ) return nodes -nodes = load_nodes() + +# - REGISTER WEB EXTENSIONS +web_extensions_root = utils.comfy_dir / "web" / "extensions" +web_mtb = web_extensions_root / "mtb" + +if web_mtb.exists(): + log.debug(f"Web extensions folder found at {web_mtb}") +elif web_extensions_root.exists(): + os.symlink((here / "web"), web_mtb.as_posix()) +else: + log.error( + f"Comfy root probably not found automatically, please copy the folder {web_mtb} manually in the web/extensions folder of ComfyUI" + ) + +# - REGISTER NODES +nodes = load_nodes() for node_class in nodes: class_name = node_class.__name__ class_name = node_class.__name__ - - NODE_CLASS_MAPPINGS[class_name] = node_class - NODE_CLASS_MAPPINGS_DEBUG[class_name] = node_class.__doc__ - - -def get_summary(docstring): - return docstring.strip().split('\n\n', 1)[0] - -log.debug(f"Loaded the following nodes:\n\t" + "\n\t".join(f"{k}: {get_summary(doc) if doc else '-'}" for k,doc in NODE_CLASS_MAPPINGS_DEBUG.items())) + node_name = f"{get_label(class_name)} (mtb)" + NODE_CLASS_MAPPINGS[node_name] = node_class + NODE_CLASS_MAPPINGS_DEBUG[node_name] = node_class.__doc__ + + +log.debug( + f"Loaded the following nodes:\n\t" + + "\n\t".join( + f"{k}: {blue_text(get_summary(doc)) if doc else '-'}" + for k, doc in NODE_CLASS_MAPPINGS_DEBUG.items() + ) +) diff --git a/log.py b/log.py index 02133cd..1924bcc 100644 --- a/log.py +++ b/log.py @@ -1,8 +1,8 @@ -from logging import getLogger import logging +import re -class Formatter(logging.Formatter): +class Formatter(logging.Formatter): grey = "\x1b[38;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" @@ -16,7 +16,7 @@ class Formatter(logging.Formatter): logging.INFO: grey + format + reset, logging.WARNING: yellow + format + reset, logging.ERROR: red + format + reset, - logging.CRITICAL: bold_red + format + reset + logging.CRITICAL: bold_red + format + reset, } def format(self, record): @@ -25,7 +25,6 @@ def format(self, record): return formatter.format(record) - def mklog(name, level=logging.DEBUG): logger = logging.getLogger(name) logger.setLevel(level) @@ -39,10 +38,22 @@ def mklog(name, level=logging.DEBUG): return logger -#- The main app logger +# - The main app logger log = mklog(__package__) - def log_user(arg): - print("\033[34mComfy MTB Utils:\033[0m {arg}") \ No newline at end of file + print("\033[34mComfy MTB Utils:\033[0m {arg}") + + +def get_summary(docstring): + return docstring.strip().split("\n\n", 1)[0] + + +def blue_text(text): + return f"\033[94m{text}\033[0m" + + +def get_label(label): + words = re.findall(r"(?:^|[A-Z])[a-z]*", label) + return " ".join(words).strip() diff --git a/nodes/conditions.py b/nodes/conditions.py index bc5bfbb..287fb7a 100644 --- a/nodes/conditions.py +++ b/nodes/conditions.py @@ -38,21 +38,21 @@ def INPUT_TYPES(cls): CATEGORY = "conditioning" def do_step(self, step, start_percent, end_percent): - start = int(step * start_percent / 100) end = int(step * end_percent / 100) return (step, start, end) -def install_default_styles(): +def install_default_styles(force=False): styles_dir = Path(folder_paths.base_path) / "styles" styles_dir.mkdir(parents=True, exist_ok=True) default_style = here / "styles.csv" dest_style = styles_dir / "default.csv" - log.debug("\n\n\n\tINSTALLING DEFAULT STYLE\n\n\n") - shutil.copy2(default_style.as_posix(), dest_style.as_posix()) - log.debug("\n\n\n\tDEFAULT STYLE INSTALLED\n\n\n") + + if force or not dest_style.exists(): + log.debug(f"Copying default style to {dest_style}") + shutil.copy2(default_style.as_posix(), dest_style.as_posix()) return dest_style @@ -101,6 +101,7 @@ def INPUT_TYPES(cls): def load_style(self, style_name): return (self.options[style_name][0], self.options[style_name][1]) + class TextToImage: """Utils to convert text to image using a font @@ -115,7 +116,6 @@ def __init__(self): @classmethod def INPUT_TYPES(cls): - fonts = list(Path(folder_paths.base_path).glob("**/*.ttf")) if not fonts: log.error( @@ -157,11 +157,11 @@ def INPUT_TYPES(cls): "color": ( "COLOR", {"default": "black"}, - ), + ), "background": ( "COLOR", {"default": "white"}, - ), + ), } } diff --git a/nodes/deep_bump.py b/nodes/deep_bump.py index fefcd1e..f9ece06 100644 --- a/nodes/deep_bump.py +++ b/nodes/deep_bump.py @@ -9,10 +9,12 @@ # Disable MS telemetry ort.disable_telemetry_events() + # - COLOR to NORMALS def color_to_normals(color_img, overlap, progress_callback): """Computes a normal map from the given color map. 'color_img' must be a numpy array - in C,H,W format (with C as RGB). 'overlap' must be one of 'SMALL', 'MEDIUM', 'LARGE'.""" + in C,H,W format (with C as RGB). 'overlap' must be one of 'SMALL', 'MEDIUM', 'LARGE'. + """ # Remove alpha & convert to grayscale img = np.mean(color_img[:3], axis=0, keepdimss=True) @@ -237,6 +239,8 @@ def normals_to_height(normals_img, seamless, progress_callback): # - ADDON class DeepBump: + """Normal & height maps generation from single pictures""" + def __init__(self): pass @@ -277,8 +281,6 @@ def apply( normals_to_curvature_blur_radius="SMALL", normals_to_height_seamless="TRUE", ): - - image = utils_inference.tensor2pil(image) in_img = np.transpose(image, (2, 0, 1)) / 255 @@ -302,4 +304,4 @@ def apply( return (utils_inference.pil2tensor(out_img),) -__nodes__ = [DeepBump] \ No newline at end of file +__nodes__ = [DeepBump] diff --git a/nodes/roop.py b/nodes/faceswap.py similarity index 86% rename from nodes/roop.py rename to nodes/faceswap.py index 32c1fca..7b504fb 100644 --- a/nodes/roop.py +++ b/nodes/faceswap.py @@ -15,12 +15,17 @@ from ..utils import pil2tensor, tensor2pil from ..log import mklog + # endregion logger = mklog(__name__) providers = onnxruntime.get_available_providers() + + # region roop node -class Roop: +class FaceSwap: + """Face swap using deepinsight/insightface models""" + model = None model_path = None @@ -29,7 +34,7 @@ def __init__(self) -> None: @staticmethod def get_models() -> List[Path]: - models_path = os.path.join(folder_paths.models_dir, "roop/*") + models_path = os.path.join(folder_paths.models_dir, "insightface/*") models = glob.glob(models_path) models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] return models @@ -41,25 +46,25 @@ def INPUT_TYPES(cls): "image": ("IMAGE",), "reference": ("IMAGE",), "faces_index": ("STRING", {"default": "0"}), - "roop_model": ([x.name for x in cls.get_models()], {"default": "None"}), - }, - "optional": { - "debug": (["true", "false"], {"default": "false"}) - + "faceswap_model": ( + [x.name for x in cls.get_models()], + {"default": "None"}, + ), }, + "optional": {"debug": (["true", "false"], {"default": "false"})}, } RETURN_TYPES = ("IMAGE",) FUNCTION = "swap" - CATEGORY = "image" + CATEGORY = "face" def swap( self, image: torch.Tensor, reference: torch.Tensor, faces_index: str, - roop_model: str, - debug:str + faceswap_model: str, + debug: str, ): def do_swap(img): img = tensor2pil(img) @@ -67,27 +72,27 @@ def do_swap(img): face_ids = { int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() } - model = self.getFaceSwapModel(roop_model) + model = self.getFaceSwapModel(faceswap_model) swapped = swap_face(ref, img, model, face_ids) return pil2tensor(swapped) - + batch_count = image.size(0) - - logger.info(f"Running roop swap (batch size: {batch_count})") - + + logger.info(f"Running insightface swap (batch size: {batch_count})") + if reference.size(0) != 1: raise ValueError("Reference image must have batch size 1") if batch_count == 1: image = do_swap(image) - + else: - image = [do_swap(image[i]) for i in range(batch_count)] - image = torch.cat(image, dim=0) - + image = [do_swap(image[i]) for i in range(batch_count)] + image = torch.cat(image, dim=0) + return (image,) def getFaceSwapModel(self, model_path: str): - model_path = os.path.join(folder_paths.models_dir, "roop", model_path) + model_path = os.path.join(folder_paths.models_dir, "insightface", model_path) if self.model_path is None or self.model_path != model_path: logger.info(f"Loading model {model_path}") self.model_path = model_path @@ -103,6 +108,7 @@ def getFaceSwapModel(self, model_path: str): # endregion + # region face swap utils def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)): face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers) @@ -174,6 +180,4 @@ def swap_face( # endregion face swap utils -__nodes__ = [ - Roop -] \ No newline at end of file +__nodes__ = [FaceSwap] diff --git a/nodes/fun.py b/nodes/fun.py index 040c1ee..a53341e 100644 --- a/nodes/fun.py +++ b/nodes/fun.py @@ -1,7 +1,11 @@ import qrcode -from ..utils import pil2tensor, tensor2pil +from ..utils import pil2tensor from PIL import Image -class QRNode: + + +class QrCode: + """Basic QR Code generator""" + def __init__(self): pass @@ -29,8 +33,7 @@ def INPUT_TYPES(cls): FUNCTION = "do_qr" CATEGORY = "fun" - def do_qr(self, url, width, height,error_correct, box_size, border,invert): - + def do_qr(self, url, width, height, error_correct, box_size, border, invert): if error_correct == "L" or error_correct not in ["M", "Q", "H"]: error_correct = qrcode.constants.ERROR_CORRECT_L elif error_correct == "M": @@ -39,7 +42,7 @@ def do_qr(self, url, width, height,error_correct, box_size, border,invert): error_correct = qrcode.constants.ERROR_CORRECT_Q else: error_correct = qrcode.constants.ERROR_CORRECT_H - + qr = qrcode.QRCode( version=1, error_correction=error_correct, @@ -51,13 +54,13 @@ def do_qr(self, url, width, height,error_correct, box_size, border,invert): back_color = (255, 255, 255) if invert == "True" else (0, 0, 0) fill_color = (0, 0, 0) if invert == "True" else (255, 255, 255) - + code = img = qr.make_image(back_color=back_color, fill_color=fill_color) # that we now resize without filtering code = code.resize((width, height), Image.NEAREST) return (pil2tensor(code),) - - -__nodes__ = [QRNode] \ No newline at end of file + + +__nodes__ = [QrCode] diff --git a/nodes/image_processing.py b/nodes/image_processing.py index fc91498..deb83de 100644 --- a/nodes/image_processing.py +++ b/nodes/image_processing.py @@ -22,6 +22,8 @@ class ColorCorrect: + """Various color correction methods""" + def __init__(self): pass @@ -178,7 +180,9 @@ def correct( return (image,) -class HSVtoRGB: +class HsvToRgb: + """Convert HSV image to RGB""" + def __init__(self): pass @@ -206,7 +210,9 @@ def convert(self, image): return (torch.from_numpy(image),) -class RGBtoHSV: +class RgbToHsv: + """Convert RGB image to HSV""" + def __init__(self): pass @@ -233,6 +239,8 @@ def convert(self, image): class ImageCompare: + """Compare two images and return a difference image""" + def __init__(self): pass @@ -267,6 +275,8 @@ def compare(self, imageA: torch.Tensor, imageB: torch.Tensor, mode): class Denoise: + """Denoise an image using total variation minimization.""" + def __init__(self): pass @@ -296,6 +306,8 @@ def denoise(self, image: torch.Tensor, weight): class Blur: + """Blur an image using a Gaussian filter.""" + def __init__(self): pass @@ -338,6 +350,8 @@ def deglaze_np_img(np_img): class DeglazeImage: + """Remove adversarial noise from images""" + @classmethod def INPUT_TYPES(cls): return {"required": {"image": ("IMAGE",)}} @@ -352,6 +366,8 @@ def deglaze_image(self, image): class MaskToImage: + """Converts a mask (alpha) to an RGB image with a color and background""" + def __init__(self): pass @@ -391,6 +407,8 @@ def render_mask(self, mask, color, background): class ColoredImage: + """Constant color image of given size""" + def __init__(self) -> None: pass @@ -419,6 +437,8 @@ def render_img(self, color, width, height): class ImagePremultiply: + """Premultiply image with mask""" + def __init__(self): pass @@ -579,6 +599,8 @@ def resize( class SaveImageGrid: + """Save all the images in the input batch as a grid of images.""" + def __init__(self): self.output_dir = folder_paths.get_output_directory() self.type = "output" @@ -681,8 +703,8 @@ def save_images( __nodes__ = [ ColorCorrect, - HSVtoRGB, - RGBtoHSV, + HsvToRgb, + RgbToHsv, ImageCompare, Denoise, Blur, diff --git a/scripts/download_models.py b/scripts/download_models.py index 54a8f01..3de0441 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -17,14 +17,14 @@ "download_url": "https://github.com/HugoTini/DeepBump/raw/master/deepbump256.onnx", "destination": "deepbump", }, - "Roop": { + "Face Swap": { "size": 660, "download_url": [ "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth", "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", "https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx", ], - "destination": "roop", + "destination": "insightface", }, } @@ -35,7 +35,6 @@ def download_model(download_url, destination): - if isinstance(download_url, list): for url in download_url: download_model(url, destination) @@ -99,7 +98,6 @@ def main(models_to_download): models_to_download_selected = {} def check_destination(urls, destination): - if isinstance(urls, list): for url in urls: check_destination(url, destination)