-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DOC TST Document and test reproducibility with models using batch norm #1734
Merged
BenjaminBossan
merged 3 commits into
huggingface:main
from
BenjaminBossan:fix-checkpoint-for-batchnorm-models
May 22, 2024
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# This is not a full on test suite of vision models, since we already run many tests on dummy models with Conv2d layers | ||
# and on stable diffusion models. Instead, this file contains specific tests for bugs that have been found in the past. | ||
import gc | ||
|
||
import pytest | ||
import torch | ||
from datasets import load_dataset | ||
from safetensors.torch import load_file | ||
from transformers import AutoImageProcessor, AutoModelForImageClassification | ||
|
||
from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model | ||
|
||
|
||
CONFIGS = { | ||
"lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), | ||
"loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), | ||
"lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), | ||
"oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), | ||
# TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no | ||
# common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel: | ||
# > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered | ||
# "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"], boft_block_size=2), | ||
} | ||
|
||
|
||
class TestResnet: | ||
model_id = "microsoft/resnet-18" | ||
|
||
@pytest.fixture(autouse=True) | ||
def teardown(self): | ||
r""" | ||
Efficient mechanism to free GPU memory after each test. Based on | ||
https://github.com/huggingface/transformers/issues/21094 | ||
""" | ||
gc.collect() | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
@pytest.fixture(scope="class") | ||
def image_processor(self): | ||
image_processor = AutoImageProcessor.from_pretrained(self.model_id) | ||
return image_processor | ||
|
||
@pytest.fixture(scope="class") | ||
def data(self, image_processor): | ||
dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) | ||
image = dataset["test"]["image"][0] | ||
return image_processor(image, return_tensors="pt") | ||
|
||
@pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys()) | ||
def test_model_with_batchnorm_reproducibility(self, config, tmp_path, data): | ||
# see 1732 | ||
torch.manual_seed(0) | ||
model = AutoModelForImageClassification.from_pretrained(self.model_id) | ||
model = get_peft_model(model, config) | ||
|
||
# record outputs before training | ||
model.eval() | ||
with torch.inference_mode(): | ||
output_before = model(**data) | ||
model.train() | ||
|
||
# train the model | ||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) | ||
batch_size = 4 | ||
max_steps = 5 * batch_size | ||
labels = torch.zeros(1, 1000) | ||
labels[0, 283] = 1 | ||
for i in range(0, max_steps, batch_size): | ||
optimizer.zero_grad() | ||
outputs = model(**data, labels=labels) | ||
loss = outputs.loss | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# record outputs after training | ||
model.eval() | ||
with torch.inference_mode(): | ||
output_after = model(**data) | ||
assert torch.isfinite(output_after.logits).all() | ||
atol, rtol = 1e-4, 1e-4 | ||
# sanity check: model was updated | ||
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol) | ||
|
||
# check saving the model and loading it | ||
model.save_pretrained(tmp_path) | ||
del model | ||
|
||
torch.manual_seed(0) | ||
model = AutoModelForImageClassification.from_pretrained(self.model_id) | ||
model = PeftModel.from_pretrained(model, tmp_path).eval() | ||
with torch.inference_mode(): | ||
output_loaded = model(**data) | ||
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol) | ||
|
||
# ensure that the checkpoint file contains the buffers | ||
model_running_mean = len([k for k in model.state_dict().keys() if "running_mean" in k]) | ||
state_dict = load_file(tmp_path / "adapter_model.safetensors") | ||
checkpoint_running_mean = len([k for k in state_dict.keys() if "running_mean" in k]) | ||
# note that the model has twice as many "running_mean", as there is one copy per ModulesToSaveWrapper, we need | ||
# to multiply by 2 to get the same number | ||
assert model_running_mean == checkpoint_running_mean * 2 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yfeng95 @Zeju1997 @YuliangXiu Any idea how I could fix the described issue?