Skip to content

Commit

Permalink
feat: ✨ WIP batch from history
Browse files Browse the repository at this point in the history
Barely tested, and not much safe guards (I need to analyze more
history sessions first, but the basic idea is there)
  • Loading branch information
melMass committed Jul 16, 2023
1 parent 38f6147 commit cde7293
Showing 1 changed file with 95 additions and 4 deletions.
99 changes: 95 additions & 4 deletions nodes/image_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,94 @@
import numpy as np
import subprocess
import comfy

from PIL import Image
import urllib.request
import urllib.parse
import json
import tensorflow as tf
import comfy.model_management as model_management
import io

from comfy.cli_args import args
from ..utils import pil2tensor


def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(
"http://{}:{}/view?{}".format(args.listen, args.port, url_values)
) as response:
return io.BytesIO(response.read())


class GetBatchFromHistory:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"enable": ("BOOL", {"default": True}),
"count": ("INT", {"default": 1, "min": 0}),
"offset": ("INT", {"default": 0, "min": -1e9, "max": 1e9}),
},
}

RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = "images"
CATEGORY = "animation"
FUNCTION = "load_from_history"

def load_from_history(
self,
enable=True,
count=0,
offset=0,
):
if not enable or count == 0:
return (torch.zeros(0),)
frames = []
server_address = "localhost:3000"

with urllib.request.urlopen(
"http://{}/history".format(server_address)
) as response:
history = json.loads(response.read())

output_images = []
for k, run in history.items():
for o in run["outputs"]:
for node_id in run["outputs"]:
node_output = run["outputs"][node_id]
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = get_image(
image["filename"], image["subfolder"], image["type"]
)
images_output.append(image_data)
output_images.extend(images_output)
if len(output_images) == 0:
return (torch.zeros(0),)
for i, image in enumerate(list(reversed(output_images))):
if i < offset:
continue
if i >= offset + count:
break
# Decode image as tensor
img = Image.open(image)
log.debug(f"Image from history {i} of shape {img.size}")
frames.append(img)

# Display the shape of the tensor
# print("Tensor shape:", image_tensor.shape)

# return (output_images,)

output = pil2tensor(
list(reversed(frames)),
)

return (output,)


class LoadFilmModel:
Expand Down Expand Up @@ -160,7 +245,7 @@ def concat_images(self, imageA: torch.Tensor, imageB: torch.Tensor):
return (self.concatenate_tensors(imageA, imageB),)


class ExportToProRes:
class ExportToProres:
"""Export to ProRes 4444 (Experimental)"""

def __init__(self):
Expand Down Expand Up @@ -196,7 +281,7 @@ def export_prores(
log.debug(f"Exporting to {output_dir / id}")

frames = tensor2np(images)
log.debug(f"Frames type {type(frames)}")
log.debug(f"Frames type {type(frames[0])}")
log.debug(f"Exporting {len(frames)} frames")

frames = [frame.astype(np.uint16) * 257 for frame in frames]
Expand Down Expand Up @@ -245,4 +330,10 @@ def export_prores(
return (out_path,)


__nodes__ = [LoadFilmModel, FilmInterpolation, ExportToProRes, ConcatImages]
__nodes__ = [
LoadFilmModel,
FilmInterpolation,
ExportToProres,
ConcatImages,
GetBatchFromHistory,
]

0 comments on commit cde7293

Please sign in to comment.