[DPOTrainer
] Fix DPO trainer + mistral + FA2
#1290
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes: #1217
Fixes: huggingface/transformers#26877
Fixes: #1266
Simply setting
use_cache=False
circumvents all issues with FA-2 + DPO + Mistral. in fact we should bypass that check since we are not in text generation mode when computing the loss function.use_cache
is retrieved from the model config by default which falls back always toTrue
. The cache is not used anyway when purely computing the logits so this change is fully BCcc @kashif @vwxyzjn
To test that, I managed to repro the issue by adding
--attn_implementation "flash_attention_2"
in the dpo shell script on a A100 machine, and I confirm this PR fixes it. Unfortunately our CI runners are not compatible with FA2 so we cannot add a slow test to test that