-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] T5 compile-compatibility #33754
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the code, its working perfectly with my codespace. Appreciate the effort @zucchini-nlp
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
src/transformers/generation/utils.py
Outdated
@@ -496,7 +496,7 @@ def _prepare_encoder_decoder_kwargs_for_generation( | |||
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) | |||
|
|||
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. | |||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] | |||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want to populate cache when running the encoder, otherwise the tests will fail with shape mismatch for the decoder part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why Whisper didn't require this change, but T5 does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
found it! Because Whisper has separate classes for Encoder and Decoder, while T5 has only T5Stack
that init both encoder and decoder.
I can change a bit the code so that T5Stack
sets cache to None
when the class is used as an encoder, otherwise we pass in the cache further down the line to attention and since cache object is modified in-place, it will have a shape mismatch later in decoder part. I think that makes more sense, since encoder should pass no cache
src/transformers/generation/utils.py
Outdated
@@ -1477,7 +1477,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): | |||
|
|||
cache_kwargs = { | |||
"config": self.config if hasattr(self.config, "text_config") else self.config, | |||
"max_batch_size": batch_size, | |||
"batch_size": batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's not show warnings that users can't control :)
"hidden_size": "d_model", | ||
"num_attention_heads": "num_heads", | ||
"num_hidden_layers": "num_layers", | ||
"head_dim": "d_kv", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is used later in StaticCache to infer the correct cache shape. in T5 models the d_kv
is used to calculate proj dim in attention modules
if torch.jit.is_tracing(): | ||
seq_length = seq_length.to(hidden_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes ONNX tests
causal_mask = mask[:, :, :, : key_states.shape[-2]] | ||
position_bias = position_bias + causal_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strangely without slicing the mask, control flow error is raised when fx-tracing. But I think the slicing in general shouldn't be needed, as attn mask should be same shape as position bias
|
||
if attention_mask is None: | ||
if attention_mask is None and not is_torchdynamo_compiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compile won't work because we are trying to build a tensor on the fly, relying on a value that is not static (past_kv_length
). So it will complain about dynamic control flow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might work with torch 2.4
causal_mask = self._update_causal_mask( | ||
attention_mask, | ||
inputs_embeds, | ||
cache_position, | ||
past_key_values.self_attention_cache if past_key_values is not None else None, | ||
output_attentions, | ||
) | ||
# We use local attention in encoder self-attention, otherwise standard self & cross attentions are used | ||
elif self.config.encoder_attention_type == "local": | ||
extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) | ||
causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) | ||
else: # we need to use both local attention mask and standard extended mask for transient-global attention | ||
extended_attention_mask = attention_mask | ||
causal_mask = attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is quite different from decoder-only models and from whisper, because T5 models have one "Module" that can be either encoder or decoder. And the masks have to be prepared as causal for decoder (same as in Llama) and using the prev preparation steps for encoder (for BC I didn't change that part). The None
is simply to account for cases when no attn mask is passed, so we don't throw errors
tests/models/t5/test_modeling_t5.py
Outdated
@@ -1455,6 +1487,7 @@ def test_summarization(self): | |||
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], | |||
padding="max_length", | |||
truncation=True, | |||
max_length=512, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, otherwise we can't form a tensor. Apparently it stopped truncating without us passing the max_length
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking a good shape! just reviewed long t5 !
query_length = present_key_value_state[0].shape[2] | ||
else: | ||
query_length = None | ||
query_length = cache_position[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surprised that this does not create inpute dependent control flow!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 :o
We can use the past_key_values.get_seq_length()
as a workaround if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does ahaha, but the catch is that for static cache we need the max_length
(aka length of keys after cache is added). So I added a check on cache type in T5Attention
to use max_cache_shape()
if we're in static cache setting
We could also simply rely on key.shape[-2]
which is cleaner. I don't think we had a special purpose to get key_length
like that instead of getting simply shape, lmk if there was a reason for that
if isinstance(past_key_value, StaticCache): | ||
seq_length = past_key_value.get_max_length() | ||
elif past_key_value is not None: | ||
seq_length += cache_position[0] if query_length is None else query_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_length is just cache_position[-1]
in general no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah, right, can be simply -1
. Answered above more in details
past_key_value = past_key_value.self_attention_cache | ||
|
||
if isinstance(past_key_value, StaticCache): | ||
seq_length = past_key_value.get_max_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bit confused by this as this is the max_seq_length in a way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should not have checks that depend on the cache type in the AttentionLayer itself IMO and should just use get_seq_length or cache positions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answered above more in details
) | ||
if self.gradient_checkpointing and self.training: | ||
position_bias.requires_grad = True | ||
else: | ||
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) | ||
position_bias = self.compute_bias(seq_length, key_length, device=scores.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be juste cache_position[0] no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, that creates dynamic control for compiling, so we better stick to tensor shapes and not values, given that the resulting length is same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep good direction!
|
||
if attention_mask is None: | ||
if attention_mask is None and not is_torchdynamo_compiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might work with torch 2.4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few questions/comments, mostly around readability and differences vs Whisper 🤗
(note: I have reviewed t5-related files and generation/utils.py
, given that the others are copies/adaptations)
if torch.jit.is_tracing(): | ||
seq_length = seq_length.to(hidden_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two questions:
- If
seq_length
comes fromhidden_states
in the lines above, why does it need to be moved to its device? 🤔 - Instead of
torch.jit.is_tracing()
, could we useis_torchdynamo_compiling()
? Or does this fix a torch.jit-specific issue? If it is so, let's add it in the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is fx_trace specific where the integer values (shape in this case) are somehow traced as tensors always. And apparently they are moved to CPU if the value was int without tracing, so did something ugly like this
But we'll not need this anymore now, if we agree all that key_length = key.shape[-2]
is an okay way to get shapes and i'm not missing anything that can break
if isinstance(past_key_value, StaticCache): | ||
seq_length = past_key_value.get_max_length() | ||
elif past_key_value is not None: | ||
seq_length += cache_position[0] if query_length is None else query_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- A few lines above we have
batch_size, seq_length = hidden_states.shape[:2]
andseq_length = seq_length.to(hidden_states.device)
, and here we might haveseq_length = past_key_value.get_max_length()
(which doesn't need the lines above). Can we consolidate these two pieces of logic to getseq_length
? - Readability: I had to go back to the whole code to figure out the origin of this block, which means our code is not very readable atm. Let's work on that :D Suggestions:
- we want to use
query_length
if we are in a cross-attention layer and we have cache. Let's add someif
dependent onis_cross_attention
cache_position[0]
is the length of the cache, but it might not be immediately obvious. We can do something likecache_length = cache_position[0]
, which should help understanding what's going on :)
- we want to use
query_length = present_key_value_state[0].shape[2] | ||
else: | ||
query_length = None | ||
query_length = cache_position[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 :o
We can use the past_key_values.get_seq_length()
as a workaround if needed
if attention_mask is None and not is_torchdynamo_compiling(): | ||
# required mask seq length can be calculated via length of past | ||
mask_seq_length = ( | ||
past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length | ||
) | ||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) | ||
|
||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
# ourselves in which case we just need to make it broadcastable to all heads. | ||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) | ||
if self.config.is_decoder: | ||
causal_mask = self._update_causal_mask( | ||
attention_mask, | ||
inputs_embeds, | ||
cache_position, | ||
past_key_values.self_attention_cache if past_key_values is not None else None, | ||
output_attentions, | ||
) | ||
elif attention_mask is not None: | ||
causal_mask = attention_mask[:, None, None, :] | ||
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) | ||
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min | ||
else: | ||
causal_mask = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Whisper, _update_causal_mask
handles all this logic (attention_mask = None
, decoder-only, encoder-decoder, ...)
Why can't we do it in T5? (I don't recall whether T5 has special masking requirements)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, as whisper has separate class for encoder and another for decoder. In encoder class we never use the causal_mask
method, as the mask should not be causal. T5 in contrast has one place from where we decide if it's an encoder and decoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t5 is the bad old example !
src/transformers/generation/utils.py
Outdated
@@ -496,7 +496,7 @@ def _prepare_encoder_decoder_kwargs_for_generation( | |||
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) | |||
|
|||
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. | |||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] | |||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why Whisper didn't require this change, but T5 does?
src/transformers/cache_utils.py
Outdated
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []: | ||
return 0 | ||
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | ||
return self.self_attention_cache.get_seq_length(layer_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got an error in one of the compiles when get_seq_length()
, apparently we can't check if tensor == []
when compiling. I think we can safely pass it to the self-attn cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should be able to check the length ! ([] lives on cpu which is probably the source of the issue) having tensors for everything might be better btw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, we should not go for that check on compiled model. That one is only for dynamic cache BC with empty lists
@@ -404,11 +419,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets | |||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) | |||
return relative_buckets | |||
|
|||
def compute_bias(self, query_length, key_length, device=None): | |||
def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the best I could come up with to overcome dynamic control. Otherwise the real_seq_length
is always depedent on cache length which is currently computed from tensor values being not 0
Basically this is similar as if we used position ids, but T5 models don't pass position ids. Cache position is exactly what we need then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
Done, updated the code for readability thus removing some unused things. Compile is working on T5 and other T5-based models, compile tests are green. The CI will also run slow tests for T5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but as t5 is the first compile compatible model, let's make sure we have integration tests in T5 directly for:
- cuda-graph decoding
- cuda-graph encoder style
with expected generation!
Otherwise, great work!
src/transformers/cache_utils.py
Outdated
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []: | ||
return 0 | ||
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | ||
return self.self_attention_cache.get_seq_length(layer_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should be able to check the length ! ([] lives on cpu which is probably the source of the issue) having tensors for everything might be better btw
if cache_position is None: | ||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a not, this will not produce the correct generation when you have cudagraphs AFAIK! But it's good for the other cases (BC and eager modes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, why not? Cache position should be same, as the prev context length neither accounts for pad tokens
|
||
if position_bias is None: | ||
key_length = key_states.shape[-2] | ||
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) | ||
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super surprised that compile is not complaining about the cache_position[-1]. I usually get:
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 2188, in forward
outputs = self.language_model(
File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 1929, in forward
outputs = self.model(
File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 1729, in forward
layer_outputs = decoder_layer(
File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 1006, in forward
hidden_states, attn_weights, past_key_value = self.cross_attn(
File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 624, in forward
elif cache_position[0] != 0:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it is because we never use real_seq_length
anymore. I am passing cache_position
to compute_bias
so that compile doesn't complain but left everything else as it was for BC
tests/test_modeling_common.py
Outdated
@@ -4769,7 +4770,10 @@ def test_torch_compile(self): | |||
n_iter = 3 | |||
|
|||
tokenizer = AutoTokenizer.from_pretrained(ckpt) | |||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) | |||
if self.is_encoder_decoder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for t5 are we only testing one of them in that case? (probably decoder?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we test the whole model as encoder-decoder, I added a test_ckpt as t5-small
. Since the forward call includes call to encoder and to decoder
Slow test for T5 added, for encoder-decoder and for encoder-only models |
47d70c5
to
69b5ccb
Compare
What does this PR do?
Fixes #33221 and fixes #33283
T5 now supports new cache format and is compile compatible. The current state generates same text with dynamic cache as old tuple cache.
TODO:
I ran all tests with static cache and compile for touched T5-based models, but no tests were enabled in UDOP/Pop2Piano etc. Those models are modified to support cache class only (for now) and that is because of
fix-copies
. As those are not super commonly used, I think we can enable compile in later stages. Also enabled slow tests for T5 as part of CI, slow tests locally are passing for me