Skip to content

Commit ed5dd29

Browse files
authored
[ESM] support attention API (#40370)
* ESM supports attention API * supports flags * fix tests * fix copiees * another fixup needed after fixing tests * fix tests and make sure Evolla copied everything * fix * order * forgot about "is_causal" for fa2 * cross attention can't be causal
1 parent 8b80431 commit ed5dd29

File tree

8 files changed

+302
-613
lines changed

8 files changed

+302
-613
lines changed

src/transformers/models/esm/modeling_esm.py

Lines changed: 149 additions & 314 deletions
Large diffs are not rendered by default.

src/transformers/models/esm/modeling_esmfold.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,10 @@ def distogram(coords, min_bin, max_bin, num_bins):
19911991
class EsmForProteinFolding(EsmPreTrainedModel):
19921992
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
19931993
_supports_flash_attn = False
1994+
_supports_sdpa = False
1995+
_supports_attention_backend = False
1996+
1997+
_can_record_outputs = None
19941998

19951999
def __init__(self, config):
19962000
super().__init__(config)

src/transformers/models/evolla/configuration_evolla.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def __init__(
7676
initializer_range=0.02,
7777
layer_norm_eps=1e-05,
7878
position_embedding_type="rotary",
79-
use_cache=True,
8079
emb_layer_norm_before=False,
8180
token_dropout=True,
8281
**kwargs,
@@ -94,7 +93,6 @@ def __init__(
9493
self.initializer_range = initializer_range
9594
self.layer_norm_eps = layer_norm_eps
9695
self.position_embedding_type = position_embedding_type
97-
self.use_cache = use_cache
9896
self.emb_layer_norm_before = emb_layer_norm_before
9997
self.token_dropout = token_dropout
10098

src/transformers/models/evolla/modeling_evolla.py

Lines changed: 118 additions & 279 deletions
Large diffs are not rendered by default.

src/transformers/models/evolla/modular_evolla.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
logging,
3838
)
3939
from ...utils.deprecation import deprecate_kwarg
40-
from ...utils.generic import check_model_inputs
40+
from ...utils.generic import OutputRecorder, check_model_inputs
4141
from ..esm.modeling_esm import (
4242
EsmAttention,
4343
EsmEmbeddings,
@@ -122,13 +122,13 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch
122122
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
123123

124124
return (
125-
apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached),
126-
apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached),
125+
apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
126+
apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
127127
)
128128

129129

130130
class EvollaSaProtSelfAttention(EsmSelfAttention, nn.Module):
131-
def __init__(self, config, position_embedding_type=None, layer_idx=None):
131+
def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
132132
nn.Module.__init__(self)
133133
self.config = config
134134

@@ -146,7 +146,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None):
146146
self.key = nn.Linear(config.hidden_size, self.all_head_size)
147147
self.value = nn.Linear(config.hidden_size, self.all_head_size)
148148

149-
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
149+
self.dropout = config.attention_probs_dropout_prob
150150
self.position_embedding_type = position_embedding_type or getattr(
151151
config, "position_embedding_type", "absolute"
152152
)
@@ -159,6 +159,8 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None):
159159

160160
self.is_decoder = config.is_decoder
161161
self.layer_idx = layer_idx
162+
self.scaling = 1.0
163+
self.is_causal = self.is_decoder and not is_cross_attention
162164

163165

164166
class EvollaSaProtSelfOutput(EsmSelfOutput):
@@ -193,6 +195,17 @@ class EvollaSaProtPooler(EsmPooler):
193195
class EvollaSaProtPreTrainedModel(PreTrainedModel):
194196
config: SaProtConfig
195197
_no_split_modules = ["EvollaSaProtLayer"]
198+
_supports_flash_attn = True
199+
_supports_sdpa = True
200+
_supports_attention_backend = True
201+
202+
_can_record_outputs = {
203+
"hidden_states": EvollaSaProtLayer,
204+
"attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
205+
"cross_attentions": [
206+
OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
207+
],
208+
}
196209

197210
def _init_weights(self, module):
198211
"""Initialize the weights"""
@@ -230,7 +243,7 @@ class PreTrainedModel
230243
for layer, heads in heads_to_prune.items():
231244
self.encoder.layer[layer].attention.prune_heads(heads)
232245

233-
@can_return_tuple
246+
@check_model_inputs
234247
def forward(
235248
self,
236249
input_ids: Optional[torch.Tensor],

src/transformers/utils/generic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -974,22 +974,22 @@ def check_model_inputs(func):
974974

975975
@wraps(func)
976976
def wrapper(self, *args, **kwargs):
977-
use_cache = kwargs.get("use_cache")
978-
if use_cache is None:
979-
use_cache = getattr(self.config, "use_cache", False)
977+
use_cache = (
978+
kwargs["use_cache"] if kwargs.get("use_cache") is not None else getattr(self.config, "use_cache", None)
979+
)
980+
if use_cache is not None:
981+
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
982+
logger.warning_once(
983+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
984+
)
985+
use_cache = False
986+
987+
kwargs["use_cache"] = use_cache
980988

981989
return_dict = kwargs.pop("return_dict", None)
982990
if return_dict is None:
983991
return_dict = getattr(self.config, "return_dict", True)
984992

985-
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
986-
logger.warning_once(
987-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
988-
)
989-
use_cache = False
990-
991-
kwargs["use_cache"] = use_cache
992-
993993
all_args = kwargs.copy()
994994
if "kwargs" in all_args:
995995
for k, v in all_args["kwargs"].items():

tests/models/esm/test_modeling_esm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def test_model_various_embeddings(self):
238238
config_and_inputs = self.model_tester.prepare_config_and_inputs()
239239
for type in ["absolute", "relative_key", "relative_key_query"]:
240240
config_and_inputs[0].position_embedding_type = type
241+
config_and_inputs[0]._attn_implementation = "eager"
241242
self.model_tester.create_and_check_model(*config_and_inputs)
242243

243244
def test_for_masked_lm(self):

tests/test_modeling_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4391,7 +4391,6 @@ def test_flash_attention_2_continue_generate_with_position_ids(self):
43914391
next_token_logits_from_generate = generation_out.logits[-1]
43924392

43934393
# acceptable numerical instability
4394-
# print(next_token_logits_from_generate, next_token_logits)
43954394
tol = torch.finfo(torch.bfloat16).eps
43964395
torch.testing.assert_close(next_token_logits_from_generate, next_token_logits, rtol=tol, atol=tol)
43974396

0 commit comments

Comments
 (0)