Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral&Mixtral]Add sliding window param to sdpa after torch 2.2.0 #29220

Closed
wants to merge 55 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
97cff89
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
d0baf19
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
c9dacb8
revert add sliding window for qwen2 because of numerical error
ehuaa Feb 22, 2024
9c7bb07
Merge branch 'add_sliding_window_for_sdpa' of https://mirror.ghproxy.…
ehuaa Feb 22, 2024
f464d15
remove adding sliding_window param to qwen2 because of numerical error
ehuaa Feb 22, 2024
f4c21d0
fix style
ehuaa Feb 22, 2024
a9f1571
only add non-contigous mask to qwen2 due to numerical error
ehuaa Feb 22, 2024
8b84d68
revert non-contigous tensor modification
ehuaa Feb 24, 2024
6bee407
move torch version judgement to import_utils for usability
ehuaa Feb 27, 2024
cf225ff
Merge branch 'main' into add_sliding_window_for_sdpa
ehuaa Feb 27, 2024
9c4f0b0
delete deprecated is_flash_attn in import_utils.py
ehuaa Feb 27, 2024
8a1faf2
Add compatibility with skip_memory_metrics for mps device (#29264)
SunMarc Feb 27, 2024
ddf7ac4
Token level timestamps for long-form generation in Whisper (#29148)
zucchini-nlp Feb 27, 2024
227cd54
Fix a few typos in `GenerationMixin`'s docstring (#29277)
sadra-barikbin Feb 27, 2024
83ab011
[i18n-zh] Translate fsdp.md into Chinese (#29305)
windsonsea Feb 27, 2024
63caa37
Starcoder2 model - bis (#29215)
RaymondLi0 Feb 28, 2024
bd5b986
simplify get_class_in_module and fix for paths containing a dot (#29262)
cebtenzzre Feb 28, 2024
ad00c48
FIX [`Gemma` / `CI`] Make sure our runners have access to the model (…
younesbelkada Feb 28, 2024
e715c78
Remove numpy usage from owlvit (#29326)
fxmarty Feb 28, 2024
a528885
[`require_read_token`] fix typo (#29345)
ArthurZucker Feb 28, 2024
7c87f35
[`T5 and Llama Tokenizer`] remove warning (#29346)
ArthurZucker Feb 28, 2024
8a8a0a4
[`Llama ROPE`] Fix torch export but also slow downs in forward (#29198)
ArthurZucker Feb 28, 2024
2ce56d3
Disable Mixtral `output_router_logits` during inference (#29249)
LeonardoEmili Feb 28, 2024
7628b3a
Idefics: generate fix (#29320)
gante Feb 28, 2024
d3a4b47
RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285)
danielhanchen Feb 28, 2024
554e7ad
check if position_ids exists before using it (#29306)
jiqing-feng Feb 28, 2024
f54d82c
[CI] Quantization workflow (#29046)
SunMarc Feb 28, 2024
49204c1
Better SDPA unmasking implementation (#29318)
fxmarty Feb 28, 2024
2209b7a
[i18n-zh] Sync source/zh/index.md (#29331)
windsonsea Feb 28, 2024
1aee9af
FIX [`CI` / `starcoder2`] Change starcoder2 path to correct one for s…
younesbelkada Feb 29, 2024
8d8ac9c
FIX [`CI`]: Fix failing tests for peft integration (#29330)
younesbelkada Feb 29, 2024
b647acd
FIX [`CI`] `require_read_token` in the llama FA2 test (#29361)
younesbelkada Feb 29, 2024
44fe1a1
Avoid using uncessary `get_values(MODEL_MAPPING)` (#29362)
ydshieh Feb 29, 2024
bb4f816
Patch YOLOS and others (#29353)
NielsRogge Feb 29, 2024
0ad770c
Fix @require_read_token in tests (#29367)
Wauplin Feb 29, 2024
5ee0868
Expose `offload_buffers` parameter of `accelerate` to `PreTrainedMode…
notsyncing Mar 1, 2024
2858d6c
Fix Base Model Name of LlamaForQuestionAnswering (#29258)
lenglaender Mar 1, 2024
50db7ca
FIX [`quantization` / `ESM`] Fix ESM 8bit / 4bit with bitsandbytes (#…
younesbelkada Mar 1, 2024
e7b9837
[`Llama + AWQ`] fix `prepare_inputs_for_generation` 🫠 (#29381)
ArthurZucker Mar 1, 2024
0a0a279
🚨🚨[Whisper Tok] Update integration test (#29368)
sanchit-gandhi Mar 1, 2024
f1b1379
[`YOLOS`] Fix - return padded annotations (#29300)
amyeroberts Mar 1, 2024
15f8296
Support subfolder with `AutoProcessor` (#29169)
JingyaHuang Mar 1, 2024
cec7733
Fix llama + gemma accelete tests (#29380)
SunMarc Mar 1, 2024
1a7c117
Fix deprecated arg issue (#29372)
muellerzr Mar 1, 2024
831bc25
Correct zero division error in inverse sqrt scheduler (#28982)
DavidAfonsoValente Mar 1, 2024
773d8c8
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
4fab890
revert add sliding window for qwen2 because of numerical error
ehuaa Feb 22, 2024
074d47a
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
6972cdf
remove adding sliding_window param to qwen2 because of numerical error
ehuaa Feb 22, 2024
8611c2d
fix style
ehuaa Feb 22, 2024
510f24f
only add non-contigous mask to qwen2 due to numerical error
ehuaa Feb 22, 2024
cbfc413
revert non-contigous tensor modification
ehuaa Feb 24, 2024
6d590b0
move torch version judgement to import_utils for usability
ehuaa Feb 27, 2024
da327f7
upload a test for compare flash vs sdpa for sliding window in Mistral
ehuaa Mar 2, 2024
8171137
Merge branch 'add_sliding_window_for_sdpa' of https://mirror.ghproxy.…
ehuaa Mar 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_version_greater_or_equal_than_2_2_0,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1006,6 +1007,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None,
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
)
else:
# 4d mask is passed through the layers
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_version_greater_or_equal_than_2_2_0,
logging,
replace_return_docstrings,
)
Expand All @@ -60,6 +61,7 @@

_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)


# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
Expand Down Expand Up @@ -1190,6 +1192,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_version_greater_or_equal_than_2_2_0,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdistx_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,10 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")


def is_torch_version_greater_or_equal_than_2_2_0():
return version.parse(get_torch_version()) >= version.parse("2.2.0")


def is_torchdistx_available():
return _torchdistx_available

Expand Down