Skip to content

Commit

Permalink
Tidy variable management and dtype handling in FluxTextToImageInvocat…
Browse files Browse the repository at this point in the history
…ion.
  • Loading branch information
RyanJDick committed Aug 29, 2024
1 parent 5e8cf9f commit 4e4b6c6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
28 changes: 15 additions & 13 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,28 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)

latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16

# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds

transformer_info = context.models.load(self.transformer.transformer)

# Prepare input noise.
x = get_noise(
num_samples=1,
Expand All @@ -88,13 +90,13 @@ def _run_diffusion(
seed=self.seed,
)

img, img_ids = prepare_latent_img_patches(x)
x, img_ids = prepare_latent_img_patches(x)

is_schnell = "schnell" in transformer_info.config.config_path

timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
image_seq_len=x.shape[1],
shift=not is_schnell,
)

Expand Down Expand Up @@ -135,7 +137,7 @@ def step_callback() -> None:

x = denoise(
model=transformer,
img=img,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
Expand Down
17 changes: 4 additions & 13 deletions invokeai/backend/flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,7 @@ def denoise(
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype

# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)

# this is ignored for schnell
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
Expand Down Expand Up @@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
img = repeat(img, "1 ... -> bs ...", bs=bs)

# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

return img, img_ids
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor

def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self


@dataclass
class ConditioningFieldData:
Expand Down

0 comments on commit 4e4b6c6

Please sign in to comment.