Skip to content

Commit

Permalink
Merge branch 'main' into exponential-sigmas
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu authored Sep 23, 2024
2 parents 40c5eb2 + 00f5b41 commit 7993eb3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
Expand Down
12 changes: 6 additions & 6 deletions tests/pipelines/lumina/test_lumina_nextdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
def get_dummy_components(self):
torch.manual_seed(0)
transformer = LuminaNextDiT2DModel(
sample_size=16,
sample_size=4,
patch_size=2,
in_channels=4,
hidden_size=24,
hidden_size=4,
num_layers=2,
num_attention_heads=3,
num_attention_heads=1,
num_kv_heads=1,
multiple_of=16,
ffn_dim_multiplier=None,
norm_eps=1e-5,
learn_sigma=True,
qk_norm=True,
cross_attention_dim=32,
cross_attention_dim=8,
scaling_factor=1.0,
)
torch.manual_seed(0)
Expand All @@ -57,8 +57,8 @@ def get_dummy_components(self):

torch.manual_seed(0)
config = GemmaConfig(
head_dim=4,
hidden_size=32,
head_dim=2,
hidden_size=8,
intermediate_size=37,
num_attention_heads=4,
num_hidden_layers=2,
Expand Down

0 comments on commit 7993eb3

Please sign in to comment.