diff --git a/docs/source/developer_guides/troubleshooting.md b/docs/source/developer_guides/troubleshooting.md index 179258aca9..91dbbde266 100644 --- a/docs/source/developer_guides/troubleshooting.md +++ b/docs/source/developer_guides/troubleshooting.md @@ -69,6 +69,12 @@ trainer = Trainer(model=peft_model, fp16=True, ...) trainer.train() ``` + + +Starting from PEFT verion v0.11.0, PEFT automatically promotes the dtype of adapter weights from `torch.float16` and `torch.bfloat16` to `torch.float32` where appropriate. To _prevent_ this behavior, you can pass `autocast_adapter_dtype=False` to [`~get_peft_model`], to [`~PeftModel.from_pretrained`], and to [`~PeftModel.load_adapter`]. + + + ## Bad results from a loaded PEFT model There can be several reasons for getting a poor result from a loaded PEFT model which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue. diff --git a/src/peft/mapping.py b/src/peft/mapping.py index dfea91c9ac..77b6b66eea 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -122,6 +122,7 @@ def get_peft_model( peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False, + autocast_adapter_dtype: bool = True, revision: Optional[str] = None, ) -> PeftModel | PeftMixedModel: """ @@ -136,6 +137,10 @@ def get_peft_model( The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). mixed (`bool`, `optional`, defaults to `False`): Whether to allow mixing different (compatible) adapter types. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. revision (`str`, `optional`, defaults to `main`): The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for the base model @@ -154,14 +159,17 @@ def get_peft_model( peft_config.revision = revision if mixed: + # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it return PeftMixedModel(model, peft_config, adapter_name=adapter_name) if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: - return PeftModel(model, peft_config, adapter_name=adapter_name) + return PeftModel(model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype) if peft_config.is_prompt_learning: peft_config = _prepare_prompt_learning_config(peft_config, model_config) - return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) + return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( + model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype + ) def inject_adapter_in_model( diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index adaa23cd6e..509705ef0a 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -102,6 +102,10 @@ class PeftModel(PushToHubMixin, torch.nn.Module): model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft. peft_config ([`PeftConfig`]): The configuration of the Peft model. adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. **Attributes**: - **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft. @@ -118,7 +122,13 @@ class PeftModel(PushToHubMixin, torch.nn.Module): in the base model if using [`PromptLearningConfig`]. """ - def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None: + def __init__( + self, + model: PreTrainedModel, + peft_config: PeftConfig, + adapter_name: str = "default", + autocast_adapter_dtype: bool = True, + ) -> None: super().__init__() self.modules_to_save = None self.active_adapter = adapter_name @@ -138,6 +148,11 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name self.base_model = cls(model, {adapter_name: peft_config}, adapter_name) self.set_additional_trainable_modules(peft_config, adapter_name) + if hasattr(self.base_model, "_cast_adapter_dtype"): + self.base_model._cast_adapter_dtype( + adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype + ) + if getattr(model, "is_gradient_checkpointing", True): model = self._prepare_model_for_gradient_checkpointing(model) @@ -335,6 +350,7 @@ def from_pretrained( adapter_name: str = "default", is_trainable: bool = False, config: Optional[PeftConfig] = None, + autocast_adapter_dtype: bool = True, **kwargs: Any, ) -> PeftModel: r""" @@ -361,6 +377,8 @@ def from_pretrained( The configuration object to use instead of an automatically loaded configuration. This configuration object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already loaded before calling `from_pretrained`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Only relevant for specific adapter types. kwargs: (`optional`): Additional keyword arguments passed along to the specific PEFT configuration class. """ @@ -424,10 +442,15 @@ def from_pretrained( config.inference_mode = not is_trainable if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): - model = cls(model, config, adapter_name) + model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype) else: - model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name) - model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) + model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type]( + model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype + ) + model.load_adapter( + model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs + ) + return model def _setup_prompt_encoder(self, adapter_name: str): @@ -935,6 +958,7 @@ def load_adapter( adapter_name: str, is_trainable: bool = False, torch_device: Optional[str] = None, + autocast_adapter_dtype: bool = True, **kwargs: Any, ): """ @@ -955,6 +979,10 @@ def load_adapter( used for inference. torch_device (`str`, *optional*, defaults to None): The device to load the adapter on. If `None`, the device will be inferred. + autocast_adapter_dtype (`bool`, *optional*, defaults to `True`): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter + weights using float16 and bfloat16 to float32, as this is typically required for stable training, and + only affect select PEFT tuners. kwargs: (`optional`): Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. """ @@ -1034,6 +1062,11 @@ def load_adapter( remove_hook_from_submodules(self.prompt_encoder) add_hook_to_module(self.get_base_model(), hook) + if hasattr(self.base_model, "_cast_adapter_dtype"): + self.base_model._cast_adapter_dtype( + adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype + ) + # Set model in evaluation mode to deactivate Dropout modules by default if not is_trainable: self.eval() @@ -1133,6 +1166,11 @@ class PeftModelForSequenceClassification(PeftModel): Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. @@ -1166,8 +1204,10 @@ class PeftModelForSequenceClassification(PeftModel): ``` """ - def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: - super().__init__(model, peft_config, adapter_name) + def __init__( + self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs + ) -> None: + super().__init__(model, peft_config, adapter_name, **kwargs) classifier_module_names = ["classifier", "score"] if self.modules_to_save is None: @@ -1361,7 +1401,11 @@ class PeftModelForCausalLM(PeftModel): Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. - + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. Example: @@ -1391,8 +1435,10 @@ class PeftModelForCausalLM(PeftModel): ``` """ - def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: - super().__init__(model, peft_config, adapter_name) + def __init__( + self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs + ) -> None: + super().__init__(model, peft_config, adapter_name, **kwargs) self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation def forward( @@ -1566,7 +1612,11 @@ class PeftModelForSeq2SeqLM(PeftModel): Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. - + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. Example: @@ -1595,8 +1645,10 @@ class PeftModelForSeq2SeqLM(PeftModel): ``` """ - def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: - super().__init__(model, peft_config, adapter_name) + def __init__( + self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs + ) -> None: + super().__init__(model, peft_config, adapter_name, **kwargs) self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation self.base_model_prepare_encoder_decoder_kwargs_for_generation = ( self.base_model._prepare_encoder_decoder_kwargs_for_generation @@ -1820,6 +1872,11 @@ class PeftModelForTokenClassification(PeftModel): Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. @@ -1853,8 +1910,10 @@ class PeftModelForTokenClassification(PeftModel): ``` """ - def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None: - super().__init__(model, peft_config, adapter_name) + def __init__( + self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default", **kwargs + ) -> None: + super().__init__(model, peft_config, adapter_name, **kwargs) classifier_module_names = ["classifier", "score"] if self.modules_to_save is None: @@ -2032,6 +2091,11 @@ class PeftModelForQuestionAnswering(PeftModel): Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. @@ -2063,8 +2127,10 @@ class PeftModelForQuestionAnswering(PeftModel): ``` """ - def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: - super().__init__(model, peft_config, adapter_name) + def __init__( + self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs + ) -> None: + super().__init__(model, peft_config, adapter_name, **kwargs) qa_module_names = ["qa_outputs"] if self.modules_to_save is None: @@ -2265,6 +2331,11 @@ class PeftModelForFeatureExtraction(PeftModel): Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. @@ -2293,8 +2364,8 @@ class PeftModelForFeatureExtraction(PeftModel): ``` """ - def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default"): - super().__init__(model, peft_config, adapter_name) + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs): + super().__init__(model, peft_config, adapter_name, **kwargs) def forward( self, diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index f4a6ba53dc..8f96ce1092 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -304,6 +304,44 @@ def _check_new_adapter_config(self, config: PeftConfig) -> None: """ pass + def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None: + """ + A helper method to cast the adapter weights to the correct dtype. + + Currently, this only upcasts float16 and bfloat16 to float32. + + Args: + adapter_name (`str`): + The adapter name. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. + + """ + if not autocast_adapter_dtype: + return + + dtypes_to_convert_to_fp32 = {torch.float16, torch.bfloat16} + + for module in self.model.modules(): + if not isinstance(module, BaseTunerLayer): + continue + + for submodule in module.modules(): + if not isinstance(submodule, (nn.ModuleDict, nn.ParameterDict)): + continue + + if adapter_name not in submodule: + continue + + if isinstance(submodule[adapter_name], nn.Parameter): + if submodule[adapter_name].dtype in dtypes_to_convert_to_fp32: + submodule[adapter_name].data = submodule[adapter_name].data.to(torch.float32) + continue + + for param in submodule[adapter_name].parameters(): + if param.dtype in dtypes_to_convert_to_fp32: + param.data = param.data.to(torch.float32) + def _check_merge_allowed(self): """Helper method to check whether the adapter can be merged. @@ -311,7 +349,7 @@ def _check_merge_allowed(self): """ pass - def inject_adapter(self, model: nn.Module, adapter_name: str): + def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None: r""" Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed. @@ -323,6 +361,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): The model to be tuned. adapter_name (`str`): The adapter name. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. """ peft_config = self.peft_config[adapter_name] # Note: If possible, all checks should be performed *at the start of this method*. diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 2316c76db5..1486b004a5 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -16,6 +16,7 @@ import os import tempfile import unittest +from collections import Counter from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, List, Union @@ -2002,7 +2003,7 @@ def test_notebook_launcher(self): @require_torch_gpu class MixedPrecisionTests(unittest.TestCase): def setUp(self): - self.causal_lm_model_id = "facebook/opt-350m" + self.causal_lm_model_id = "facebook/opt-125m" self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) self.config = LoraConfig( r=16, @@ -2024,15 +2025,14 @@ def tearDown(self): gc.collect() @pytest.mark.single_gpu_tests - def test_model_loaded_in_float16_raises(self): - # This test shows the issue with loading the model in fp16 and then trying to use it with mixed precision - # training, which should not use fp16. If this is ever automated in PEFT, this test should fail. In that case, - # remove this test, adjust the next one, and remove the entry about FP16 usage from troubleshooting.md. + def test_model_using_float16_with_amp_raises(self): + # This test shows the issue with using a model in fp16 and then trying to use it with mixed precision training, + # which should not use fp16. model = AutoModelForCausalLM.from_pretrained( self.causal_lm_model_id, torch_dtype=torch.float16, ) - model = get_peft_model(model, self.config) + model = get_peft_model(model, self.config, autocast_adapter_dtype=False) with tempfile.TemporaryDirectory() as tmp_dir: trainer = Trainer( @@ -2040,8 +2040,8 @@ def test_model_loaded_in_float16_raises(self): train_dataset=self.data["train"], args=TrainingArguments( fp16=True, # <= this is required for the error to be raised - logging_steps=1, output_dir=tmp_dir, + max_steps=3, ), data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), ) @@ -2049,32 +2049,216 @@ def test_model_loaded_in_float16_raises(self): trainer.train() @pytest.mark.single_gpu_tests - def test_model_loaded_in_float16_working(self): - # Same test as before but containing the fix to make it work + def test_model_using_float16_autocast_dtype(self): + # Here we use autocast_adapter_dtype=True (the default) to automatically promote the adapter weights to float32. + # No exception should be raised. + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + ) + model = get_peft_model(model, self.config, autocast_adapter_dtype=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = Trainer( + model=model, + train_dataset=self.data["train"], + args=TrainingArguments( + fp16=True, # <= this is required for the error to be raised + output_dir=tmp_dir, + max_steps=3, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + trainer.train() # does not raise + + @pytest.mark.single_gpu_tests + def test_model_using_float16_explicit_cast(self): + # Same test as above but containing the fix to make it work model = AutoModelForCausalLM.from_pretrained( self.causal_lm_model_id, torch_dtype=torch.float16, ) - model = get_peft_model(model, self.config) + model = get_peft_model(model, self.config, autocast_adapter_dtype=False) - # for now, this is unfortunately necessary to avoid the error: - # ValueError: Attempting to unscale FP16 gradients. + # here we manually promote the adapter weights to float32 for param in model.parameters(): if param.requires_grad: param.data = param.data.float() + dtype_counts_before = Counter(p.dtype for p in model.parameters()) + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + ) + + model = get_peft_model(model, self.config, autocast_adapter_dtype=True) + dtype_counts_after = Counter(p.dtype for p in model.parameters()) + assert dtype_counts_before == dtype_counts_after + with tempfile.TemporaryDirectory() as tmp_dir: trainer = Trainer( model=model, train_dataset=self.data["train"], args=TrainingArguments( - fp16=True, + fp16=True, # <= this is required for the error to be raised max_steps=3, output_dir=tmp_dir, ), data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), ) - trainer.train() + trainer.train() # does not raise + + @pytest.mark.single_gpu_tests + def test_load_model_using_float16_with_amp_raises(self): + # Same as previous tests, but loading the adapter with PeftModel.from_pretrained instead + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + ) + model = get_peft_model(model, self.config, autocast_adapter_dtype=False) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, torch_dtype=torch.float16) + model = PeftModel.from_pretrained(model, tmp_dir, autocast_adapter_dtype=False, is_trainable=True) + + trainer = Trainer( + model=model, + train_dataset=self.data["train"], + args=TrainingArguments( + fp16=True, # <= this is required for the error to be raised + output_dir=tmp_dir, + max_steps=3, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."): + trainer.train() + + @pytest.mark.single_gpu_tests + def test_load_model_using_float16_autocast_dtype(self): + # Same as previous tests, but loading the adapter with PeftModel.from_pretrained instead + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + ) + # Below, we purposefully set autocast_adapter_dtype=False so that the saved adapter uses float16. We still want + # the loaded adapter to use float32 when we load it with autocast_adapter_dtype=True. + model = get_peft_model(model, self.config, autocast_adapter_dtype=False) + # sanity check: this should have float16 adapter weights: + assert ( + model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype + == torch.float16 + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, torch_dtype=torch.float16) + model = PeftModel.from_pretrained(model, tmp_dir, autocast_adapter_dtype=True, is_trainable=True) + # sanity check: this should NOT have float16 adapter weights: + assert ( + model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype + == torch.float32 + ) + + trainer = Trainer( + model=model, + train_dataset=self.data["train"], + args=TrainingArguments( + fp16=True, # <= this is required for the error to be raised + output_dir=tmp_dir, + max_steps=3, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + trainer.train() # does not raise + + @pytest.mark.single_gpu_tests + def test_load_adapter_using_float16_autocast_dtype(self): + # Here we test the load_adapter method with autocast_adapter_dtype. We show that autocasting is prevented when + # calling load_model(..., autocast_adapter_dtype=False) and that it is enabled when calling + # load_model(..., autocast_adapter_dtype=True) (the default). + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + ) + # Below, we purposefully set autocast_adapter_dtype=False so that the saved adapter uses float16. We still want + # the loaded adapter to use float32 when we load it with autocast_adapter_dtype=True. + model = get_peft_model(model, self.config, autocast_adapter_dtype=False) + # sanity check: this should have float16 adapter weights: + assert ( + model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype + == torch.float16 + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, torch_dtype=torch.float16) + # the default adapter is now in float16 + model = get_peft_model(model, self.config, autocast_adapter_dtype=False) + # sanity check: this should NOT have float16 adapter weights: + assert ( + model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype + == torch.float16 + ) + + # now load the first adapter in float16 using the adapter name "loaded16" + model.load_adapter(tmp_dir, "loaded16", autocast_adapter_dtype=False) + assert ( + model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["loaded16"].weight.dtype + == torch.float16 + ) + + # now load the first adapter in float32 using the adapter name "loaded32" + model.load_adapter(tmp_dir, "loaded32", autocast_adapter_dtype=True) + assert ( + model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["loaded32"].weight.dtype + == torch.float32 + ) + + # training with the default adapter, which is in float16, should raise + model.set_adapter("default") + trainer = Trainer( + model=model, + train_dataset=self.data["train"], + args=TrainingArguments( + fp16=True, # <= this is required for the error to be raised + output_dir=tmp_dir, + max_steps=3, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."): + trainer.train() + + # training the model with the adapter "loaded16", which is in float16, should also raise + model.set_adapter("loaded16") + trainer = Trainer( + model=model, + train_dataset=self.data["train"], + args=TrainingArguments( + fp16=True, # <= this is required for the error to be raised + output_dir=tmp_dir, + max_steps=3, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."): + trainer.train() + + # training the model with the adapter "loaded32", which is in float32, should not raise + model.set_adapter("loaded32") + trainer = Trainer( + model=model, + train_dataset=self.data["train"], + args=TrainingArguments( + fp16=True, # <= this is required for the error to be raised + output_dir=tmp_dir, + max_steps=3, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + trainer.train() # does not raise @require_torch_gpu