From 530ab96036604b125276433b67ebb840e841aede Mon Sep 17 00:00:00 2001 From: Czxck001 <10724409+Czxck001@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:10:40 -0700 Subject: [PATCH] Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590) * support skip layer guidance (slg) for stable diffusion 3.5 medium * Tweak the comments formatting. * Proper error message. * Cosmetic tweaks. --------- Co-authored-by: Laurent --- .../examples/stable-diffusion-3/main.rs | 27 ++++++++++++-- .../examples/stable-diffusion-3/sampling.rs | 36 ++++++++++++++++--- candle-transformers/src/models/mmdit/model.rs | 26 +++++++++++--- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 9ad057e358..8c9a78d25b 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -75,14 +75,19 @@ struct Args { #[arg(long)] num_inference_steps: Option, - // CFG scale. + /// CFG scale. #[arg(long)] cfg_scale: Option, - // Time shift factor (alpha). + /// Time shift factor (alpha). #[arg(long, default_value_t = 3.0)] time_shift: f64, + /// Use Skip Layer Guidance (SLG) for the sampling. + /// Currently only supports Stable Diffusion 3.5 Medium. + #[arg(long)] + use_slg: bool, + /// The seed to use when generating random samples. #[arg(long)] seed: Option, @@ -105,6 +110,7 @@ fn main() -> Result<()> { time_shift, seed, which, + use_slg, } = Args::parse(); let _guard = if tracing { @@ -211,6 +217,22 @@ fn main() -> Result<()> { if let Some(seed) = seed { device.set_seed(seed)?; } + + let slg_config = if use_slg { + match which { + // https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394 + Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig { + scale: 2.5, + start: 0.01, + end: 0.2, + layers: vec![7, 8, 9], + }), + _ => anyhow::bail!("--use-slg can only be used with 3.5-medium"), + } + } else { + None + }; + let start_time = std::time::Instant::now(); let x = { let mmdit = MMDiT::new( @@ -227,6 +249,7 @@ fn main() -> Result<()> { time_shift, height, width, + slg_config, )? }; let dt = start_time.elapsed().as_secs_f32(); diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs index cd881b6a2f..5e23437175 100644 --- a/candle-examples/examples/stable-diffusion-3/sampling.rs +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -1,8 +1,15 @@ use anyhow::{Ok, Result}; -use candle::{DType, Tensor}; +use candle::{DType, IndexOp, Tensor}; use candle_transformers::models::flux; -use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function +use candle_transformers::models::mmdit::model::MMDiT; + +pub struct SkipLayerGuidanceConfig { + pub scale: f64, + pub start: f64, + pub end: f64, + pub layers: Vec, +} #[allow(clippy::too_many_arguments)] pub fn euler_sample( @@ -14,6 +21,7 @@ pub fn euler_sample( time_shift: f64, height: usize, width: usize, + slg_config: Option, ) -> Result { let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?; let sigmas = (0..=num_inference_steps) @@ -22,7 +30,7 @@ pub fn euler_sample( .map(|x| time_snr_shift(time_shift, x)) .collect::>(); - for window in sigmas.windows(2) { + for (step, window) in sigmas.windows(2).enumerate() { let (s_curr, s_prev) = match window { [a, b] => (a, b), _ => continue, @@ -34,8 +42,28 @@ pub fn euler_sample( &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, y, context, + None, )?; - x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?; + + let mut guidance = apply_cfg(cfg_scale, &noise_pred)?; + + if let Some(slg_config) = slg_config.as_ref() { + if (num_inference_steps as f64) * slg_config.start < (step as f64) + && (step as f64) < (num_inference_steps as f64) * slg_config.end + { + let slg_noise_pred = mmdit.forward( + &x, + &Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?, + &y.i(..1)?, + &context.i(..1)?, + Some(&slg_config.layers), + )?; + guidance = (guidance + + (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?; + } + } + + x = (x + (guidance * (*s_prev - *s_curr))?)?; } Ok(x) } diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index c7b4deedb2..21897aa356 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -130,7 +130,14 @@ impl MMDiT { }) } - pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result { + pub fn forward( + &self, + x: &Tensor, + t: &Tensor, + y: &Tensor, + context: &Tensor, + skip_layers: Option<&[usize]>, + ) -> Result { // Following the convention of the ComfyUI implementation. // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919 // @@ -150,7 +157,7 @@ impl MMDiT { let c = (c + y)?; let context = self.context_embedder.forward(context)?; - let x = self.core.forward(&context, &x, &c)?; + let x = self.core.forward(&context, &x, &c, skip_layers)?; let x = self.unpatchifier.unpatchify(&x, h, w)?; x.narrow(2, 0, h)?.narrow(3, 0, w) } @@ -211,9 +218,20 @@ impl MMDiTCore { }) } - pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result { + pub fn forward( + &self, + context: &Tensor, + x: &Tensor, + c: &Tensor, + skip_layers: Option<&[usize]>, + ) -> Result { let (mut context, mut x) = (context.clone(), x.clone()); - for joint_block in &self.joint_blocks { + for (i, joint_block) in self.joint_blocks.iter().enumerate() { + if let Some(skip_layers) = &skip_layers { + if skip_layers.contains(&i) { + continue; + } + } (context, x) = joint_block.forward(&context, &x, c)?; } let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;