Skip to content

Commit

Permalink
feat: ✨ add Interpolate Clip Sequential
Browse files Browse the repository at this point in the history
Still need testing but works
  • Loading branch information
melMass committed Aug 12, 2023
1 parent 49c64c7 commit a71c273
Showing 1 changed file with 83 additions and 1 deletion.
84 changes: 83 additions & 1 deletion nodes/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,88 @@
import csv


class InterpolateClipSequential:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"base_text": ("STRING", {"multiline": True}),
"text_to_replace": ("STRING", {"default": ""}),
"clip": ("CLIP",),
"interpolation_strength": (
"FLOAT",
{"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01},
),
}
}

RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "interpolate_encodings_sequential"

CATEGORY = "mtb/conditioning"

def interpolate_encodings_sequential(
self, base_text, text_to_replace, clip, interpolation_strength, **replacements
):
log.debug(f"Received interpolation_strength: {interpolation_strength}")

# - Ensure interpolation strength is within [0, 1]
interpolation_strength = max(0.0, min(1.0, interpolation_strength))

# - Check if replacements were provided
if not replacements:
raise ValueError("At least one replacement should be provided.")

num_replacements = len(replacements)
log.debug(f"Number of replacements: {num_replacements}")

segment_length = 1.0 / num_replacements
log.debug(f"Calculated segment_length: {segment_length}")

# - Find the segment that the interpolation_strength falls into
segment_index = min(
int(interpolation_strength // segment_length), num_replacements - 1
)
log.debug(f"Segment index: {segment_index}")

# - Calculate the local strength within the segment
local_strength = (
interpolation_strength - (segment_index * segment_length)
) / segment_length
log.debug(f"Local strength: {local_strength}")

# - If it's the first segment, interpolate between base_text and the first replacement
if segment_index == 0:
replacement_text = list(replacements.values())[0]
log.debug("Using the base text a the base blend")
# - Start with the base_text condition
tokens = clip.tokenize(base_text)
cond_from, pooled_from = clip.encode_from_tokens(tokens, return_pooled=True)
else:
base_replace = list(replacements.values())[segment_index - 1]
log.debug(f"Using {base_replace} a the base blend")

# - Start with the base_text condition replaced by the closest replacement
tokens = clip.tokenize(base_text.replace(text_to_replace, base_replace))
cond_from, pooled_from = clip.encode_from_tokens(tokens, return_pooled=True)

replacement_text = list(replacements.values())[segment_index]

interpolated_text = base_text.replace(text_to_replace, replacement_text)
tokens = clip.tokenize(interpolated_text)
cond_to, pooled_to = clip.encode_from_tokens(tokens, return_pooled=True)

# - Linearly interpolate between the two conditions
interpolated_condition = (
1.0 - local_strength
) * cond_from + local_strength * cond_to
interpolated_pooled = (
1.0 - local_strength
) * pooled_from + local_strength * pooled_to

return ([[interpolated_condition, {"pooled_output": interpolated_pooled}]],)


class SmartStep:
"""Utils to control the steps start/stop of the KAdvancedSampler in percentage"""

Expand Down Expand Up @@ -96,4 +178,4 @@ def load_style(self, style_name):
return (self.options[style_name][0], self.options[style_name][1])


__nodes__ = [SmartStep, StylesLoader]
__nodes__ = [SmartStep, StylesLoader, InterpolateClipSequential]

0 comments on commit a71c273

Please sign in to comment.