Skip to content

Commit

Permalink
refactor: ⚡️ small local fixes
Browse files Browse the repository at this point in the history
made while writting wiki
  • Loading branch information
melMass committed Oct 21, 2023
1 parent 049983d commit bcac665
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 5 deletions.
7 changes: 7 additions & 0 deletions errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class ModelNotFound(Exception):
def __init__(self, model_name, *args, **kwargs):
super().__init__(
f"The model {model_name} could not be found, make sure to download it using ComfyManager first.\nrepository: https://github.com/ltdrdata/ComfyUI-Manager",
*args,
**kwargs,
)
1 change: 0 additions & 1 deletion nodes/deep_bump.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import tempfile
from pathlib import Path

import folder_paths
import numpy as np
import onnxruntime as ort
import torch
Expand Down
7 changes: 3 additions & 4 deletions nodes/faceenhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import folder_paths
import numpy as np
import torch
from basicsr.utils import imwrite
from comfy import model_management
from gfpgan import GFPGANer
from PIL import Image
Expand Down Expand Up @@ -247,16 +246,16 @@ def save_intermediate_images(self, cropped_faces, restored_faces, height, width)
):
face_id = idx + 1
file = self.get_step_image_path("cropped_faces", face_id)
imwrite(cropped_face, file)
cv2.imwrite(file, cropped_face)

file = self.get_step_image_path("cropped_faces_restored", face_id)
imwrite(restored_face, file)
cv2.imwrite(file, restored_face)

file = self.get_step_image_path("cropped_faces_compare", face_id)

# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, file)
cv2.imwrite(file, cmp_img)


__nodes__ = [RestoreFace, LoadFaceEnhanceModel]
85 changes: 85 additions & 0 deletions nodes/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@
# log.warning("cv2.ximgproc.guidedFilter not found, use opencv-contrib-python")


def gaussian_kernel(kernel_size: int, sigma_x: float, sigma_y: float, device=None):
x, y = torch.meshgrid(
torch.linspace(-1, 1, kernel_size, device=device),
torch.linspace(-1, 1, kernel_size, device=device),
indexing="ij",
)
d_x = x * x / (2.0 * sigma_x * sigma_x)
d_y = y * y / (2.0 * sigma_y * sigma_y)
g = torch.exp(-(d_x + d_y))
return g / g.sum()


class ColorCorrect:
"""Various color correction methods"""

Expand Down Expand Up @@ -273,6 +285,78 @@ def blur(self, image: torch.Tensor, sigmaX, sigmaY):
return (torch.from_numpy(image),)


class Sharpen_:
"""Sharpens an image using a Gaussian kernel."""

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": (
"INT",
{"default": 1, "min": 1, "max": 31, "step": 1},
),
"sigma_x": (
"FLOAT",
{"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1},
),
"sigma_y": (
"FLOAT",
{"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1},
),
"alpha": (
"FLOAT",
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1},
),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "do_sharp"
CATEGORY = "mtb/image processing"

def do_sharp(
self,
image: torch.Tensor,
sharpen_radius: int,
sigma_x: float,
sigma_y: float,
alpha: float,
):
if sharpen_radius == 0:
return (image,)

channels = image.shape[3]

kernel_size = 2 * sharpen_radius + 1
kernel = gaussian_kernel(kernel_size, sigma_x, sigma_y) * -(alpha * 10)

# Modify center of kernel to make it a sharpening kernel
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0

kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2)

tensor_image = F.pad(
tensor_image,
(sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius),
"reflect",
)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)

# Remove padding
sharpened = sharpened[
:, :, sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius
]

sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)

return (result,)


# https://github.com/lllyasviel/AdverseCleaner/blob/main/clean.py
# def deglaze_np_img(np_img):
# y = np_img.copy()
Expand Down Expand Up @@ -702,4 +786,5 @@ def tile_image(self, image: torch.Tensor, tiles: int = 2):
ImageResizeFactor,
SaveImageGrid_,
LoadImageFromUrl_,
Sharpen_,
]

0 comments on commit bcac665

Please sign in to comment.