Skip to content

Commit

Permalink
fix: ✨ refactor existing
Browse files Browse the repository at this point in the history
  • Loading branch information
melMass committed Aug 10, 2023
1 parent 2eccba4 commit 1144466
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 20 deletions.
1 change: 1 addition & 0 deletions nodes/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def get_image(filename, subfolder, folder_type):
log.debug(f"Getting image {filename} from {subfolder} of {folder_type}")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(
Expand Down
70 changes: 50 additions & 20 deletions nodes/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch
import torchvision.transforms.functional as F
from ..utils import log, hex_to_rgb, tensor2pil, pil2tensor
from math import sqrt, ceil
from typing import cast
from PIL import Image


class TransformImage:
Expand All @@ -14,15 +18,19 @@ def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"x": ("FLOAT", {"default": 0}),
"y": ("FLOAT", {"default": 0}),
"zoom": ("FLOAT", {"default": 1.0, "min": 0.001}),
"angle": ("FLOAT", {"default": 0}),
"shear": ("FLOAT", {"default": 0}),
"x": ("FLOAT", {"default": 0, "step": 1, "min": -4096, "max": 4096}),
"y": ("FLOAT", {"default": 0, "step": 1, "min": -4096, "max": 4096}),
"zoom": ("FLOAT", {"default": 1.0, "min": 0.001, "step": 0.01}),
"angle": ("FLOAT", {"default": 0, "step": 1, "min": -360, "max": 360}),
"shear": (
"FLOAT",
{"default": 0, "step": 1, "min": -4096, "max": 4096},
),
"border_handling": (
["edge", "constant", "reflect", "symmetric"],
{"default": "edge"},
),
"constant_color": ("COLOR", {"default": "black"}),
},
}

Expand All @@ -36,45 +44,67 @@ def transform(
x: float,
y: float,
zoom: float,
angle: int,
shear,
angle: float,
shear: float,
border_handling="edge",
constant_color=None,
):
x = int(x)
y = int(y)
angle = int(angle)

log.debug(f"Zoom: {zoom} | x: {x}, y: {y}, angle: {angle}, shear: {shear}")

if image.size(0) == 0:
return (torch.zeros(0),)
transformed_images = []
frames_count, frame_height, frame_width, frame_channel_count = image.size()

new_height, new_width = int(frame_height * zoom), int(frame_width * zoom)

log.debug(f"New height: {new_height}, New width: {new_width}")

# - Calculate diagonal of the original image
diagonal = sqrt(frame_width**2 + frame_height**2)
max_padding = ceil(diagonal * zoom - min(frame_width, frame_height))
# Calculate padding for zoom
pw = int(frame_width - new_width)
ph = int(frame_height - new_height)
padding = [max(0, pw + x), max(0, ph + y), max(0, pw - x), max(0, ph - y)]

for img in image:
img = img.permute(2, 0, 1)
new_height, new_width = int(frame_height * zoom), int(frame_width * zoom)
pw = int(frame_width - new_width)
ph = int(frame_height - new_height)
pw += max_padding
ph += max_padding

padding = [max(0, pw + x), max(0, ph + y), max(0, pw - x), max(0, ph - y)]

padding = [int(i) for i in padding]
constant_color = hex_to_rgb(constant_color)
log.debug(f"Fill Tuple: {constant_color}")

for img in tensor2pil(image):
img = F.pad(
img, # transformed_frame,
padding=padding,
padding_mode=border_handling,
fill=constant_color or 0,
)

img = F.affine(img, angle=angle, scale=zoom, translate=[x, y], shear=shear)
img = cast(
Image.Image,
F.affine(img, angle=angle, scale=zoom, translate=[x, y], shear=shear),
)

crop = [ph + y, -(ph - y), x + pw, -(pw - x)]
left = abs(padding[0])
upper = abs(padding[1])
right = img.width - abs(padding[2])
bottom = img.height - abs(padding[3])

img = img[:, crop[0] : crop[1], crop[2] : crop[3]]
# log.debug("crop is [:,top:bottom, left:right] for tensors")
log.debug("crop is [left, top, right, bottom] for PIL")
log.debug(f"crop is {left}, {upper}, {right}, {bottom}")
img = img.crop((left, upper, right, bottom))

img = img.permute(1, 2, 0)
transformed_images.append(img.unsqueeze(0))
transformed_images.append(img)

return (torch.cat(transformed_images, dim=0),)
return (pil2tensor(transformed_images),)


__nodes__ = [TransformImage]
5 changes: 5 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
comfy_mode = "venv"


def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip("#")
return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))


# region MISC Utilities
def add_path(path, prepend=False):
if isinstance(path, list):
Expand Down

0 comments on commit 1144466

Please sign in to comment.