diff --git a/docs/source/developer_guides/troubleshooting.md b/docs/source/developer_guides/troubleshooting.md
index 179258aca9..abc4132b84 100644
--- a/docs/source/developer_guides/troubleshooting.md
+++ b/docs/source/developer_guides/troubleshooting.md
@@ -240,3 +240,25 @@ TunerModelStatus(
     available_adapters=['adapter-1', 'adapter-2'],
 )
 ```
+
+## Reproducibility
+
+### Models using batch norm
+
+When loading a trained PEFT model where the base model uses batch norm (e.g. `torch.nn.BatchNorm1d` or `torch.nn.BatchNorm2d`), you may find that you cannot reproduce the exact same outputs. This is because the batch norm layers keep track of running stats during training, but these stats are not part of the PEFT checkpoint. Therefore, when you load the PEFT model, the running stats of the base model will be used (i.e. from before training with PEFT).
+
+Depending on your use case, this may not be a big deal. If, however, you need your outputs to be 100% reproducible, you can achieve this by adding the batch norm layers to `modules_to_save`. Below is an example of this using resnet and LoRA. Notice that we set `modules_to_save=["classifier", "normalization"]`. We need the `"classifier"` argument because our task is image classification, and we add the `"normalization"` argument to ensure that the batch norm layers are saved in the PEFT checkpoint.
+
+```python
+from transformers import AutoModelForImageClassification
+from peft import LoraConfig, get_peft_model
+
+model_id = "microsoft/resnet-18"
+base_model = AutoModelForImageClassification.from_pretrained(self.model_id)
+config = LoraConfig(
+    target_modules=["convolution"],
+    modules_to_save=["classifier", "normalization"],
+),
+```
+
+Depending on the type of model you use, the batch norm layers could have different names than `"normalization"`, so please ensure that the name matches your model architecture.
diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py
index dc77d0f660..c993afc9bb 100644
--- a/src/peft/utils/save_and_load.py
+++ b/src/peft/utils/save_and_load.py
@@ -71,6 +71,8 @@ def get_peft_model_state_dict(
     config = model.peft_config[adapter_name]
     if state_dict is None:
         state_dict = model.state_dict()
+
+    # TUNER SPECIFIC CODE
     if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
         # to_return = lora_state_dict(model, bias=model.peft_config.bias)
         # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
@@ -165,11 +167,13 @@ def get_peft_model_state_dict(
     else:
         raise ValueError(f"Unknown PEFT type passed: {config.peft_type}")
 
+    # MODULES TO SAVE
     if getattr(model, "modules_to_save", None) is not None:
         for key, value in state_dict.items():
             if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
                 to_return[key.replace("modules_to_save.", "")] = value
 
+    # DEAL WITH EMBEDDINGS
     # check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
     is_embedding_in_target_modules = False
     if (
@@ -223,6 +227,7 @@ def get_peft_model_state_dict(
     elif save_embedding_layers:
         warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
 
+    # REMOVE ADAPTER NAME
     to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
     return to_return
 
diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py
new file mode 100644
index 0000000000..8cb913707d
--- /dev/null
+++ b/tests/test_vision_models.py
@@ -0,0 +1,117 @@
+# Copyright 2024-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This is not a full on test suite of vision models, since we already run many tests on dummy models with Conv2d layers
+# and on stable diffusion models. Instead, this file contains specific tests for bugs that have been found in the past.
+import gc
+
+import pytest
+import torch
+from datasets import load_dataset
+from safetensors.torch import load_file
+from transformers import AutoImageProcessor, AutoModelForImageClassification
+
+from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model
+
+
+CONFIGS = {
+    "lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
+    "loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
+    "lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
+    "oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
+    # TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no
+    # common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel:
+    # > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered
+    # "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"], boft_block_size=2),
+}
+
+
+class TestResnet:
+    model_id = "microsoft/resnet-18"
+
+    @pytest.fixture(autouse=True)
+    def teardown(self):
+        r"""
+        Efficient mechanism to free GPU memory after each test. Based on
+        https://github.com/huggingface/transformers/issues/21094
+        """
+        gc.collect()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+        gc.collect()
+
+    @pytest.fixture(scope="class")
+    def image_processor(self):
+        image_processor = AutoImageProcessor.from_pretrained(self.model_id)
+        return image_processor
+
+    @pytest.fixture(scope="class")
+    def data(self, image_processor):
+        dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
+        image = dataset["test"]["image"][0]
+        return image_processor(image, return_tensors="pt")
+
+    @pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys())
+    def test_model_with_batchnorm_reproducibility(self, config, tmp_path, data):
+        # see 1732
+        torch.manual_seed(0)
+        model = AutoModelForImageClassification.from_pretrained(self.model_id)
+        model = get_peft_model(model, config)
+
+        # record outputs before training
+        model.eval()
+        with torch.inference_mode():
+            output_before = model(**data)
+        model.train()
+
+        # train the model
+        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
+        batch_size = 4
+        max_steps = 5 * batch_size
+        labels = torch.zeros(1, 1000)
+        labels[0, 283] = 1
+        for i in range(0, max_steps, batch_size):
+            optimizer.zero_grad()
+            outputs = model(**data, labels=labels)
+            loss = outputs.loss
+            loss.backward()
+            optimizer.step()
+
+        # record outputs after training
+        model.eval()
+        with torch.inference_mode():
+            output_after = model(**data)
+        assert torch.isfinite(output_after.logits).all()
+        atol, rtol = 1e-4, 1e-4
+        # sanity check: model was updated
+        assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
+
+        # check saving the model and loading it
+        model.save_pretrained(tmp_path)
+        del model
+
+        torch.manual_seed(0)
+        model = AutoModelForImageClassification.from_pretrained(self.model_id)
+        model = PeftModel.from_pretrained(model, tmp_path).eval()
+        with torch.inference_mode():
+            output_loaded = model(**data)
+        assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
+
+        # ensure that the checkpoint file contains the buffers
+        model_running_mean = len([k for k in model.state_dict().keys() if "running_mean" in k])
+        state_dict = load_file(tmp_path / "adapter_model.safetensors")
+        checkpoint_running_mean = len([k for k in state_dict.keys() if "running_mean" in k])
+        # note that the model has twice as many "running_mean", as there is one copy per ModulesToSaveWrapper, we need
+        # to multiply by 2 to get the same number
+        assert model_running_mean == checkpoint_running_mean * 2