From 2c71bd0991db90eb2d0b66c9c6724711ad9a9677 Mon Sep 17 00:00:00 2001 From: Prakhar Saxena Date: Mon, 5 Feb 2024 07:41:20 +0000 Subject: [PATCH 1/8] Support Mistral For llama-adapter --- src/peft/tuners/adaption_prompt/config.py | 7 +++++++ src/peft/tuners/adaption_prompt/layer.py | 21 +++++++++++++++++++-- src/peft/tuners/adaption_prompt/utils.py | 6 +++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/adaption_prompt/config.py b/src/peft/tuners/adaption_prompt/config.py index 37d206248a..9aee0550dc 100644 --- a/src/peft/tuners/adaption_prompt/config.py +++ b/src/peft/tuners/adaption_prompt/config.py @@ -55,6 +55,13 @@ def is_adaption_prompt(self) -> bool: v_proj_layer="v_proj", o_proj_layer="o_proj", ), + "mistral": ModelTypeConfig( # same for Mistral, + compute_query_states=llama_compute_query_states, + target_modules="self_attn", + k_proj_layer="k_proj", + v_proj_layer="v_proj", + o_proj_layer="o_proj", + ), } diff --git a/src/peft/tuners/adaption_prompt/layer.py b/src/peft/tuners/adaption_prompt/layer.py index 0ca4701f58..18a81aca65 100644 --- a/src/peft/tuners/adaption_prompt/layer.py +++ b/src/peft/tuners/adaption_prompt/layer.py @@ -75,6 +75,9 @@ def forward(self, **kwargs): k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer + factor = ( + self.model.k_proj.in_features // self.model.k_proj.out_features + ) # Mistral has different input and output dimension for k_proj and v_proj layers if k_proj_layer == v_proj_layer: _, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2) @@ -82,14 +85,15 @@ def forward(self, **kwargs): key = getattr(self.model, k_proj_layer)(self.adaption_prompt) value = getattr(self.model, v_proj_layer)(self.adaption_prompt) # (bsz, num_heads, adapter_len, head_dim) + adapter_k = ( - key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim) + key.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor)) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) # (bsz, num_heads, adapter_len, head_dim) adapter_v = ( - value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim) + value.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor)) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) @@ -100,6 +104,15 @@ def forward(self, **kwargs): query_states = compute_query_states(model=self.model, **kwargs) previous_dtype = query_states.dtype + + # Reshape and average the extra tensors + query_states_reshaped = query_states.reshape( + bsz, self.model.num_heads, -1, (self.model.head_dim // factor), factor + ) + + # Take the mean along the last dimension to get [bsz, 32, X, 32] + query_states = query_states_reshaped.mean(dim=-1) + # (bsz, num_heads, q_len, adapter_len) scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt( self.model.head_dim @@ -109,6 +122,10 @@ def forward(self, **kwargs): scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype) # (bsz, q_len, num_heads * head_dim) adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1) + + adapter_output = torch.repeat_interleave( + adapter_output, repeats=factor, dim=2 + ) # https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 # (bsz, q_len, hidden_size) if o_proj_layer is not None: adapter_output = getattr(self.model, o_proj_layer)(adapter_output) diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index 0cbc95c1a1..cdc1bb047e 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -69,7 +69,11 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: past_key_value = kwargs.get("past_key_value") bsz, q_len, _ = hidden_states.size() query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) - value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) + + factor = model.k_proj.in_features // model.k_proj.out_features + value_states = ( + model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, (model.head_dim // factor)).transpose(1, 2) + ) seq_len = q_len if past_key_value is not None: From 9ebe2952c9535162271cbb3b5095ae261b1229ca Mon Sep 17 00:00:00 2001 From: PrakharSaxena24 <50725987+PrakharSaxena24@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:04:31 +0900 Subject: [PATCH 2/8] Update src/peft/tuners/adaption_prompt/layer.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- src/peft/tuners/adaption_prompt/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/adaption_prompt/layer.py b/src/peft/tuners/adaption_prompt/layer.py index 18a81aca65..19a1de5a78 100644 --- a/src/peft/tuners/adaption_prompt/layer.py +++ b/src/peft/tuners/adaption_prompt/layer.py @@ -87,7 +87,7 @@ def forward(self, **kwargs): # (bsz, num_heads, adapter_len, head_dim) adapter_k = ( - key.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor)) + key.view(1, self.adapter_len, (self.model.num_heads// factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) From c42ef25282af3965e4409a188d6e3448f0210565 Mon Sep 17 00:00:00 2001 From: PrakharSaxena24 <50725987+PrakharSaxena24@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:04:41 +0900 Subject: [PATCH 3/8] Update src/peft/tuners/adaption_prompt/layer.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- src/peft/tuners/adaption_prompt/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/adaption_prompt/layer.py b/src/peft/tuners/adaption_prompt/layer.py index 19a1de5a78..41df50a817 100644 --- a/src/peft/tuners/adaption_prompt/layer.py +++ b/src/peft/tuners/adaption_prompt/layer.py @@ -93,7 +93,7 @@ def forward(self, **kwargs): ) # (bsz, num_heads, adapter_len, head_dim) adapter_v = ( - value.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor)) + value.view(1, self.adapter_len, (self.model.num_heads// factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) From 139e3153eedf0d9108677bcdd88ddfcb1ef417f3 Mon Sep 17 00:00:00 2001 From: Prakhar Saxena Date: Thu, 8 Feb 2024 13:30:04 +0000 Subject: [PATCH 4/8] corrected logic and added test --- src/peft/tuners/adaption_prompt/layer.py | 23 +- src/peft/tuners/adaption_prompt/utils.py | 5 +- tests/test_adaption_prompt.py | 397 ++++++++++++++++++++++- 3 files changed, 405 insertions(+), 20 deletions(-) diff --git a/src/peft/tuners/adaption_prompt/layer.py b/src/peft/tuners/adaption_prompt/layer.py index 41df50a817..378712792f 100644 --- a/src/peft/tuners/adaption_prompt/layer.py +++ b/src/peft/tuners/adaption_prompt/layer.py @@ -84,20 +84,22 @@ def forward(self, **kwargs): else: key = getattr(self.model, k_proj_layer)(self.adaption_prompt) value = getattr(self.model, v_proj_layer)(self.adaption_prompt) - # (bsz, num_heads, adapter_len, head_dim) + # (bsz, num_key_value_heads, adapter_len, head_dim) adapter_k = ( - key.view(1, self.adapter_len, (self.model.num_heads// factor), self.model.head_dim) + key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) - # (bsz, num_heads, adapter_len, head_dim) adapter_v = ( - value.view(1, self.adapter_len, (self.model.num_heads// factor), self.model.head_dim) + value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) - + # Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 + # (bsz, num_heads, adapter_len, head_dim) + adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1) + adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1) # Recompute query states. compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states # (bsz, num_heads, q_len, head_dim) @@ -105,14 +107,6 @@ def forward(self, **kwargs): previous_dtype = query_states.dtype - # Reshape and average the extra tensors - query_states_reshaped = query_states.reshape( - bsz, self.model.num_heads, -1, (self.model.head_dim // factor), factor - ) - - # Take the mean along the last dimension to get [bsz, 32, X, 32] - query_states = query_states_reshaped.mean(dim=-1) - # (bsz, num_heads, q_len, adapter_len) scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt( self.model.head_dim @@ -123,9 +117,6 @@ def forward(self, **kwargs): # (bsz, q_len, num_heads * head_dim) adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1) - adapter_output = torch.repeat_interleave( - adapter_output, repeats=factor, dim=2 - ) # https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 # (bsz, q_len, hidden_size) if o_proj_layer is not None: adapter_output = getattr(self.model, o_proj_layer)(adapter_output) diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index cdc1bb047e..a4e5a7ccea 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -71,8 +71,11 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) factor = model.k_proj.in_features // model.k_proj.out_features + # value_states = ( + # model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, (model.head_dim // factor)).transpose(1, 2) + # ) value_states = ( - model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, (model.head_dim // factor)).transpose(1, 2) + model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2) ) seq_len = q_len diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 2607c185e9..4e28cb8d18 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -36,6 +36,13 @@ def is_llama_available() -> bool: return importlib.util.find_spec("transformers.models.llama.modeling_llama") is not None except ModuleNotFoundError: return False + +def is_mistral_available() -> bool: + """Check if mistral is available in the transformers library (it's not in earlier versions).""" + try: + return importlib.util.find_spec("transformers.models.mistral.modeling_mistral") is not None + except ModuleNotFoundError: + return False if is_llama_available(): @@ -43,6 +50,10 @@ def is_llama_available() -> bool: # that don't have a transformers package with Llama. from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +if is_mistral_available(): + # We guard the import statement so that our unit tests will pass in CI environments + # that don't have a transformers package with Mistral. + from transformers import MistralConfig, MistralForCausalLM, MistralModel class AdaptionPromptTester(TestCase, PeftCommonTester): """ @@ -55,7 +66,10 @@ class AdaptionPromptTester(TestCase, PeftCommonTester): def setUp(self): # Check that llama is available in transformers package before running each test. if not is_llama_available(): - self.skipTest("Llama not available in transformers. Skipping test.") + self.skipTest("Llama not available in transformers. Skipping all tests.") + else: + # Check for Mistral's availability. It might or might not be available. + self.mistral_available = is_mistral_available() @staticmethod def _create_test_llama_config(): @@ -68,8 +82,22 @@ def _create_test_llama_config(): num_attention_heads=4, use_cache=False, ) + + @staticmethod + def _create_test_mistral_config(): + """Create a test config for a small Mistral model for testing.""" + return MistralConfig( + vocab_size=16, + hidden_size=8, + intermediate_size=8, + num_hidden_layers=8, + num_attention_heads=4, + num_key_value_heads=2, + use_cache=False, + ) def test_attributes(self) -> None: + #Test Llama model = LlamaModel(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) model = get_peft_model(model, config) @@ -78,7 +106,19 @@ def test_attributes(self) -> None: self.assertTrue(hasattr(model, "from_pretrained")) self.assertTrue(hasattr(model, "push_to_hub")) + #Test Mistral + if self.mistral_available: + model_mistral = MistralModel(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) + model_mistral = get_peft_model(model_mistral, config_mistral) + + self.assertTrue(hasattr(model_mistral, "save_pretrained")) + self.assertTrue(hasattr(model_mistral, "from_pretrained")) + self.assertTrue(hasattr(model_mistral, "push_to_hub")) + + def test_prepare_for_training(self) -> None: + #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -89,7 +129,23 @@ def test_prepare_for_training(self) -> None: self.assertTrue(not dummy_output.requires_grad) + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) + dummy_output = model_mistral.get_input_embeddings()(dummy_input) + + self.assertTrue(not dummy_output.requires_grad) + + + + def test_prepare_for_int8_training(self) -> None: + #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) model = prepare_model_for_int8_training(model) model = model.to(self.torch_device) @@ -115,9 +171,41 @@ def make_inputs_require_grad(module, input, output): self.assertTrue(dummy_output.requires_grad) + #Test mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = prepare_model_for_int8_training(model_mistral) + model_mistral = model_mistral.to(self.torch_device) + + for param in model_mistral.parameters(): + self.assertTrue(not param.requires_grad) + + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + + # For backward compatibility + if hasattr(model_mistral, "enable_input_require_grads"): + model_mistral.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model_mistral.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) + dummy_output = model_mistral.get_input_embeddings()(dummy_input) + + self.assertTrue(dummy_output.requires_grad) + + + + def test_save_pretrained_regression(self) -> None: seed = 420 torch.manual_seed(seed) + + #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -160,9 +248,57 @@ def test_save_pretrained_regression(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname, safe_serialization=False) + + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) + + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + ) + + # check if `adapter_model.bin` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) + + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + + # check if `model.safetensors` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "mode.safetensors"))) + + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + + + def test_save_pretrained(self) -> None: seed = 420 torch.manual_seed(seed) + + #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -193,7 +329,7 @@ def test_save_pretrained(self) -> None: ) ) - # check if `adapter_model.bin` is present + # check if `adapter_model.safetensors` is present self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) # check if `adapter_config.json` is present @@ -205,9 +341,58 @@ def test_save_pretrained(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname) + + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) + + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + ) + + # check if `adapter_model.bin` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) + + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + + # check if `model.safetensors` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) + + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + + + + def test_save_pretrained_selected_adapters(self) -> None: seed = 420 torch.manual_seed(seed) + + #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -243,7 +428,7 @@ def test_save_pretrained_selected_adapters(self) -> None: ) ) - # check if `adapter_model.bin` is present + # check if `adapter_model.safetensors` is present self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) # check if `adapter_config.json` is present @@ -255,7 +440,62 @@ def test_save_pretrained_selected_adapters(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + new_adapter_config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral.add_adapter("new_adapter", new_adapter_config_mistral) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname) + + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + + model_from_pretrained_mistral.load_adapter(tmp_dirname, "new_adapter") + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) + + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + ) + + # check if `adapter_model.safetensors` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) + + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + + # check if `model.safetensors` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) + + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + + + + + def test_generate(self) -> None: + #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -270,6 +510,23 @@ def test_generate(self) -> None: # check if `generate` works if positional arguments are passed _ = model.generate(input_ids, attention_mask=attention_mask) + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + # check if `generate` works + _ = model_mistral.generate(input_ids=input_ids, attention_mask=attention_mask) + + # check if `generate` works if positional arguments are passed + _ = model_mistral.generate(input_ids, attention_mask=attention_mask) + + def test_sequence_adapter_ops(self) -> None: """Test sequence of adapter operations.""" # Test input data. @@ -339,6 +596,71 @@ def test_sequence_adapter_ops(self) -> None: self.assertFalse(torch.allclose(original_before.logits, default_after_set.logits)) self.assertFalse(torch.allclose(adapter_1_after.logits, default_after_set.logits)) + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = model_mistral.to(self.torch_device) + original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) + + # Get AdaptionPrompt model. + adapted_mistral = get_peft_model( + model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted_mistral = adapted_mistral.to(self.torch_device) + default_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + + # Test zero-init: The logits should be exactly the same. + assert_close(original_before.logits, default_before.logits, rtol=0, atol=0) + + # Single fine-tuning step on "default" adapter. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + default_before.loss.backward() + optimizer.step() + + # Test that the output changed. + default_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + self.assertFalse(torch.allclose(default_before.logits, default_after.logits)) + + with adapted_mistral.disable_adapter(): + # Test that the output is the same as the original output. + default_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, default_disabled.logits, rtol=0, atol=0) + + # Add new adapter 1. + adapted_mistral.add_adapter("adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM")) + # Test zero-init + adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) + + # Single fine-tuning step on adapter 1. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + adapter_1_before.loss.backward() + optimizer.step() + + # Test that adapter 1 output changed. + adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + self.assertFalse(torch.allclose(adapter_1_before.logits, adapter_1_after.logits)) + self.assertFalse(torch.allclose(original_before.logits, adapter_1_after.logits)) + self.assertFalse(torch.allclose(default_after.logits, adapter_1_after.logits)) + + with adapted_mistral.disable_adapter(): + # Test that the output is the same as the original output. + adapter_1_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_disabled.logits, rtol=0, atol=0) + + # Set adapter back to default. + adapted_mistral.set_adapter("default") + + # Test that the output is the same as the default output after training. + default_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(default_after.logits, default_after_set.logits, rtol=0, atol=0) + self.assertFalse(torch.allclose(original_before.logits, default_after_set.logits)) + self.assertFalse(torch.allclose(adapter_1_after.logits, default_after_set.logits)) + + + def test_add_and_set_while_disabled(self): """Test that adding and setting adapters while disabled works as intended.""" # Test input data. @@ -384,6 +706,49 @@ def test_add_and_set_while_disabled(self): adapter_1_after_set = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) + + #Test Mistral + if self.mistral_available: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = model_mistral.to(self.torch_device) + original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) + + # Get AdaptionPrompt model. + adapted_mistral = get_peft_model( + model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted_mistral = adapted_mistral.to(self.torch_device) + + with adapted_mistral.disable_adapter(): + adapted_mistral.add_adapter( + "adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM") + ) + + # Test that the output is the same as the original output. + adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) + + # Single fine-tuning step on adapter 1. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + adapter_1_before.loss.backward() + optimizer.step() + + # Test that adapter 1 output changed. + adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + self.assertFalse(torch.allclose(original_before.logits, adapter_1_after.logits)) + + adapted_mistral.set_adapter("default") + with adapted_mistral.disable_adapter(): + adapted_mistral.set_adapter("adapter 1") + + # Test that adapter 1 is active again. + adapter_1_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) + + + + def test_use_cache(self) -> None: """Test that AdaptionPrompt works when Llama config use_cache=True.""" torch.manual_seed(0) @@ -409,6 +774,32 @@ def test_use_cache(self) -> None: actual = adapted.generate(input_ids=input_ids, max_length=8) assert_close(expected, actual, rtol=0, atol=0) + #Test Mistral + if self.mistral_available: + original = MistralForCausalLM( + MistralConfig( + vocab_size=16, + hidden_size=8, + intermediate_size=8, + num_hidden_layers=8, + num_attention_heads=4, + num_key_value_heads=2, + use_cache=False, + ) + ).eval() + adapted = get_peft_model( + original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted = adapted.to(self.torch_device) + expected = adapted.generate(input_ids=input_ids, max_length=8) + + # Set use_cache = True and generate output again. + adapted.base_model.config.use_cache = True + actual = adapted.generate(input_ids=input_ids, max_length=8) + assert_close(expected, actual, rtol=0, atol=0) + + + def test_bf16_inference(self) -> None: """Test that AdaptionPrompt works when Llama using a half-precision model.""" input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) From e4011606af97e766958b8148dcb5b1dccbad3f17 Mon Sep 17 00:00:00 2001 From: Prakhar Saxena Date: Fri, 9 Feb 2024 05:57:05 +0000 Subject: [PATCH 5/8] removed commented out code --- src/peft/tuners/adaption_prompt/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index a4e5a7ccea..85d90b194e 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -71,9 +71,6 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) factor = model.k_proj.in_features // model.k_proj.out_features - # value_states = ( - # model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, (model.head_dim // factor)).transpose(1, 2) - # ) value_states = ( model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2) ) From 6a5d7298bae31aadb258994b4a7224ba796dfd8c Mon Sep 17 00:00:00 2001 From: Prakhar Saxena Date: Mon, 12 Feb 2024 09:52:04 +0000 Subject: [PATCH 6/8] Added seperate test functions for mistral --- src/peft/tuners/adaption_prompt/config.py | 2 +- tests/test_adaption_prompt.py | 561 +++++++++++----------- 2 files changed, 276 insertions(+), 287 deletions(-) diff --git a/src/peft/tuners/adaption_prompt/config.py b/src/peft/tuners/adaption_prompt/config.py index 9aee0550dc..24f6580c55 100644 --- a/src/peft/tuners/adaption_prompt/config.py +++ b/src/peft/tuners/adaption_prompt/config.py @@ -55,7 +55,7 @@ def is_adaption_prompt(self) -> bool: v_proj_layer="v_proj", o_proj_layer="o_proj", ), - "mistral": ModelTypeConfig( # same for Mistral, + "mistral": ModelTypeConfig( # same as llama, compute_query_states=llama_compute_query_states, target_modules="self_attn", k_proj_layer="k_proj", diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 4e28cb8d18..7dd9e7e027 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -36,7 +36,8 @@ def is_llama_available() -> bool: return importlib.util.find_spec("transformers.models.llama.modeling_llama") is not None except ModuleNotFoundError: return False - + + def is_mistral_available() -> bool: """Check if mistral is available in the transformers library (it's not in earlier versions).""" try: @@ -55,6 +56,7 @@ def is_mistral_available() -> bool: # that don't have a transformers package with Mistral. from transformers import MistralConfig, MistralForCausalLM, MistralModel + class AdaptionPromptTester(TestCase, PeftCommonTester): """ Tests for the AdaptionPrompt model. @@ -82,7 +84,7 @@ def _create_test_llama_config(): num_attention_heads=4, use_cache=False, ) - + @staticmethod def _create_test_mistral_config(): """Create a test config for a small Mistral model for testing.""" @@ -97,7 +99,6 @@ def _create_test_mistral_config(): ) def test_attributes(self) -> None: - #Test Llama model = LlamaModel(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) model = get_peft_model(model, config) @@ -106,19 +107,18 @@ def test_attributes(self) -> None: self.assertTrue(hasattr(model, "from_pretrained")) self.assertTrue(hasattr(model, "push_to_hub")) - #Test Mistral - if self.mistral_available: - model_mistral = MistralModel(self._create_test_mistral_config()) - config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) - model_mistral = get_peft_model(model_mistral, config_mistral) - - self.assertTrue(hasattr(model_mistral, "save_pretrained")) - self.assertTrue(hasattr(model_mistral, "from_pretrained")) - self.assertTrue(hasattr(model_mistral, "push_to_hub")) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_attributes_mistral(self) -> None: + model_mistral = MistralModel(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) + model_mistral = get_peft_model(model_mistral, config_mistral) + self.assertTrue(hasattr(model_mistral, "save_pretrained")) + self.assertTrue(hasattr(model_mistral, "from_pretrained")) + self.assertTrue(hasattr(model_mistral, "push_to_hub")) def test_prepare_for_training(self) -> None: - #Test Llama + # Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -129,23 +129,19 @@ def test_prepare_for_training(self) -> None: self.assertTrue(not dummy_output.requires_grad) - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") - model_mistral = get_peft_model(model_mistral, config_mistral) - model_mistral = model_mistral.to(self.torch_device) - - dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) - dummy_output = model_mistral.get_input_embeddings()(dummy_input) - - self.assertTrue(not dummy_output.requires_grad) - + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_prepare_for_training_mistral(self) -> None: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) + dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) + dummy_output = model_mistral.get_input_embeddings()(dummy_input) + self.assertTrue(not dummy_output.requires_grad) def test_prepare_for_int8_training(self) -> None: - #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) model = prepare_model_for_int8_training(model) model = model.to(self.torch_device) @@ -171,41 +167,36 @@ def make_inputs_require_grad(module, input, output): self.assertTrue(dummy_output.requires_grad) - #Test mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - model_mistral = prepare_model_for_int8_training(model_mistral) - model_mistral = model_mistral.to(self.torch_device) - - for param in model_mistral.parameters(): - self.assertTrue(not param.requires_grad) - - config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") - model_mistral = get_peft_model(model_mistral, config_mistral) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_prepare_model_for_int8_training_mistral(self) -> None: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = prepare_model_for_int8_training(model_mistral) + model_mistral = model_mistral.to(self.torch_device) - # For backward compatibility - if hasattr(model_mistral, "enable_input_require_grads"): - model_mistral.enable_input_require_grads() - else: + for param in model_mistral.parameters(): + self.assertTrue(not param.requires_grad) - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) + config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) - model_mistral.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + # For backward compatibility + if hasattr(model_mistral, "enable_input_require_grads"): + model_mistral.enable_input_require_grads() + else: - dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) - dummy_output = model_mistral.get_input_embeddings()(dummy_input) + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) - self.assertTrue(dummy_output.requires_grad) - + model_mistral.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) + dummy_output = model_mistral.get_input_embeddings()(dummy_input) + self.assertTrue(dummy_output.requires_grad) def test_save_pretrained_regression(self) -> None: seed = 420 torch.manual_seed(seed) - - #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -248,57 +239,55 @@ def test_save_pretrained_regression(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - model_mistral = get_peft_model(model_mistral, config_mistral) - model_mistral = model_mistral.to(self.torch_device) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_save_pretrained_regression_mistral(self) -> None: + seed = 420 + torch.manual_seed(seed) + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) - with tempfile.TemporaryDirectory() as tmp_dirname: - model_mistral.save_pretrained(tmp_dirname, safe_serialization=False) + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname, safe_serialization=False) - torch.manual_seed(seed) - model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) - model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) - # check if the state dicts are equal - state_dict = get_peft_model_state_dict(model_mistral) - state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) - # check if same keys - self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) - # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - self.assertEqual(len(list(state_dict.keys())), 4) + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) - # check if tensors equal - for key in state_dict.keys(): - self.assertTrue( - torch.allclose( - state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) - ) + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) ) + ) - # check if `adapter_model.bin` is present - self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) - - # check if `adapter_config.json` is present - self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) - - # check if `model.safetensors` is not present - self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "mode.safetensors"))) + # check if `adapter_model.bin` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) - # check if `config.json` is not present - self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + # check if `model.safetensors` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) def test_save_pretrained(self) -> None: seed = 420 torch.manual_seed(seed) - - #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -341,58 +330,55 @@ def test_save_pretrained(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - model_mistral = get_peft_model(model_mistral, config_mistral) - model_mistral = model_mistral.to(self.torch_device) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_save_pretrained_mistral(self) -> None: + seed = 420 + torch.manual_seed(seed) + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) - with tempfile.TemporaryDirectory() as tmp_dirname: - model_mistral.save_pretrained(tmp_dirname) + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname) - torch.manual_seed(seed) - model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) - model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) - # check if the state dicts are equal - state_dict = get_peft_model_state_dict(model_mistral) - state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) - # check if same keys - self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) - # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - self.assertEqual(len(list(state_dict.keys())), 4) + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) - # check if tensors equal - for key in state_dict.keys(): - self.assertTrue( - torch.allclose( - state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) - ) + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) ) + ) - # check if `adapter_model.bin` is present - self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) - - # check if `adapter_config.json` is present - self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) - - # check if `model.safetensors` is not present - self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) - - # check if `config.json` is not present - self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + # check if `adapter_model.bin` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + # check if `model.safetensors` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) def test_save_pretrained_selected_adapters(self) -> None: seed = 420 torch.manual_seed(seed) - - #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -440,62 +426,58 @@ def test_save_pretrained_selected_adapters(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) - - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - model_mistral = get_peft_model(model_mistral, config_mistral) - model_mistral = model_mistral.to(self.torch_device) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_save_pretrained_selected_adapters_mistral(self) -> None: + seed = 420 + torch.manual_seed(seed) + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) - new_adapter_config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - model_mistral.add_adapter("new_adapter", new_adapter_config_mistral) + new_adapter_config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral.add_adapter("new_adapter", new_adapter_config_mistral) - with tempfile.TemporaryDirectory() as tmp_dirname: - model_mistral.save_pretrained(tmp_dirname) + with tempfile.TemporaryDirectory() as tmp_dirname: + model_mistral.save_pretrained(tmp_dirname) - torch.manual_seed(seed) - model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) - model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) + torch.manual_seed(seed) + model_from_pretrained_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_from_pretrained_mistral = PeftModel.from_pretrained(model_from_pretrained_mistral, tmp_dirname) - model_from_pretrained_mistral.load_adapter(tmp_dirname, "new_adapter") + model_from_pretrained_mistral.load_adapter(tmp_dirname, "new_adapter") - # check if the state dicts are equal - state_dict = get_peft_model_state_dict(model_mistral) - state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model_mistral) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained_mistral) - # check if same keys - self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) - # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - self.assertEqual(len(list(state_dict.keys())), 4) + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) - # check if tensors equal - for key in state_dict.keys(): - self.assertTrue( - torch.allclose( - state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) - ) + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) ) + ) - # check if `adapter_model.safetensors` is present - self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) - - # check if `adapter_config.json` is present - self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) - - # check if `model.safetensors` is not present - self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) + # check if `adapter_model.safetensors` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))) - # check if `config.json` is not present - self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) - + # check if `model.safetensors` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))) - + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) def test_generate(self) -> None: - #Test Llama model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") model = get_peft_model(model, config) @@ -510,22 +492,21 @@ def test_generate(self) -> None: # check if `generate` works if positional arguments are passed _ = model.generate(input_ids, attention_mask=attention_mask) - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - model_mistral = get_peft_model(model_mistral, config_mistral) - model_mistral = model_mistral.to(self.torch_device) - - input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) - attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_generate_mistral(self) -> None: + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + config_mistral = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model_mistral = get_peft_model(model_mistral, config_mistral) + model_mistral = model_mistral.to(self.torch_device) - # check if `generate` works - _ = model_mistral.generate(input_ids=input_ids, attention_mask=attention_mask) + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) - # check if `generate` works if positional arguments are passed - _ = model_mistral.generate(input_ids, attention_mask=attention_mask) + # check if `generate` works + _ = model_mistral.generate(input_ids=input_ids, attention_mask=attention_mask) + # check if `generate` works if positional arguments are passed + _ = model_mistral.generate(input_ids, attention_mask=attention_mask) def test_sequence_adapter_ops(self) -> None: """Test sequence of adapter operations.""" @@ -596,70 +577,76 @@ def test_sequence_adapter_ops(self) -> None: self.assertFalse(torch.allclose(original_before.logits, default_after_set.logits)) self.assertFalse(torch.allclose(adapter_1_after.logits, default_after_set.logits)) - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - model_mistral = model_mistral.to(self.torch_device) - original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) - - # Get AdaptionPrompt model. - adapted_mistral = get_peft_model( - model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - ) - adapted_mistral = adapted_mistral.to(self.torch_device) - default_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_sequence_adapter_ops_mistral(self) -> None: + # Test input data. + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) - # Test zero-init: The logits should be exactly the same. - assert_close(original_before.logits, default_before.logits, rtol=0, atol=0) + # Create original mistral model. + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = model_mistral.to(self.torch_device) + original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) - # Single fine-tuning step on "default" adapter. - optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) - optimizer.zero_grad() - default_before.loss.backward() - optimizer.step() + # Get AdaptionPrompt model. + adapted_mistral = get_peft_model( + model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted_mistral = adapted_mistral.to(self.torch_device) + default_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - # Test that the output changed. - default_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - self.assertFalse(torch.allclose(default_before.logits, default_after.logits)) + # Test zero-init: The logits should be exactly the same. + assert_close(original_before.logits, default_before.logits, rtol=0, atol=0) - with adapted_mistral.disable_adapter(): - # Test that the output is the same as the original output. - default_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - assert_close(original_before.logits, default_disabled.logits, rtol=0, atol=0) + # Single fine-tuning step on "default" adapter. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + default_before.loss.backward() + optimizer.step() - # Add new adapter 1. - adapted_mistral.add_adapter("adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM")) - # Test zero-init - adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) + # Test that the output changed. + default_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + self.assertFalse(torch.allclose(default_before.logits, default_after.logits)) - # Single fine-tuning step on adapter 1. - optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) - optimizer.zero_grad() - adapter_1_before.loss.backward() - optimizer.step() + with adapted_mistral.disable_adapter(): + # Test that the output is the same as the original output. + default_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, default_disabled.logits, rtol=0, atol=0) - # Test that adapter 1 output changed. - adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - self.assertFalse(torch.allclose(adapter_1_before.logits, adapter_1_after.logits)) - self.assertFalse(torch.allclose(original_before.logits, adapter_1_after.logits)) - self.assertFalse(torch.allclose(default_after.logits, adapter_1_after.logits)) + # Add new adapter 1. + adapted_mistral.add_adapter( + "adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM") + ) + # Test zero-init + adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) - with adapted_mistral.disable_adapter(): - # Test that the output is the same as the original output. - adapter_1_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - assert_close(original_before.logits, adapter_1_disabled.logits, rtol=0, atol=0) + # Single fine-tuning step on adapter 1. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + adapter_1_before.loss.backward() + optimizer.step() - # Set adapter back to default. - adapted_mistral.set_adapter("default") + # Test that adapter 1 output changed. + adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + self.assertFalse(torch.allclose(adapter_1_before.logits, adapter_1_after.logits)) + self.assertFalse(torch.allclose(original_before.logits, adapter_1_after.logits)) + self.assertFalse(torch.allclose(default_after.logits, adapter_1_after.logits)) - # Test that the output is the same as the default output after training. - default_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - assert_close(default_after.logits, default_after_set.logits, rtol=0, atol=0) - self.assertFalse(torch.allclose(original_before.logits, default_after_set.logits)) - self.assertFalse(torch.allclose(adapter_1_after.logits, default_after_set.logits)) + with adapted_mistral.disable_adapter(): + # Test that the output is the same as the original output. + adapter_1_disabled = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_disabled.logits, rtol=0, atol=0) + # Set adapter back to default. + adapted_mistral.set_adapter("default") + # Test that the output is the same as the default output after training. + default_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(default_after.logits, default_after_set.logits, rtol=0, atol=0) + self.assertFalse(torch.allclose(original_before.logits, default_after_set.logits)) + self.assertFalse(torch.allclose(adapter_1_after.logits, default_after_set.logits)) def test_add_and_set_while_disabled(self): """Test that adding and setting adapters while disabled works as intended.""" @@ -706,48 +693,50 @@ def test_add_and_set_while_disabled(self): adapter_1_after_set = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_add_and_set_while_disabled_mistral(self): + # Test input data. + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) - #Test Mistral - if self.mistral_available: - model_mistral = MistralForCausalLM(self._create_test_mistral_config()) - model_mistral = model_mistral.to(self.torch_device) - original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) - - # Get AdaptionPrompt model. - adapted_mistral = get_peft_model( - model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - ) - adapted_mistral = adapted_mistral.to(self.torch_device) - - with adapted_mistral.disable_adapter(): - adapted_mistral.add_adapter( - "adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM") - ) - - # Test that the output is the same as the original output. - adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) + # Create original mistral model. + model_mistral = MistralForCausalLM(self._create_test_mistral_config()) + model_mistral = model_mistral.to(self.torch_device) + original_before = model_mistral(input_ids=input_ids, attention_mask=attention_mask) - # Single fine-tuning step on adapter 1. - optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) - optimizer.zero_grad() - adapter_1_before.loss.backward() - optimizer.step() + # Get AdaptionPrompt model. + adapted_mistral = get_peft_model( + model_mistral, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted_mistral = adapted_mistral.to(self.torch_device) - # Test that adapter 1 output changed. - adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - self.assertFalse(torch.allclose(original_before.logits, adapter_1_after.logits)) + with adapted_mistral.disable_adapter(): + adapted_mistral.add_adapter( + "adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM") + ) - adapted_mistral.set_adapter("default") - with adapted_mistral.disable_adapter(): - adapted_mistral.set_adapter("adapter 1") + # Test that the output is the same as the original output. + adapter_1_before = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0) - # Test that adapter 1 is active again. - adapter_1_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) - assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) + # Single fine-tuning step on adapter 1. + optimizer = torch.optim.SGD(adapted_mistral.parameters(), lr=1) + optimizer.zero_grad() + adapter_1_before.loss.backward() + optimizer.step() + # Test that adapter 1 output changed. + adapter_1_after = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + self.assertFalse(torch.allclose(original_before.logits, adapter_1_after.logits)) + adapted_mistral.set_adapter("default") + with adapted_mistral.disable_adapter(): + adapted_mistral.set_adapter("adapter 1") + # Test that adapter 1 is active again. + adapter_1_after_set = adapted_mistral(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids) + assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0) def test_use_cache(self) -> None: """Test that AdaptionPrompt works when Llama config use_cache=True.""" @@ -774,31 +763,31 @@ def test_use_cache(self) -> None: actual = adapted.generate(input_ids=input_ids, max_length=8) assert_close(expected, actual, rtol=0, atol=0) - #Test Mistral - if self.mistral_available: - original = MistralForCausalLM( - MistralConfig( - vocab_size=16, - hidden_size=8, - intermediate_size=8, - num_hidden_layers=8, - num_attention_heads=4, - num_key_value_heads=2, - use_cache=False, - ) - ).eval() - adapted = get_peft_model( - original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + @unittest.skipIf(not is_mistral_available(), "Mistral is not available") + def test_use_cache_mistral(self) -> None: + torch.manual_seed(0) + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + original = MistralForCausalLM( + MistralConfig( + vocab_size=16, + hidden_size=8, + intermediate_size=8, + num_hidden_layers=8, + num_attention_heads=4, + num_key_value_heads=2, + use_cache=False, ) - adapted = adapted.to(self.torch_device) - expected = adapted.generate(input_ids=input_ids, max_length=8) - - # Set use_cache = True and generate output again. - adapted.base_model.config.use_cache = True - actual = adapted.generate(input_ids=input_ids, max_length=8) - assert_close(expected, actual, rtol=0, atol=0) - + ).eval() + adapted = get_peft_model( + original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + ) + adapted = adapted.to(self.torch_device) + expected = adapted.generate(input_ids=input_ids, max_length=8) + # Set use_cache = True and generate output again. + adapted.base_model.config.use_cache = True + actual = adapted.generate(input_ids=input_ids, max_length=8) + assert_close(expected, actual, rtol=0, atol=0) def test_bf16_inference(self) -> None: """Test that AdaptionPrompt works when Llama using a half-precision model.""" From 53607459e47c2fb0631a0557331eb0c5d47492a2 Mon Sep 17 00:00:00 2001 From: PrakharSaxena24 Date: Tue, 20 Feb 2024 00:31:35 +0900 Subject: [PATCH 7/8] missed self.assert --- tests/test_adaption_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index ce3e0fbf1a..0a66eeb487 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -174,7 +174,7 @@ def test_prepare_model_for_int8_training_mistral(self) -> None: model_mistral = model_mistral.to(self.torch_device) for param in model_mistral.parameters(): - self.assertTrue(not param.requires_grad) + assert not param.requires_grad config_mistral = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM") model_mistral = get_peft_model(model_mistral, config_mistral) From a97e7ead47e9d3eca07c333b30c4f563024bef07 Mon Sep 17 00:00:00 2001 From: PrakharSaxena24 Date: Tue, 20 Feb 2024 00:36:56 +0900 Subject: [PATCH 8/8] ruff formatting --- tests/test_adaption_prompt.py | 62 +++++++++++++++++------------------ 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 0a66eeb487..ec3deb150a 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -282,47 +282,47 @@ def test_save_pretrained_regression_mistral(self) -> None: assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) def test_save_pretrained(self) -> None: - seed = 420 - torch.manual_seed(seed) - model = LlamaForCausalLM(self._create_test_llama_config()) - config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") - model = get_peft_model(model, config) - model = model.to(self.torch_device) + seed = 420 + torch.manual_seed(seed) + model = LlamaForCausalLM(self._create_test_llama_config()) + config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model = get_peft_model(model, config) + model = model.to(self.torch_device) - with tempfile.TemporaryDirectory() as tmp_dirname: - model.save_pretrained(tmp_dirname) + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) - torch.manual_seed(seed) - model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config()) - model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + torch.manual_seed(seed) + model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config()) + model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) - # check if the state dicts are equal - state_dict = get_peft_model_state_dict(model) - state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) - # check if same keys - assert state_dict.keys() == state_dict_from_pretrained.keys() + # check if same keys + assert state_dict.keys() == state_dict_from_pretrained.keys() - # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - assert len(state_dict) == 4 + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + assert len(state_dict) == 4 - # check if tensors equal - for key in state_dict.keys(): - assert torch.allclose( - state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) - ) + # check if tensors equal + for key in state_dict.keys(): + assert torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) - # check if `adapter_model.bin` is present - assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")) + # check if `adapter_model.bin` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")) - # check if `adapter_config.json` is present - assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) + # check if `adapter_config.json` is present + assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) - # check if `model.safetensors` is not present - assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors")) + # check if `model.safetensors` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors")) - # check if `config.json` is not present - assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) + # check if `config.json` is not present + assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) @unittest.skipIf(not is_mistral_available(), "Mistral is not available") def test_save_pretrained_mistral(self) -> None: