diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index f4aadc2577f7..ffe460d72de8 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -1283,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Get the text embedding for conditioning - prompt_embeds = batch["prompt_embeds"] - pooled_prompt_embeds = batch["pooled_prompt_embeds"] + prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype) # controlnet(s) inference controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)