-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Bugfix] Fix encoder-only model support for transformers backend #28021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug that prevented encoder-only models from working with the transformers backend. The issue, as indicated by the traceback, was an assertion error during the KV cache initialization, where the attention type was expected to be DECODER. The fix involves conditionally selecting the EncoderOnlyAttention class for encoder layers, which correctly signals that no KV cache is needed. The change is precise, correct, and effectively resolves the reported bug.
| if attn_type == AttentionType.ENCODER_ONLY | ||
| else Attention | ||
| ) | ||
| attention_instances[i] = attn_cls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does passing attn_type not work? Are the two not equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| per_layer_sliding_window = self.config.sliding_window | ||
|
|
||
| attention_instances[i] = Attention( | ||
| attn_cls = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 do you want to handle it inside the Attention class init to signal deprecation with a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to handle it like this now. For the deprecation warning, just add one line of warning in Attention class? (not necessary in this PR)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Purpose
get_kv_cache_specintoAttentionLayerBase#26587Test Plan
Test Result
Test should pass with nightly Transformers now
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.