From 2e38ee7ed9e5bc305969ea6d1ea70cac6e7deadc Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 May 2024 13:07:08 +0200 Subject: [PATCH 1/3] FIX Store batch norm buffers in PEFT checkpoint Fixes #1732 After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen. The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely. This PR fixes this issue by saving the buffers of the batch norm layers. They are identified by checking for the presence of the track_running_stats attribute. Note: One test for BOFT is currently failing, see the comment in the test file. --- src/peft/utils/save_and_load.py | 38 +++++++++++ tests/test_vision_models.py | 116 ++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 tests/test_vision_models.py diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index dc77d0f660..8123126eaa 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -45,6 +45,35 @@ def get_embedding_layer_name(model, layer, is_embedding_in_target_modules): return None +def get_buffers_to_save(model: torch.nn.Module, config) -> dict[str, torch.Tensor]: + """ + Get the buffers to save for the given model and config. + + Although the base model weights are not updated when using PEFT, some buffers may still be updated and therefore + need to be saved as part of the PEFT checkpoint. An example of this are the running stats of batch norm layers. + + Args: + model (torch.nn.Module): + The PEFT model. + config (PeftConfig): + The PEFT config. + + Returns: + dict[str, torch.Tensor]: + The buffers to save. + """ + # note: config is currently not used but that may change in the future + buffers_to_save = {} + for module_name, module in model.named_modules(): + # currently, we only deal with running stats from BatchNorm* modules + if not hasattr(module, "track_running_stats"): + continue + for buffer_name, buffer in module.named_buffers(): + if buffer is not None: + buffers_to_save[module_name + "." + buffer_name] = buffer + return buffers_to_save + + def get_peft_model_state_dict( model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto" ): @@ -71,6 +100,8 @@ def get_peft_model_state_dict( config = model.peft_config[adapter_name] if state_dict is None: state_dict = model.state_dict() + + # TUNER SPECIFIC CODE if config.peft_type in (PeftType.LORA, PeftType.ADALORA): # to_return = lora_state_dict(model, bias=model.peft_config.bias) # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` @@ -165,11 +196,13 @@ def get_peft_model_state_dict( else: raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") + # MODULES TO SAVE if getattr(model, "modules_to_save", None) is not None: for key, value in state_dict.items(): if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): to_return[key.replace("modules_to_save.", "")] = value + # DEAL WITH EMBEDDINGS # check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary is_embedding_in_target_modules = False if ( @@ -223,6 +256,11 @@ def get_peft_model_state_dict( elif save_embedding_layers: warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") + # DEAL WITH BUFFERS + buffers_to_save = get_buffers_to_save(model, config) + to_return.update(buffers_to_save) + + # REMOVE ADAPTER NAME to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} return to_return diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py new file mode 100644 index 0000000000..8afd237231 --- /dev/null +++ b/tests/test_vision_models.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is not a full on test suite of vision models, since we already run many tests on dummy models with Conv2d layers +# and on stable diffusion models. Instead, this file contains specific tests for bugs that have been found in the past. +import gc + +import pytest +import torch +from datasets import load_dataset +from safetensors.torch import load_file +from transformers import AutoImageProcessor, AutoModelForImageClassification + +from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model + + +CONFIGS = { + "lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"]), + "loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier"]), + "lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier"]), + "oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier"]), + # TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no + # common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel: + # > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered + # "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier"], boft_block_size=2), +} + + +class TestResnet: + model_id = "microsoft/resnet-18" + + @pytest.fixture(autouse=True) + def teardown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + @pytest.fixture(scope="class") + def image_processor(self): + image_processor = AutoImageProcessor.from_pretrained(self.model_id) + return image_processor + + @pytest.fixture(scope="class") + def data(self, image_processor): + dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) + image = dataset["test"]["image"][0] + return image_processor(image, return_tensors="pt") + + @pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys()) + def test_model_with_batchnorm_reproducibility(self, config, tmp_path, data): + # see 1732 + torch.manual_seed(0) + model = AutoModelForImageClassification.from_pretrained(self.model_id) + model = get_peft_model(model, config) + + # record outputs before training + model.eval() + with torch.inference_mode(): + output_before = model(**data) + model.train() + + # train the model + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + batch_size = 4 + max_steps = 5 * batch_size + labels = torch.zeros(1, 1000) + labels[0, 283] = 1 + for i in range(0, max_steps, batch_size): + optimizer.zero_grad() + outputs = model(**data, labels=labels) + loss = outputs.loss + loss.backward() + optimizer.step() + + # record outputs after training + model.eval() + with torch.inference_mode(): + output_after = model(**data) + assert torch.isfinite(output_after.logits).all() + atol, rtol = 1e-4, 1e-4 + # sanity check: model was updated + assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol) + + # check saving the model and loading it + model.save_pretrained(tmp_path) + del model + + torch.manual_seed(0) + model = AutoModelForImageClassification.from_pretrained(self.model_id) + model = PeftModel.from_pretrained(model, tmp_path).eval() + with torch.inference_mode(): + output_loaded = model(**data) + assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol) + + # ensure that the checkpoint file contains the buffers + model_running_mean = len([k for k in model.state_dict().keys() if "running_mean" in k]) + state_dict = load_file(tmp_path / "adapter_model.safetensors") + checkpoint_running_mean = len([k for k in state_dict.keys() if "running_mean" in k]) + assert model_running_mean == checkpoint_running_mean From 6337b00b01bc99ce8756bb152e3d96bd732a14c2 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 May 2024 13:21:30 +0200 Subject: [PATCH 2/3] Make style --- tests/test_vision_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index 8afd237231..bb4029c901 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); From 0f14aee1cdbfa45079ffd038f67f8af57b3ed2a6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 21 May 2024 11:50:19 +0200 Subject: [PATCH 3/3] Use modules_to_save solution No need to add extra code to save buffers in checkpoint. --- .../developer_guides/troubleshooting.md | 22 +++++++++++++ src/peft/utils/save_and_load.py | 33 ------------------- tests/test_vision_models.py | 14 ++++---- 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/docs/source/developer_guides/troubleshooting.md b/docs/source/developer_guides/troubleshooting.md index 179258aca9..abc4132b84 100644 --- a/docs/source/developer_guides/troubleshooting.md +++ b/docs/source/developer_guides/troubleshooting.md @@ -240,3 +240,25 @@ TunerModelStatus( available_adapters=['adapter-1', 'adapter-2'], ) ``` + +## Reproducibility + +### Models using batch norm + +When loading a trained PEFT model where the base model uses batch norm (e.g. `torch.nn.BatchNorm1d` or `torch.nn.BatchNorm2d`), you may find that you cannot reproduce the exact same outputs. This is because the batch norm layers keep track of running stats during training, but these stats are not part of the PEFT checkpoint. Therefore, when you load the PEFT model, the running stats of the base model will be used (i.e. from before training with PEFT). + +Depending on your use case, this may not be a big deal. If, however, you need your outputs to be 100% reproducible, you can achieve this by adding the batch norm layers to `modules_to_save`. Below is an example of this using resnet and LoRA. Notice that we set `modules_to_save=["classifier", "normalization"]`. We need the `"classifier"` argument because our task is image classification, and we add the `"normalization"` argument to ensure that the batch norm layers are saved in the PEFT checkpoint. + +```python +from transformers import AutoModelForImageClassification +from peft import LoraConfig, get_peft_model + +model_id = "microsoft/resnet-18" +base_model = AutoModelForImageClassification.from_pretrained(self.model_id) +config = LoraConfig( + target_modules=["convolution"], + modules_to_save=["classifier", "normalization"], +), +``` + +Depending on the type of model you use, the batch norm layers could have different names than `"normalization"`, so please ensure that the name matches your model architecture. diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 8123126eaa..c993afc9bb 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -45,35 +45,6 @@ def get_embedding_layer_name(model, layer, is_embedding_in_target_modules): return None -def get_buffers_to_save(model: torch.nn.Module, config) -> dict[str, torch.Tensor]: - """ - Get the buffers to save for the given model and config. - - Although the base model weights are not updated when using PEFT, some buffers may still be updated and therefore - need to be saved as part of the PEFT checkpoint. An example of this are the running stats of batch norm layers. - - Args: - model (torch.nn.Module): - The PEFT model. - config (PeftConfig): - The PEFT config. - - Returns: - dict[str, torch.Tensor]: - The buffers to save. - """ - # note: config is currently not used but that may change in the future - buffers_to_save = {} - for module_name, module in model.named_modules(): - # currently, we only deal with running stats from BatchNorm* modules - if not hasattr(module, "track_running_stats"): - continue - for buffer_name, buffer in module.named_buffers(): - if buffer is not None: - buffers_to_save[module_name + "." + buffer_name] = buffer - return buffers_to_save - - def get_peft_model_state_dict( model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto" ): @@ -256,10 +227,6 @@ def get_peft_model_state_dict( elif save_embedding_layers: warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") - # DEAL WITH BUFFERS - buffers_to_save = get_buffers_to_save(model, config) - to_return.update(buffers_to_save) - # REMOVE ADAPTER NAME to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} return to_return diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index bb4029c901..8cb913707d 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -26,14 +26,14 @@ CONFIGS = { - "lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"]), - "loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier"]), - "lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier"]), - "oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier"]), + "lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), + "loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), + "lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), + "oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), # TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no # common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel: # > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered - # "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier"], boft_block_size=2), + # "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"], boft_block_size=2), } @@ -112,4 +112,6 @@ def test_model_with_batchnorm_reproducibility(self, config, tmp_path, data): model_running_mean = len([k for k in model.state_dict().keys() if "running_mean" in k]) state_dict = load_file(tmp_path / "adapter_model.safetensors") checkpoint_running_mean = len([k for k in state_dict.keys() if "running_mean" in k]) - assert model_running_mean == checkpoint_running_mean + # note that the model has twice as many "running_mean", as there is one copy per ModulesToSaveWrapper, we need + # to multiply by 2 to get the same number + assert model_running_mean == checkpoint_running_mean * 2