Skip to content

Commit

Permalink
feat: 💄 add a few more batch nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
melMass committed Oct 9, 2023
1 parent 4605f74 commit c1d42de
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 21 deletions.
3 changes: 3 additions & 0 deletions node_list.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
"Animation Builder (mtb)": "Convenient way to manage basic animation maths at the core of many of my workflows",
"Any To String (mtb)": "Tries to take any input and convert it to a string",
"Batch Float (mtb)": "Generates a batch of float values with interpolation",
"Batch Float Assemble (mtb)": "Assembles mutiple batches of floats into a single stream (batch)",
"Batch Float Fill (mtb)": "Fills a batch float with a single value until it reaches the target length",
"Batch Make (mtb)": "Simply duplicates the input frame as a batch",
"Batch Shape (mtb)": "Generates a batch of 2D shapes with optional shading (experimental)",
"Batch Transform (mtb)": "Transform a batch of images using a batch of keyframes",
"Bbox (mtb)": "The bounding box (BBOX) custom type used by other nodes",
Expand Down
86 changes: 85 additions & 1 deletion nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ def hex_to_rgb(hex_color, bgr=False):
return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))


class BatchMake:
"""Simply duplicates the input frame as a batch"""

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"count": ("INT", {"default": 1}),
}
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_batch"
CATEGORY = "mtb/batch"

def generate_batch(self, image: torch.Tensor, count):
if len(image.shape) == 3:
image = image.unsqueeze(0)

return (image.repeat(count, 1, 1, 1),)


class BatchShape:
"""Generates a batch of 2D shapes with optional shading (experimental)"""

Expand Down Expand Up @@ -112,6 +135,60 @@ def generate_shapes(
return (pil2tensor(res),)


class BatchFloatFill:
"""Fills a batch float with a single value until it reaches the target length"""

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"floats": ("FLOATS",),
"direction": (["head", "tail"], {"default": "tail"}),
"value": ("FLOAT", {"default": 0.0}),
"count": ("INT", {"default": 1}),
}
}

FUNCTION = "fill_floats"
RETURN_TYPES = ("FLOATS",)
CATEGORY = "mtb/batch"

def fill_floats(self, floats, direction, value, count):
size = len(floats)
if size > count:
raise ValueError(f"Size ({size}) is less then target count ({count})")

rem = count - size
if direction == "tail":
floats = floats + [value] * rem
else:
floats = [value] * rem + floats
return (floats,)


class BatchFloatAssemble:
"""Assembles mutiple batches of floats into a single stream (batch)"""

@classmethod
def INPUT_TYPES(cls):
return {"required": {"reverse": ("BOOLEAN", {"default": False})}}

FUNCTION = "assemble_floats"
RETURN_TYPES = ("FLOATS",)
CATEGORY = "mtb/batch"

def assemble_floats(self, reverse, **kwargs):
res = []
if reverse:
for x in reversed(kwargs.values()):
res += x
else:
for x in kwargs.values():
res += x

return (res,)


class BatchFloat:
"""Generates a batch of float values with interpolation"""

Expand Down Expand Up @@ -257,4 +334,11 @@ def transform_batch(
return (torch.cat(res, dim=0),)


__nodes__ = [BatchFloat, Batch2dTransform, BatchShape]
__nodes__ = [
BatchFloat,
Batch2dTransform,
BatchShape,
BatchMake,
BatchFloatAssemble,
BatchFloatFill,
]
20 changes: 20 additions & 0 deletions web/comfy_shared.js
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,27 @@ export function getWidgetType(config) {
}
return { type, linkType }
}
export const setupDynamicConnections = (nodeType, prefix, inputType) => {
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined
this.addInput(`${prefix}_1`, inputType)
return r
}

const onConnectionsChange = nodeType.prototype.onConnectionsChange
nodeType.prototype.onConnectionsChange = function (
type,
index,
connected,
link_info
) {
const r = onConnectionsChange
? onConnectionsChange.apply(this, arguments)
: undefined
dynamic_connection(this, index, connected, `${prefix}_`, inputType)
}
}
export const dynamic_connection = (
node,
index,
Expand Down
30 changes: 10 additions & 20 deletions web/mtb_widgets.js
Original file line number Diff line number Diff line change
Expand Up @@ -878,27 +878,17 @@ const mtb_widgets = {
break
}
case 'Stack Images (mtb)': {
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated
? onNodeCreated.apply(this, arguments)
: undefined
this.addInput(`image_1`, 'IMAGE')
return r
}
shared.setupDynamicConnections(nodeType, 'image', 'IMAGE')

break
}
case 'Batch Float Assemble (mtb)': {
shared.setupDynamicConnections(nodeType, 'floats', 'FLOATS')
break
}
case 'Batch Merge (mtb)': {
shared.setupDynamicConnections(nodeType, 'batches', 'IMAGE')

const onConnectionsChange = nodeType.prototype.onConnectionsChange
nodeType.prototype.onConnectionsChange = function (
type,
index,
connected,
link_info
) {
const r = onConnectionsChange
? onConnectionsChange.apply(this, arguments)
: undefined
shared.dynamic_connection(this, index, connected, 'image_', 'IMAGE')
}
break
}
case 'Save Tensors (mtb)': {
Expand Down

0 comments on commit c1d42de

Please sign in to comment.