Skip to content

Commit

Permalink
feat: ✨ Math Expression node
Browse files Browse the repository at this point in the history
  • Loading branch information
melMass committed Nov 4, 2023
1 parent c8658df commit 142624e
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 12 deletions.
47 changes: 47 additions & 0 deletions nodes/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,52 @@ def replace_str(self, string: str, old: str, new: str):
return (string,)


class MTB_MathExpression:
"""Node to evaluate a simple math expression string"""

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"expression": ("STRING", {"default": "", "multiline": True}),
}
}

FUNCTION = "eval_expression"
RETURN_TYPES = ("FLOAT", "INT")
RETURN_NAMES = ("result (float)", "result (int)")
CATEGORY = "mtb/math"
DESCRIPTION = "evaluate a simple math expression string (!! Fallsback to eval)"

def eval_expression(self, expression, **kwargs):
import math
from ast import literal_eval

for key, value in kwargs.items():
print(f"Replacing placeholder <{key}> with value {value}")
expression = expression.replace(f"<{key}>", str(value))

result = -1
try:
result = literal_eval(expression)
except SyntaxError as e:
raise ValueError(
f"The expression syntax is wrong '{expression}': {e}"
) from e

except ValueError:
try:
expression = expression.replace("^", "**")
result = eval(expression)
except Exception as e:
# Handle any other exceptions and provide a meaningful error message
raise ValueError(
f"Error evaluating expression '{expression}': {e}"
) from e

return (result, int(result))


class FitNumber:
"""Fit the input float using a source and target range"""

Expand Down Expand Up @@ -276,4 +322,5 @@ def concatenate_tensors(self, reverse, **kwargs):
GetBatchFromHistory,
AnyToString,
ConcatImages,
MTB_MathExpression,
]
26 changes: 14 additions & 12 deletions web/comfy_shared.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ export const dynamic_connection = (
index,
connected,
connectionPrefix = 'input_',
connectionType = 'PSDLAYER'
connectionType = 'PSDLAYER',
nameArray = []
) => {
// remove all non connected inputs
if (!connected && node.inputs.length > 1) {
Expand All @@ -134,23 +135,24 @@ export const dynamic_connection = (

// make inputs sequential again
for (let i = 0; i < node.inputs.length; i++) {
node.inputs[i].label = `${connectionPrefix}${i + 1}`
node.inputs[i].name = `${connectionPrefix}${i + 1}`
const name =
i < nameArray.length ? nameArray[i] : `${connectionPrefix}${i + 1}`
node.inputs[i].label = name
node.inputs[i].name = name
}
}

// add an extra input
if (node.inputs[node.inputs.length - 1].link != undefined) {
log(
`Adding input ${node.inputs.length + 1} (${connectionPrefix}${
node.inputs.length + 1
})`
)
const nextIndex = node.inputs.length
const name =
nextIndex < nameArray.length
? nameArray[nextIndex]
: `${connectionPrefix}${nextIndex + 1}`

node.addInput(
`${connectionPrefix}${node.inputs.length + 1}`,
connectionType
)
log(`Adding input ${nextIndex + 1} (${name})`)

node.addInput(name, connectionType)
}
}

Expand Down
48 changes: 48 additions & 0 deletions web/mtb_widgets.js
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ const mtb_widgets = {
async beforeRegisterNodeDef(nodeType, nodeData, app) {
// const rinputs = nodeData.input?.required

if (!nodeData.name.endsWith('(mtb)')) {
return
}

let has_custom = false
if (nodeData.input && nodeData.input.required) {
for (const i of Object.keys(nodeData.input.required)) {
Expand Down Expand Up @@ -892,6 +896,50 @@ const mtb_widgets = {

break
}
case 'Math Expression (mtb)': {
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated
? onNodeCreated.apply(this, arguments)
: undefined
this.addInput(`x`, '*')
return r
}

const onConnectionsChange = nodeType.prototype.onConnectionsChange
nodeType.prototype.onConnectionsChange = function (
type,
index,
connected,
link_info
) {
const r = onConnectionsChange
? onConnectionsChange.apply(this, arguments)
: undefined
shared.dynamic_connection(this, index, connected, 'var_', '*', [
'x',
'y',
'z',
])

//- infer type
if (link_info) {
const fromNode = this.graph._nodes.find(
(otherNode) => otherNode.id == link_info.origin_id
)
const type = fromNode.outputs[link_info.origin_slot].type
this.inputs[index].type = type
// this.inputs[index].label = type.toLowerCase()
}
//- restore dynamic input
if (!connected) {
this.inputs[index].type = '*'
this.inputs[index].label = `number_${index + 1}`
}
}

break
}
case 'Save Tensors (mtb)': {
const onDrawBackground = nodeType.prototype.onDrawBackground
nodeType.prototype.onDrawBackground = function (ctx, canvas) {
Expand Down

0 comments on commit 142624e

Please sign in to comment.