Skip to content

Commit

Permalink
Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590
Browse files Browse the repository at this point in the history
)

* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
  • Loading branch information
Czxck001 and LaurentMazare authored Nov 1, 2024
1 parent 7ac0de1 commit 530ab96
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
27 changes: 25 additions & 2 deletions candle-examples/examples/stable-diffusion-3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,19 @@ struct Args {
#[arg(long)]
num_inference_steps: Option<usize>,

// CFG scale.
/// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,

// Time shift factor (alpha).
/// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,

/// Use Skip Layer Guidance (SLG) for the sampling.
/// Currently only supports Stable Diffusion 3.5 Medium.
#[arg(long)]
use_slg: bool,

/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
Expand All @@ -105,6 +110,7 @@ fn main() -> Result<()> {
time_shift,
seed,
which,
use_slg,
} = Args::parse();

let _guard = if tracing {
Expand Down Expand Up @@ -211,6 +217,22 @@ fn main() -> Result<()> {
if let Some(seed) = seed {
device.set_seed(seed)?;
}

let slg_config = if use_slg {
match which {
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
scale: 2.5,
start: 0.01,
end: 0.2,
layers: vec![7, 8, 9],
}),
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
}
} else {
None
};

let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
Expand All @@ -227,6 +249,7 @@ fn main() -> Result<()> {
time_shift,
height,
width,
slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();
Expand Down
36 changes: 32 additions & 4 deletions candle-examples/examples/stable-diffusion-3/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
use anyhow::{Ok, Result};
use candle::{DType, Tensor};
use candle::{DType, IndexOp, Tensor};

use candle_transformers::models::flux;
use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
use candle_transformers::models::mmdit::model::MMDiT;

pub struct SkipLayerGuidanceConfig {
pub scale: f64,
pub start: f64,
pub end: f64,
pub layers: Vec<usize>,
}

#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
Expand All @@ -14,6 +21,7 @@ pub fn euler_sample(
time_shift: f64,
height: usize,
width: usize,
slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
Expand All @@ -22,7 +30,7 @@ pub fn euler_sample(
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();

for window in sigmas.windows(2) {
for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
Expand All @@ -34,8 +42,28 @@ pub fn euler_sample(
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
None,
)?;
x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;

let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;

if let Some(slg_config) = slg_config.as_ref() {
if (num_inference_steps as f64) * slg_config.start < (step as f64)
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
{
let slg_noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
&y.i(..1)?,
&context.i(..1)?,
Some(&slg_config.layers),
)?;
guidance = (guidance
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
}
}

x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}
Expand Down
26 changes: 22 additions & 4 deletions candle-transformers/src/models/mmdit/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,14 @@ impl MMDiT {
})
}

pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
y: &Tensor,
context: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
Expand All @@ -150,7 +157,7 @@ impl MMDiT {
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;

let x = self.core.forward(&context, &x, &c)?;
let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
Expand Down Expand Up @@ -211,9 +218,20 @@ impl MMDiTCore {
})
}

pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
context: &Tensor,
x: &Tensor,
c: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for joint_block in &self.joint_blocks {
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
if let Some(skip_layers) = &skip_layers {
if skip_layers.contains(&i) {
continue;
}
}
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
Expand Down

0 comments on commit 530ab96

Please sign in to comment.