diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py index 2ec798b40a..975788ac74 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion.py @@ -235,6 +235,7 @@ def generate_image( + unconditional_guidance_scale * (latent - unconditional_latent) ) a_t, a_prev = alphas[index], alphas_prev[index] + latent = ops.cast(latent, latent_prev.dtype) pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt( a_t )