Skip to content

Commit

Permalink
chore: 🔖 local updates
Browse files Browse the repository at this point in the history
- black -> Ruff
- wip nodes (Curve, FilterZ, Plot Batch Floats)
  • Loading branch information
melMass committed Dec 25, 2023
1 parent 90f3bc2 commit 6c5e5d3
Show file tree
Hide file tree
Showing 11 changed files with 617 additions and 50 deletions.
1 change: 1 addition & 0 deletions nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""MTB Nodes module."""
2 changes: 1 addition & 1 deletion nodes/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class AnimationBuilder:
"""Convenient way to manage basic animation maths at the core of many of my workflows"""
"""Simple maths for animation."""

@classmethod
def INPUT_TYPES(cls):
Expand Down
181 changes: 166 additions & 15 deletions nodes/batch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import math
import os
from pathlib import Path
from typing import List
from io import BytesIO

import cv2
import folder_paths
import numpy as np
import torch
from PIL import Image

from ..log import log
from ..utils import apply_easing, pil2tensor
Expand Down Expand Up @@ -95,7 +92,9 @@ def generate_shapes(
res = []
for x in range(count):
# Initialize an image canvas
canvas = np.full((image_height, image_width, 3), bg_color, dtype=np.uint8)
canvas = np.full(
(image_height, image_width, 3), bg_color, dtype=np.uint8
)
mask = np.zeros((image_height, image_width), dtype=np.uint8)

# Compute the center point of the shape
Expand Down Expand Up @@ -125,7 +124,9 @@ def generate_shapes(
# Apply shading effects to a separate shading canvas
shading = np.zeros_like(canvas, dtype=np.float32)
shading[:, :, 0] = shadex * np.linspace(0, 1, image_width)
shading[:, :, 1] = shadey * np.linspace(0, 1, image_height).reshape(-1, 1)
shading[:, :, 1] = shadey * np.linspace(
0, 1, image_height
).reshape(-1, 1)
shading_canvas = cv2.addWeighted(
canvas.astype(np.float32), 1, shading, 1, 0
).astype(np.uint8)
Expand Down Expand Up @@ -158,7 +159,9 @@ def INPUT_TYPES(cls):
def fill_floats(self, floats, direction, value, count):
size = len(floats)
if size > count:
raise ValueError(f"Size ({size}) is less then target count ({count})")
raise ValueError(
f"Size ({size}) is less then target count ({count})"
)

rem = count - size
if direction == "tail":
Expand Down Expand Up @@ -261,7 +264,10 @@ class BatchMerge:
def INPUT_TYPES(cls):
return {
"required": {
"fusion_mode": (["add", "multiply", "average"], {"default": "average"}),
"fusion_mode": (
["add", "multiply", "average"],
{"default": "average"},
),
"fill": (["head", "tail"], {"default": "tail"}),
}
}
Expand All @@ -279,7 +285,9 @@ def merge_batches(self, fusion_mode, fill, **kwargs):
frame_count = img.shape[0]
if frame_count < max_frames:
fill_frame = img[0] if fill == "head" else img[-1]
fill_frames = fill_frame.repeat(max_frames - frame_count, 1, 1, 1)
fill_frames = fill_frame.repeat(
max_frames - frame_count, 1, 1, 1
)
adjusted_batch = (
torch.cat((fill_frames, img), dim=0)
if fill == "head"
Expand Down Expand Up @@ -352,9 +360,12 @@ def transform_batch(
shear=None,
):
if all(
self.get_num_elements(param) <= 0 for param in [x, y, zoom, angle, shear]
self.get_num_elements(param) <= 0
for param in [x, y, zoom, angle, shear]
):
raise ValueError("At least one transform parameter must be provided")
raise ValueError(
"At least one transform parameter must be provided"
)

keyframes = {"x": [], "y": [], "zoom": [], "angle": [], "shear": []}

Expand Down Expand Up @@ -397,6 +408,138 @@ def transform_batch(
return (torch.cat(res, dim=0),)


class PlotBatchFloat:
"""Plot floats"""

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"width": ("INT", {"default": 768}),
"height": ("INT", {"default": 768}),
"point_size": ("INT", {"default": 4}),
"seed": ("INT", {"default": 1}),
}
}

RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("plot",)
FUNCTION = "plot"
CATEGORY = "mtb/batch"

def plot(self, width, height, point_size, seed, **kwargs):
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
fig.set_edgecolor("black")
fig.patch.set_facecolor("#2e2e2e")
# Setting background color and grid
ax.set_facecolor("#2e2e2e") # Dark gray background
ax.grid(color="gray", linestyle="-", linewidth=0.5, alpha=0.5)

# Finding global min and max across all lists for scaling the plot
global_min = min(min(values) for values in kwargs.values())
global_max = max(max(values) for values in kwargs.values())

# Color cycle to ensure each plot has a distinct color
colormap = plt.cm.get_cmap("viridis", len(kwargs))
color_normalization_factor = (
0.5 if len(kwargs) == 1 else (len(kwargs) - 1)
)

# Plotting each list with a unique color
for i, (label, values) in enumerate(kwargs.items()):
color_value = i / color_normalization_factor
ax.plot(values, label=label, color=colormap(color_value))

ax.set_ylim(global_min, global_max) # Scaling the y-axis
ax.legend(
title="Legend",
title_fontsize="large",
fontsize="medium",
edgecolor="black",
)

# Setting labels and title
ax.set_xlabel("Time", fontsize="large", color="white")
ax.set_ylabel("Value", fontsize="large", color="white")
ax.set_title(
"Plot of Values over Time", fontsize="x-large", color="white"
)

# Adjusting tick colors to be visible on dark background
ax.tick_params(colors="white")

# Changing color of the axes border
for _, spine in ax.spines.items():
spine.set_edgecolor("white")

# Rendering the plot into a NumPy array
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
image = Image.open(buf)
plt.close(fig) # Closing the figure to free up memory

return (pil2tensor(image),)

def draw_point(self, image, point, color, point_size):
x, y = point
y = image.shape[0] - 1 - y # Invert Y-coordinate
half_size = point_size // 2
x_start, x_end = (
max(0, x - half_size),
min(image.shape[1], x + half_size + 1),
)
y_start, y_end = (
max(0, y - half_size),
min(image.shape[0], y + half_size + 1),
)
image[y_start:y_end, x_start:x_end] = color

def draw_line(self, image, start, end, color):
x1, y1 = start
x2, y2 = end

# Invert Y-coordinate
y1 = image.shape[0] - 1 - y1
y2 = image.shape[0] - 1 - y2

dx = x2 - x1
dy = y2 - y1
is_steep = abs(dy) > abs(dx)
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
dx = x2 - x1
dy = y2 - y1
error = int(dx / 2.0)
y = y1
ystep = None
if y1 < y2:
ystep = 1
else:
ystep = -1
for x in range(x1, x2 + 1):
coord = (y, x) if is_steep else (x, y)
image[coord] = color
error -= abs(dy)
if error < 0:
y += ystep
error += dx
if swapped:
image[(x1, y1)] = color
image[(x2, y2)] = color


DEFAULT_INTERPOLANT = lambda t: t * t * t * (t * (t * 6 - 15) + 10)


Expand Down Expand Up @@ -452,7 +595,9 @@ def generate_perlin_noise_2d(
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = (
np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0)
np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(
1, 2, 0
)
% 1
)
# Gradients
Expand All @@ -471,7 +616,9 @@ def generate_perlin_noise_2d(
n00 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1])) * g00, 2)
n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
n11 = np.sum(
np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2
)
# Interpolation
t = interpolant(grid)
n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
Expand Down Expand Up @@ -519,7 +666,10 @@ def generate_fractal_noise_2d(
amplitude = 1
for _ in range(octaves):
noise += amplitude * self.generate_perlin_noise_2d(
shape, (frequency * res[0], frequency * res[1]), tileable, interpolant
shape,
(frequency * res[0], frequency * res[1]),
tileable,
interpolant,
)
frequency *= lacunarity
amplitude *= persistence
Expand Down Expand Up @@ -622,4 +772,5 @@ def apply_shake(
BatchFloatFill,
BatchMerge,
BatchShake,
PlotBatchFloat,
]
36 changes: 36 additions & 0 deletions nodes/curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json


def deserialize_curve(curve):
if isinstance(curve, str):
curve = json.loads(curve)
return curve


def serialize_curve(curve):
if not isinstance(curve, str):
curve = json.dumps(curve)
return curve


class MTB_Curve:
"""A basic FLOAT_CURVE input node."""

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"curve": ("FLOAT_CURVE",),
},
}

RETURN_TYPES = ("FLOAT_CURVE",)
FUNCTION = "do_curve"

CATEGORY = "mtb/curve"

def do_curve(self, curve):
return (curve,)


__nodes__ = [MTB_Curve]
41 changes: 29 additions & 12 deletions nodes/faceenhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def get_models(cls):
log.warning("Face restoration models not found.")
return []
if not fr_models_path.exists():
log.warning(
f"No Face Restore checkpoints found at {fr_models_path} (if you've used mtb before these checkpoints were saved in upscale_models before)"
)
log.warning(
"For now we fallback to upscale_models but this will be removed in a future version"
)
# log.warning(
# f"No Face Restore checkpoints found at {fr_models_path} (if you've used mtb before these checkpoints were saved in upscale_models before)"
# )
# log.warning(
# "For now we fallback to upscale_models but this will be removed in a future version"
# )
if um_models_path.exists():
return [
x
Expand Down Expand Up @@ -98,7 +98,9 @@ def load_model(self, model_name, upscale=2, bg_upsampler=None):
(fr_root if fr_root.exists() else um_root) / model_name
).as_posix(),
upscale=upscale,
arch="clean" if basic else "RestoreFormer", # or original for v1.0 only
arch="clean"
if basic
else "RestoreFormer", # or original for v1.0 only
channel_multiplier=2, # 1 for v1.0 only
bg_upsampler=bg_upsampler,
)
Expand All @@ -122,7 +124,11 @@ def enhance(self, img: Image.Image, outscale=2):
imgt = imgt.movedim(-1, -3).to(device)

steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps(
imgt.shape[3], imgt.shape[2], tile_x=tile, tile_y=tile, overlap=overlap
imgt.shape[3],
imgt.shape[2],
tile_x=tile,
tile_y=tile,
overlap=overlap,
)

log.debug(f"Steps: {steps}")
Expand Down Expand Up @@ -199,10 +205,14 @@ def do_restore(
log.warning(f"Weight value has no effect for now. (value: {weight})")

if save_tmp_steps:
self.save_intermediate_images(cropped_faces, restored_faces, height, width)
self.save_intermediate_images(
cropped_faces, restored_faces, height, width
)
output = None
if restored_img is not None:
output = Image.fromarray(cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB))
output = Image.fromarray(
cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
)
# imwrite(restored_img, save_restore_path)

return pil2tensor(output)
Expand All @@ -218,7 +228,12 @@ def restore(
) -> Tuple[torch.Tensor]:
out = [
self.do_restore(
image[i], model, aligned, only_center_face, weight, save_tmp_steps
image[i],
model,
aligned,
only_center_face,
weight,
save_tmp_steps,
)
for i in range(image.size(0))
]
Expand All @@ -240,7 +255,9 @@ def get_step_image_path(self, step, idx):

return os.path.join(full_output_folder, file)

def save_intermediate_images(self, cropped_faces, restored_faces, height, width):
def save_intermediate_images(
self, cropped_faces, restored_faces, height, width
):
for idx, (cropped_face, restored_face) in enumerate(
zip(cropped_faces, restored_faces)
):
Expand Down
Loading

0 comments on commit 6c5e5d3

Please sign in to comment.