Mistral: Sliding Window Attention with Flash Attention and Sample Packing #732
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Main benefits of this PR:
window_size
to Flash Attention.Memory usage
Memory usage with SWA. The conclusion is that you save 3GB when using a sliding window mask.
_prepare_decoder_attention_mask
withwindow_size=(4096, 4096)
parameter to flash attention._prepare_decoder_attention_mask
withwindow_size=(-1, -1)
parameter to flash attention.Long context (casperhansen/longalpaca_1k_test)
I test with a long context dataset, minimum 16k tokens and maximum 32k tokens. Minimum 48GB VRAM needed to run this.
Results after a few steps:
Short context (mhenrichsen/alpaca_2k_test)
Loss on short-context datasets is tested to be the same.
Used default config in
examples/mistral/qlora.yml
.Other notes:
attention_mask
andsliding_window_mask
are not broadcastable in the first iteration after eval loss. However, this is only the case whenwandb
is enabled. This error is handled byattention_mask.shape[0] != 1
so that it does not trigger._expand_mask
and it did not work with Flash Attention. I tried other methods too, but same problem.