Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC TST Document and test reproducibility with models using batch norm #1734

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,35 @@ def get_embedding_layer_name(model, layer, is_embedding_in_target_modules):
return None


def get_buffers_to_save(model: torch.nn.Module, config) -> dict[str, torch.Tensor]:
"""
Get the buffers to save for the given model and config.

Although the base model weights are not updated when using PEFT, some buffers may still be updated and therefore
need to be saved as part of the PEFT checkpoint. An example of this are the running stats of batch norm layers.

Args:
model (torch.nn.Module):
The PEFT model.
config (PeftConfig):
The PEFT config.

Returns:
dict[str, torch.Tensor]:
The buffers to save.
"""
# note: config is currently not used but that may change in the future
buffers_to_save = {}
for module_name, module in model.named_modules():
# currently, we only deal with running stats from BatchNorm* modules
if not hasattr(module, "track_running_stats"):
continue
for buffer_name, buffer in module.named_buffers():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we wanna check if track_running_stats is set to True?

Copy link
Member Author

@BenjaminBossan BenjaminBossan May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I considered this, but was hesitant. Say someone trains a model with tracking enabled, and then turns it off for some reason before saving. Then we would not save these buffers even though they're needed, do I see that right? Ideally, we would have a check if they changed vis-à-vis the base model, but I don't see a way to monitor this except for storing a copy of all these buffers, requiring extra memory.

I guess we could decide only to save them if getattr(module, "track_running_stats", None) is True, and issue a warning that they're not saved if getattr(module, "track_running_stats", None) is False (and for None, just ignore).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes... true.. ok its safer to save them you are right yeah someone can then turn it off and it starts to use the batch's summary statistics in inference mode...

if buffer is not None:
buffers_to_save[module_name + "." + buffer_name] = buffer
return buffers_to_save


def get_peft_model_state_dict(
model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto"
):
Expand All @@ -71,6 +100,8 @@ def get_peft_model_state_dict(
config = model.peft_config[adapter_name]
if state_dict is None:
state_dict = model.state_dict()

# TUNER SPECIFIC CODE
if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
Expand Down Expand Up @@ -165,11 +196,13 @@ def get_peft_model_state_dict(
else:
raise ValueError(f"Unknown PEFT type passed: {config.peft_type}")

# MODULES TO SAVE
if getattr(model, "modules_to_save", None) is not None:
for key, value in state_dict.items():
if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
to_return[key.replace("modules_to_save.", "")] = value

# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
is_embedding_in_target_modules = False
if (
Expand Down Expand Up @@ -223,6 +256,11 @@ def get_peft_model_state_dict(
elif save_embedding_layers:
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")

# DEAL WITH BUFFERS
buffers_to_save = get_buffers_to_save(model, config)
to_return.update(buffers_to_save)

# REMOVE ADAPTER NAME
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
return to_return

Expand Down
115 changes: 115 additions & 0 deletions tests/test_vision_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is not a full on test suite of vision models, since we already run many tests on dummy models with Conv2d layers
# and on stable diffusion models. Instead, this file contains specific tests for bugs that have been found in the past.
import gc

import pytest
import torch
from datasets import load_dataset
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModelForImageClassification

from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model


CONFIGS = {
"lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"]),
"loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier"]),
"lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier"]),
"oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier"]),
# TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no
# common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel:
# > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered
# "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier"], boft_block_size=2),
}


class TestResnet:
model_id = "microsoft/resnet-18"

@pytest.fixture(autouse=True)
def teardown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

@pytest.fixture(scope="class")
def image_processor(self):
image_processor = AutoImageProcessor.from_pretrained(self.model_id)
return image_processor

@pytest.fixture(scope="class")
def data(self, image_processor):
dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
image = dataset["test"]["image"][0]
return image_processor(image, return_tensors="pt")

@pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys())
def test_model_with_batchnorm_reproducibility(self, config, tmp_path, data):
# see 1732
torch.manual_seed(0)
model = AutoModelForImageClassification.from_pretrained(self.model_id)
model = get_peft_model(model, config)

# record outputs before training
model.eval()
with torch.inference_mode():
output_before = model(**data)
model.train()

# train the model
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
max_steps = 5 * batch_size
labels = torch.zeros(1, 1000)
labels[0, 283] = 1
for i in range(0, max_steps, batch_size):
optimizer.zero_grad()
outputs = model(**data, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()

# record outputs after training
model.eval()
with torch.inference_mode():
output_after = model(**data)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)

# check saving the model and loading it
model.save_pretrained(tmp_path)
del model

torch.manual_seed(0)
model = AutoModelForImageClassification.from_pretrained(self.model_id)
model = PeftModel.from_pretrained(model, tmp_path).eval()
with torch.inference_mode():
output_loaded = model(**data)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)

# ensure that the checkpoint file contains the buffers
model_running_mean = len([k for k in model.state_dict().keys() if "running_mean" in k])
state_dict = load_file(tmp_path / "adapter_model.safetensors")
checkpoint_running_mean = len([k for k in state_dict.keys() if "running_mean" in k])
assert model_running_mean == checkpoint_running_mean
Loading