Skip to content

Commit

Permalink
[LoRA] add a test to ensure set_adapters() and attn kwargs outs mat…
Browse files Browse the repository at this point in the history
…ch (#10110)

* add a test to ensure set_adapters() and attn kwargs outs match

* remove print

* fix

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* assertFalse.

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
sayakpaul and BenjaminBossan authored Dec 12, 2024
1 parent 7db9463 commit a6a18cf
Showing 1 changed file with 90 additions and 2 deletions.
92 changes: 90 additions & 2 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def initialize_dummy_state_dict(state_dict):
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}


POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]


@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
Expand Down Expand Up @@ -429,7 +432,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()

# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
Expand Down Expand Up @@ -790,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
and makes sure it works as expected
"""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
Expand Down Expand Up @@ -1885,3 +1888,88 @@ def set_pad_mode(network, mode="circular"):

_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0]

def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None

for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)

lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertFalse(
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)

pipe.set_adapters("default", lora_scale)
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should match the output of `set_adapters()`.",
)

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as attention_kwargs.",
)
self.assertTrue(
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)

0 comments on commit a6a18cf

Please sign in to comment.