Skip to content

Commit

Permalink
feat: ✨ add "To Device"
Browse files Browse the repository at this point in the history
To send an image or mask tensor to an available device.
Supports "cpu", "cuda", and "mps"
  • Loading branch information
melMass committed Mar 4, 2024
1 parent b7c8582 commit c28181f
Showing 1 changed file with 67 additions and 8 deletions.
75 changes: 67 additions & 8 deletions nodes/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import io, json, urllib.parse, urllib.request

import numpy as np
Expand All @@ -22,10 +23,58 @@ def get_image(filename, subfolder, folder_type):
return io.BytesIO(response.read())


class MTB_ToDevice:
"""Send a image or mask tensor to the given device."""

@classmethod
def INPUT_TYPES(cls):
devices = ["cpu"]
if torch.backends.mps.is_available():
devices.append("mps")
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
devices.append(f"cuda{i}")

return {
"required": {
"ignore_errors": ("BOOLEAN", {"default": False}),
"device": (devices, {"default": "cpu"}),
},
"optional": {
"image": ("IMAGE",),
"mask": ("MASK",),
},
}

RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("images", "masks")
CATEGORY = "mtb/utils"
FUNCTION = "to_device"

def to_device(
self,
*,
ignore_errors=False,
device="cuda",
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
):
if not ignore_errors and image is None and mask is None:
raise ValueError(
"You must either provide an image or a mask,"
" use ignore_error to passthrough"
)
if image is not None:
image = image.to(device)
if mask is not None:
mask = mask.to(device)
return (image, mask)

class GetBatchFromHistory:
"""Very experimental node to load images from the history of the server.
Queue items without output are ignored in the count."""
Queue items without output are ignored in the count.
"""

@classmethod
def INPUT_TYPES(cls):
Expand All @@ -48,6 +97,7 @@ def INPUT_TYPES(cls):

def load_from_history(
self,
*,
enable=True,
count=0,
offset=0,
Expand Down Expand Up @@ -85,7 +135,9 @@ def load_batch_frames(self, response, offset, count, frames):
if "images" in node_output:
for image in node_output["images"]:
image_data = get_image(
image["filename"], image["subfolder"], image["type"]
image["filename"],
image["subfolder"],
image["type"],
)
output_images.append(image_data)

Expand All @@ -108,7 +160,7 @@ def load_batch_frames(self, response, offset, count, frames):


class AnyToString:
"""Tries to take any input and convert it to a string"""
"""Tries to take any input and convert it to a string."""

@classmethod
def INPUT_TYPES(cls):
Expand All @@ -128,18 +180,22 @@ def do_str(self, input):
elif isinstance(input, Image.Image):
return (f"PIL Image of size {input.size} and mode {input.mode}",)
elif isinstance(input, np.ndarray):
return (f"Numpy array of shape {input.shape} and dtype {input.dtype}",)
return (
f"Numpy array of shape {input.shape} and dtype {input.dtype}",
)

elif isinstance(input, dict):
return (f"Dictionary of {len(input)} items, with keys {input.keys()}",)
return (
f"Dictionary of {len(input)} items, with keys {input.keys()}",
)

else:
log.debug(f"Falling back to string conversion of {input}")
return (str(input),)


class StringReplace:
"""Basic string replacement"""
"""Basic string replacement."""

@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -182,7 +238,9 @@ def INPUT_TYPES(cls):
RETURN_TYPES = ("FLOAT", "INT")
RETURN_NAMES = ("result (float)", "result (int)")
CATEGORY = "mtb/math"
DESCRIPTION = "evaluate a simple math expression string (!! Fallsback to eval)"
DESCRIPTION = (
"evaluate a simple math expression string (!! Fallsback to eval)"
)

def eval_expression(self, expression, **kwargs):
import math
Expand Down Expand Up @@ -287,7 +345,7 @@ def set_range(


class ConcatImages:
"""Add images to batch"""
"""Add images to batch."""

RETURN_TYPES = ("IMAGE",)
FUNCTION = "concatenate_tensors"
Expand Down Expand Up @@ -320,4 +378,5 @@ def concatenate_tensors(self, reverse, **kwargs):
AnyToString,
ConcatImages,
MTB_MathExpression,
MTB_ToDevice,
]

0 comments on commit c28181f

Please sign in to comment.