diff --git a/src/peft/tuners/adaption_prompt/config.py b/src/peft/tuners/adaption_prompt/config.py index fdce845edf..90e2984149 100644 --- a/src/peft/tuners/adaption_prompt/config.py +++ b/src/peft/tuners/adaption_prompt/config.py @@ -54,6 +54,13 @@ def is_adaption_prompt(self) -> bool: v_proj_layer="v_proj", o_proj_layer="o_proj", ), + "mistral": ModelTypeConfig( # same as llama, + compute_query_states=llama_compute_query_states, + target_modules="self_attn", + k_proj_layer="k_proj", + v_proj_layer="v_proj", + o_proj_layer="o_proj", + ), } diff --git a/src/peft/tuners/adaption_prompt/layer.py b/src/peft/tuners/adaption_prompt/layer.py index cdd7895eaa..31fb51e0de 100644 --- a/src/peft/tuners/adaption_prompt/layer.py +++ b/src/peft/tuners/adaption_prompt/layer.py @@ -74,31 +74,38 @@ def forward(self, **kwargs): k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer + factor = ( + self.model.k_proj.in_features // self.model.k_proj.out_features + ) # Mistral has different input and output dimension for k_proj and v_proj layers if k_proj_layer == v_proj_layer: _, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2) else: key = getattr(self.model, k_proj_layer)(self.adaption_prompt) value = getattr(self.model, v_proj_layer)(self.adaption_prompt) - # (bsz, num_heads, adapter_len, head_dim) + + # (bsz, num_key_value_heads, adapter_len, head_dim) adapter_k = ( - key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim) + key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) - # (bsz, num_heads, adapter_len, head_dim) adapter_v = ( - value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim) + value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) - + # Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 + # (bsz, num_heads, adapter_len, head_dim) + adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1) + adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1) # Recompute query states. compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states # (bsz, num_heads, q_len, head_dim) query_states = compute_query_states(model=self.model, **kwargs) previous_dtype = query_states.dtype + # (bsz, num_heads, q_len, adapter_len) scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt( self.model.head_dim @@ -108,6 +115,7 @@ def forward(self, **kwargs): scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype) # (bsz, q_len, num_heads * head_dim) adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1) + # (bsz, q_len, hidden_size) if o_proj_layer is not None: adapter_output = getattr(self.model, o_proj_layer)(adapter_output) diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index 2722193889..d70f31e389 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -68,7 +68,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: past_key_value = kwargs.get("past_key_value") bsz, q_len, _ = hidden_states.size() query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) - value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) + + factor = model.k_proj.in_features // model.k_proj.out_features + value_states = ( + model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2) + ) + seq_len = q_len if past_key_value is not None: diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 93cdb27d9a..ec3deb150a 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -38,11 +38,24 @@ def is_llama_available() -> bool: return False +def is_mistral_available() -> bool: + """Check if mistral is available in the transformers library (it's not in earlier versions).""" + try: + return importlib.util.find_spec("transformers.models.mistral.modeling_mistral") is not None + except ModuleNotFoundError: + return False + + if is_llama_available(): # We guard the import statement so that our unit tests will pass in CI environments # that don't have a transformers package with Llama. from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +if is_mistral_available(): + # We guard the import statement so that our unit tests will pass in CI environments + # that don't have a transformers package with Mistral. + from transformers import MistralConfig, MistralForCausalLM, MistralModel + class AdaptionPromptTester(TestCase, PeftCommonTester): """ @@ -55,7 +68,10 @@ class AdaptionPromptTester(TestCase, PeftCommonTester): def setUp(self): # Check that llama is available in transformers package before running each test. if not is_llama_available(): - self.skipTest("Llama not available in transformers. Skipping test.") + self.skipTest("Llama not available in transformers. Skipping all tests.") + else: + # Check for Mistral's availability. It might or might not be available. + self.mistral_available = is_mistral_available() @staticmethod def _create_test_llama_config(): @@ -69,6 +85,19 @@ def _create_test_llama_config(): use_cache=False, ) + @staticmethod + def _create_test_mistral_config(): + """Create a test config for a small Mistral model for testing.""" + return MistralConfig( + vocab_size=16, + hidden_size=8, + intermediate_size=8, + num_hidden_layers=8, + num_attention_heads=4, + num_key_value_heads=2, + use_cache=False, + ) + def test_attributes(self) -> None: model = LlamaModel(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) @@ -78,7 +107,18 @@ def test_attributes(self) -> None: assert hasattr(model, "from_pretrained") assert hasattr(model, "push_to_hub") + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_attributes_mistral(self) -> None: + model_mistral = MistralModel(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) + model_mistral = get_peft_model(model_mistral, config_mistral) + + assert hasattr(model_mistral, "save_pretrained") + assert hasattr(model_mistral, "from_pretrained") + assert hasattr(model_mistral, "push_to_hub") + def test_prepare_for_training(self) -> None: + # Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -89,6 +129,18 @@ def test_prepare_for_training(self) -> None: assert not dummy_output.requires_grad + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_prepare_for_training_mistral(self) -> None: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) + dummy_output = model_mistral.get_input_embeddings()(dummy_input) + + assert not dummy_output.requires_grad + def test_prepare_for_int8_training(self) -> None: model = LlamaForCausalLM(self._create_test_llama_config()) model = prepare_model_for_int8_training(model) @@ -115,6 +167,33 @@ def make_inputs_require_grad(module, input, output): assert dummy_output.requires_grad + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_prepare_model_for_int8_training_mistral(self) -> None: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = prepare_model_for_int8_training(model_mistral) + model_mistral = model_mistral.to(self.torch_device) + + for param in model_mistral.parameters(): + assert not param.requires_grad + + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + + # For backward compatibility + if hasattr(model_mistral, "enable_input_require_grads"): + model_mistral.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model_mistral.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) + dummy_output = model_mistral.get_input_embeddings()(dummy_input) + + assert dummy_output.requires_grad + def test_save_pretrained_regression(self) -> None: seed = 420 torch.manual_seed(seed) @@ -158,6 +237,50 @@ def test_save_pretrained_regression(self) -> None: # check if `config.json` is not present assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_save_pretrained_regression_mistral(self) -> None: + seed = 420 + torch.manual_seed(seed) + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname, safe_serialization=False) + + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + + # check if same keys + assert state_dict.keys() == state_dict_from_pretrained.keys() + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + assert len(state_dict) == 4 + + # check if tensors equal + for key in state_dict.keys(): + assert torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + + # check if `adapter_model.bin` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")) + + # check if `adapter_config.json` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) + + # check if `model.safetensors` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors")) + + # check if `config.json` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + def test_save_pretrained(self) -> None: seed = 420 torch.manual_seed(seed) @@ -201,6 +324,50 @@ def test_save_pretrained(self) -> None: # check if `config.json` is not present assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_save_pretrained_mistral(self) -> None: + seed = 420 + torch.manual_seed(seed) + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname) + + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + + # check if same keys + assert state_dict.keys() == state_dict_from_pretrained.keys() + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + assert len(state_dict) == 4 + + # check if tensors equal + for key in state_dict.keys(): + assert torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + + # check if `adapter_model.bin` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")) + + # check if `adapter_config.json` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) + + # check if `model.safetensors` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors")) + + # check if `config.json` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + def test_save_pretrained_selected_adapters(self) -> None: seed = 420 torch.manual_seed(seed) @@ -249,6 +416,55 @@ def test_save_pretrained_selected_adapters(self) -> None: # check if `config.json` is not present assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_save_pretrained_selected_adapters_mistral(self) -> None: + seed = 420 + torch.manual_seed(seed) + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + new_adapter_config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral.add_adapter("new_adapter", new_adapter_config_mistral) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname) + + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + + model_from_pretrained_mistral.load_adapter(tmp_dirname, "new_adapter") + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + + # check if same keys + assert state_dict.keys() == state_dict_from_pretrained.keys() + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + assert len(state_dict) == 4 + + # check if tensors equal + for key in state_dict.keys(): + assert torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + + # check if `adapter_model.bin` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")) + + # check if `adapter_config.json` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) + + # check if `model.safetensors` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors")) + + # check if `config.json` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + def test_generate(self) -> None: model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") @@ -264,6 +480,22 @@ def test_generate(self) -> None: # check if `generate` works if positional arguments are passed _ = model.generate(input_ids, attention_mask=attention_mask) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_generate_mistral(self) -> None: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + # check if `generate` works + _ = model_mistral.generate(input_ids=input_ids, attention_mask=attention_mask) + + # check if `generate` works if positional arguments are passed + _ = model_mistral.generate(input_ids, attention_mask=attention_mask) + def test_sequence_adapter_ops(self) -> None: """Test sequence of adapter operations.""" # Test input data. @@ -333,6 +565,77 @@ def test_sequence_adapter_ops(self) -> None: assert not torch.allclose(original_before.logits, default_after_set.logits) assert not torch.allclose(adapter_1_after.logits, default_after_set.logits) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_sequence_adapter_ops_mistral(self) -> None: + # Test input data. + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + # Create original mistral model. + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = model_mistral.to(self.torch_device) + original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) + + # Get AdaptionPrompt model. + adapted_mistral = get_peft_model( + model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted_mistral = adapted_mistral.to(self.torch_device) + default_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + + # Test zero-init: The logits should be exactly the same. + assert_close(original_before.logits, default_before.logits, rtol=0, atol=0) + + # Single fine-tuning step on "default" adapter. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + default_before.loss.backward() + optimizer.step() + + # Test that the output changed. + default_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert not torch.allclose(default_before.logits, default_after.logits) + + with adapted_mistral.disable_adapter(): + # Test that the output is the same as the original output. + default_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, default_disabled.logits, rtol=0, atol=0) + + # Add new adapter 1. + adapted_mistral.add_adapter( + "adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM") + ) + # Test zero-init + adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) + + # Single fine-tuning step on adapter 1. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + adapter_1_before.loss.backward() + optimizer.step() + + # Test that adapter 1 output changed. + adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert not torch.allclose(adapter_1_before.logits, adapter_1_after.logits) + assert not torch.allclose(original_before.logits, adapter_1_after.logits) + assert not torch.allclose(default_after.logits, adapter_1_after.logits) + + with adapted_mistral.disable_adapter(): + # Test that the output is the same as the original output. + adapter_1_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_disabled.logits, rtol=0, atol=0) + + # Set adapter back to default. + adapted_mistral.set_adapter("default") + + # Test that the output is the same as the default output after training. + default_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(default_after.logits, default_after_set.logits, rtol=0, atol=0) + assert not torch.allclose(original_before.logits, default_after_set.logits) + assert not torch.allclose(adapter_1_after.logits, default_after_set.logits) + def test_add_and_set_while_disabled(self): """Test that adding and setting adapters while disabled works as intended.""" # Test input data. @@ -378,6 +681,51 @@ def test_add_and_set_while_disabled(self): adapter_1_after_set = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_add_and_set_while_disabled_mistral(self): + # Test input data. + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + # Create original mistral model. + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = model_mistral.to(self.torch_device) + original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) + + # Get AdaptionPrompt model. + adapted_mistral = get_peft_model( + model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted_mistral = adapted_mistral.to(self.torch_device) + + with adapted_mistral.disable_adapter(): + adapted_mistral.add_adapter( + "adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM") + ) + + # Test that the output is the same as the original output. + adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) + + # Single fine-tuning step on adapter 1. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + adapter_1_before.loss.backward() + optimizer.step() + + # Test that adapter 1 output changed. + adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert not torch.allclose(original_before.logits, adapter_1_after.logits) + + adapted_mistral.set_adapter("default") + with adapted_mistral.disable_adapter(): + adapted_mistral.set_adapter("adapter 1") + + # Test that adapter 1 is active again. + adapter_1_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) + def test_use_cache(self) -> None: """Test that AdaptionPrompt works when Llama config use_cache=True.""" torch.manual_seed(0) @@ -403,6 +751,32 @@ def test_use_cache(self) -> None: actual = adapted.generate(input_ids=input_ids, max_length=8) assert_close(expected, actual, rtol=0, atol=0) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_use_cache_mistral(self) -> None: + torch.manual_seed(0) + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + original = MistralForCausalLM( + MistralConfig( + vocab_size=16, + hidden_size=8, + intermediate_size=8, + num_hidden_layers=8, + num_attention_heads=4, + num_key_value_heads=2, + use_cache=False, + ) + ).eval() + adapted = get_peft_model( + original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted = adapted.to(self.torch_device) + expected = adapted.generate(input_ids=input_ids, max_length=8) + + # Set use_cache = True and generate output again. + adapted.base_model.config.use_cache = True + actual = adapted.generate(input_ids=input_ids, max_length=8) + assert_close(expected, actual, rtol=0, atol=0) + def test_bf16_inference(self) -> None: if self.torch_device == "mps": return pytest.skip("Skipping bf16 test on MPS")