From a6a18cff5ef6af3396809dbe5200392551983b1e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 12 Dec 2024 12:52:50 +0530 Subject: [PATCH] [LoRA] add a test to ensure `set_adapters()` and attn kwargs outs match (#10110) * add a test to ensure set_adapters() and attn kwargs outs match * remove print * fix * Apply suggestions from code review Co-authored-by: Benjamin Bossan * assertFalse. --------- Co-authored-by: Benjamin Bossan --- tests/lora/utils.py | 92 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 474c31150538..990cf71f298e 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -76,6 +76,9 @@ def initialize_dummy_state_dict(state_dict): return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} +POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] + + @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None @@ -429,7 +432,7 @@ def test_simple_inference_with_text_lora_and_scale(self): call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release - for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: + for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: if possible_attention_kwargs in call_signature_keys: attention_kwargs_name = possible_attention_kwargs break @@ -790,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): and makes sure it works as expected """ call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() - for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: + for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: if possible_attention_kwargs in call_signature_keys: attention_kwargs_name = possible_attention_kwargs break @@ -1885,3 +1888,88 @@ def set_pad_mode(network, mode="circular"): _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs)[0] + + def test_set_adapters_match_attention_kwargs(self): + """Test to check if outputs after `set_adapters()` and attention kwargs match.""" + call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() + for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + lora_scale = 0.5 + attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertFalse( + np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + + pipe.set_adapters("default", lora_scale) + output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + self.assertTrue( + np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), + "Lora + scale should match the output of `set_adapters()`.", + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + + output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( + not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + self.assertTrue( + np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results as attention_kwargs.", + ) + self.assertTrue( + np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results as set_adapters().", + )