Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLUX memory management improvements #6791

Merged
merged 11 commits into from
Aug 29, 2024
Merged
24 changes: 15 additions & 9 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,18 @@ class FluxTextEncoderInvocation(BaseInvocation):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)

conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)

def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)

# Load T5.
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)

Expand All @@ -70,6 +68,15 @@ def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torc

prompt_embeds = t5_encoder(prompt)

assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds

def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)

prompt = [self.prompt]

with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
Expand All @@ -81,6 +88,5 @@ def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torc

pooled_prompt_embeds = clip_encoder(prompt)

assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
return pooled_prompt_embeds
33 changes: 15 additions & 18 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,28 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)

latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16

# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds

transformer_info = context.models.load(self.transformer.transformer)

# Prepare input noise.
x = get_noise(
num_samples=1,
Expand All @@ -88,24 +90,19 @@ def _run_diffusion(
seed=self.seed,
)

img, img_ids = prepare_latent_img_patches(x)
x, img_ids = prepare_latent_img_patches(x)

is_schnell = "schnell" in transformer_info.config.config_path

timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
image_seq_len=x.shape[1],
shift=not is_schnell,
)

bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())

# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)

with transformer_info as transformer:
assert isinstance(transformer, Flux)

Expand Down Expand Up @@ -140,7 +137,7 @@ def step_callback() -> None:

x = denoise(
model=transformer,
img=img,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
Expand Down
17 changes: 4 additions & 13 deletions invokeai/backend/flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,7 @@ def denoise(
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype

# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)

# this is ignored for schnell
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
Expand Down Expand Up @@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
img = repeat(img, "1 ... -> bs ...", bs=bs)

# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

return img, img_ids
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod
pass

config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)

self._ram_cache.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,6 @@ def get(
"""
pass

@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass

@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
Expand Down
Loading
Loading