Skip to content

Commit

Permalink
DOC Troubleshooting for unscaling error with fp16 (#1336)
Browse files Browse the repository at this point in the history
Some users ran into the issue of trying to use a model loaded in float16
with mixed precision, e.g. these issues: #341, #1249. This PR documents
a workaround to solve the issue.

I also added tests that demonstrate the issue, as well as the
workaround.

Notes

This is not strictly a PEFT issue, but more a general error when using
AMP with float16. Still, since PEFT users encounter this sometimes, it
is useful to document it.

When we discussed this issue in the past, I think we concluded that it's
not as straightforward as PEFT automatically casting the weights to
float32, though I cannot remember anymore what the drawbacks were.

In any case, should we ever add an automatic solution for this in PEFT,
the added test should fail, which alerts us to the fact that we need to
update the documentation.
  • Loading branch information
BenjaminBossan authored Jan 10, 2024
1 parent e96eef9 commit c6b28a2
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
19 changes: 19 additions & 0 deletions docs/source/developer_guides/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ Installing PEFT from source is useful for keeping up with the latest development
python -m pip install git+https://github.com/huggingface/peft
```

## Training errors

### Getting: ValueError: Attempting to unscale FP16 gradients

This error probably occurred because the model was loaded with `torch_dtype=torch.float16` and then used in an automatic mixed precision (AMP) context, e.g. by setting `fp16=True` in the `Trainer` class from 🤗 Transformers. The reason is that when using AMP, trainable weights should never use fp16. To make this work without having to load the whole model in FP32, add the following snippet to your code:

```python
peft_model = get_peft_model(...)

# add this:
for param in model.parameters():
if param.requires_grad:
param.data = param.data.float()

# proceed as usual
trainer = Trainer(model=peft_model, fp16=True, ...)
trainer.train()
```

## Bad results from a loaded PEFT model

There can be several reasons for getting a poor result from a loaded PEFT model, which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue.
Expand Down
79 changes: 79 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,3 +1196,82 @@ def test_notebook_launcher(self):
cmd = ["python", script_path]
with patch_environment(omp_num_threads=1):
run_command(cmd, env=os.environ.copy())


@require_torch_gpu
class MixedPrecisionTests(unittest.TestCase):
def setUp(self):
self.causal_lm_model_id = "facebook/opt-350m"
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
self.config = LoraConfig(
r=16,
lora_alpha=32,
task_type="CAUSAL_LM",
)

data = load_dataset("ybelkada/english_quotes_copy")
self.data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)

def tearDown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

@pytest.mark.single_gpu_tests
def test_model_loaded_in_float16_raises(self):
# This test shows the issue with loading the model in fp16 and then trying to use it with mixed precision
# training, which should not use fp16. If this is ever automated in PEFT, this test should fail. In that case,
# remove this test, adjust the next one, and remove the entry about FP16 usage from troubleshooting.md.
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
)
model = get_peft_model(model, self.config)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=self.data["train"],
args=TrainingArguments(
fp16=True, # <= this is required for the error to be raised
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
msg = "Attempting to unscale FP16 gradients."
with self.assertRaisesRegex(ValueError, msg):
trainer.train()

@pytest.mark.single_gpu_tests
def test_model_loaded_in_float16_working(self):
# Same test as before but containing the fix to make it work
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
)
model = get_peft_model(model, self.config)

# for now, this is unfortunately necessary to avoid the error:
# ValueError: Attempting to unscale FP16 gradients.
for param in model.parameters():
if param.requires_grad:
param.data = param.data.float()

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=self.data["train"],
args=TrainingArguments(
fp16=True,
max_steps=3,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
trainer.train()

0 comments on commit c6b28a2

Please sign in to comment.