Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] T5 compile-compatibility #33754

Closed
wants to merge 0 commits into from

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Sep 27, 2024

What does this PR do?

Fixes #33221 and fixes #33283

T5 now supports new cache format and is compile compatible. The current state generates same text with dynamic cache as old tuple cache.

TODO:

  • Add tests and run existing tests

I ran all tests with static cache and compile for touched T5-based models, but no tests were enabled in UDOP/Pop2Piano etc. Those models are modified to support cache class only (for now) and that is because of fix-copies. As those are not super commonly used, I think we can enable compile in later stages. Also enabled slow tests for T5 as part of CI, slow tests locally are passing for me

Copy link

@vignesh1507 vignesh1507 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the code, its working perfectly with my codespace. Appreciate the effort @zucchini-nlp

@HuggingFaceDocBuilderDev

Hey! 🤗 Thanks for your contribution to the transformers library!

Before merging this pull request, slow tests CI should be triggered. To enable this:

  • Add the run-slow label to the PR
  • When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command [run-slow] followed by a comma separated list of all the models to be tested, i.e. [run_slow] model_to_test_1, model_to_test_2
    • If the pull request affects a lot of models, put at most 10 models in the commit message
  • A transformers maintainer will then approve the workflow to start the tests

(For maintainers) The documentation for slow tests CI on PRs is here.

@@ -496,7 +496,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want to populate cache when running the encoder, otherwise the tests will fail with shape mismatch for the decoder part

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why Whisper didn't require this change, but T5 does?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found it! Because Whisper has separate classes for Encoder and Decoder, while T5 has only T5Stack that init both encoder and decoder.

I can change a bit the code so that T5Stack sets cache to None when the class is used as an encoder, otherwise we pass in the cache further down the line to attention and since cache object is modified in-place, it will have a shape mismatch later in decoder part. I think that makes more sense, since encoder should pass no cache

@@ -1477,7 +1477,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):

cache_kwargs = {
"config": self.config if hasattr(self.config, "text_config") else self.config,
"max_batch_size": batch_size,
"batch_size": batch_size,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's not show warnings that users can't control :)

"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used later in StaticCache to infer the correct cache shape. in T5 models the d_kv is used to calculate proj dim in attention modules

Comment on lines 514 to 515
if torch.jit.is_tracing():
seq_length = seq_length.to(hidden_states.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes ONNX tests

Comment on lines 580 to 581
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strangely without slicing the mask, control flow error is raised when fx-tracing. But I think the slicing in general shouldn't be needed, as attn mask should be same shape as position bias


if attention_mask is None:
if attention_mask is None and not is_torchdynamo_compiling():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile won't work because we are trying to build a tensor on the fly, relying on a value that is not static (past_kv_length). So it will complain about dynamic control flow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might work with torch 2.4

Comment on lines 1532 to 1543
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
# We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
elif self.config.encoder_attention_type == "local":
extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
else: # we need to use both local attention mask and standard extended mask for transient-global attention
extended_attention_mask = attention_mask
causal_mask = attention_mask
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is quite different from decoder-only models and from whisper, because T5 models have one "Module" that can be either encoder or decoder. And the masks have to be prepared as causal for decoder (same as in Llama) and using the prev preparation steps for encoder (for BC I didn't change that part). The None is simply to account for cases when no attn mask is passed, so we don't throw errors

@@ -1455,6 +1487,7 @@ def test_summarization(self):
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
padding="max_length",
truncation=True,
max_length=512,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, otherwise we can't form a tensor. Apparently it stopped truncating without us passing the max_length

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a good shape! just reviewed long t5 !

query_length = present_key_value_state[0].shape[2]
else:
query_length = None
query_length = cache_position[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surprised that this does not create inpute dependent control flow!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 :o

We can use the past_key_values.get_seq_length() as a workaround if needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does ahaha, but the catch is that for static cache we need the max_length (aka length of keys after cache is added). So I added a check on cache type in T5Attention to use max_cache_shape() if we're in static cache setting

We could also simply rely on key.shape[-2] which is cleaner. I don't think we had a special purpose to get key_length like that instead of getting simply shape, lmk if there was a reason for that

if isinstance(past_key_value, StaticCache):
seq_length = past_key_value.get_max_length()
elif past_key_value is not None:
seq_length += cache_position[0] if query_length is None else query_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_length is just cache_position[-1] in general no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, right, can be simply -1. Answered above more in details

past_key_value = past_key_value.self_attention_cache

if isinstance(past_key_value, StaticCache):
seq_length = past_key_value.get_max_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit confused by this as this is the max_seq_length in a way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not have checks that depend on the cache type in the AttentionLayer itself IMO and should just use get_seq_length or cache positions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered above more in details

)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
position_bias = self.compute_bias(seq_length, key_length, device=scores.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be juste cache_position[0] no?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, that creates dynamic control for compiling, so we better stick to tensor shapes and not values, given that the resulting length is same

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep good direction!


if attention_mask is None:
if attention_mask is None and not is_torchdynamo_compiling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might work with torch 2.4

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions/comments, mostly around readability and differences vs Whisper 🤗

(note: I have reviewed t5-related files and generation/utils.py, given that the others are copies/adaptations)

Comment on lines 483 to 484
if torch.jit.is_tracing():
seq_length = seq_length.to(hidden_states.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions:

  1. If seq_length comes from hidden_states in the lines above, why does it need to be moved to its device? 🤔
  2. Instead of torch.jit.is_tracing(), could we use is_torchdynamo_compiling()? Or does this fix a torch.jit-specific issue? If it is so, let's add it in the comment

Copy link
Member Author

@zucchini-nlp zucchini-nlp Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is fx_trace specific where the integer values (shape in this case) are somehow traced as tensors always. And apparently they are moved to CPU if the value was int without tracing, so did something ugly like this

But we'll not need this anymore now, if we agree all that key_length = key.shape[-2] is an okay way to get shapes and i'm not missing anything that can break

src/transformers/models/t5/modeling_t5.py Outdated Show resolved Hide resolved
src/transformers/models/t5/modeling_t5.py Outdated Show resolved Hide resolved
Comment on lines 507 to 510
if isinstance(past_key_value, StaticCache):
seq_length = past_key_value.get_max_length()
elif past_key_value is not None:
seq_length += cache_position[0] if query_length is None else query_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • A few lines above we have batch_size, seq_length = hidden_states.shape[:2] and seq_length = seq_length.to(hidden_states.device), and here we might have seq_length = past_key_value.get_max_length() (which doesn't need the lines above). Can we consolidate these two pieces of logic to get seq_length?
  • Readability: I had to go back to the whole code to figure out the origin of this block, which means our code is not very readable atm. Let's work on that :D Suggestions:
    • we want to use query_length if we are in a cross-attention layer and we have cache. Let's add some if dependent on is_cross_attention
    • cache_position[0] is the length of the cache, but it might not be immediately obvious. We can do something like cache_length = cache_position[0], which should help understanding what's going on :)

src/transformers/models/t5/modeling_t5.py Outdated Show resolved Hide resolved
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
query_length = cache_position[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 :o

We can use the past_key_values.get_seq_length() as a workaround if needed

Comment on lines 1050 to 1070
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past
mask_seq_length = (
past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
)
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if self.config.is_decoder:
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
else:
causal_mask = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Whisper, _update_causal_mask handles all this logic (attention_mask = None, decoder-only, encoder-decoder, ...)

Why can't we do it in T5? (I don't recall whether T5 has special masking requirements)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, as whisper has separate class for encoder and another for decoder. In encoder class we never use the causal_mask method, as the mask should not be causal. T5 in contrast has one place from where we decide if it's an encoder and decoder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t5 is the bad old example !

@@ -496,7 +496,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why Whisper didn't require this change, but T5 does?

if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return self.self_attention_cache.get_seq_length(layer_idx)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got an error in one of the compiles when get_seq_length(), apparently we can't check if tensor == [] when compiling. I think we can safely pass it to the self-attn cache

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to check the length ! ([] lives on cpu which is probably the source of the issue) having tensors for everything might be better btw

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, we should not go for that check on compiled model. That one is only for dynamic cache BC with empty lists

@@ -404,11 +419,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets

def compute_bias(self, query_length, key_length, device=None):
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the best I could come up with to overcome dynamic control. Otherwise the real_seq_length is always depedent on cache length which is currently computed from tensor values being not 0

Basically this is similar as if we used position ids, but T5 models don't pass position ids. Cache position is exactly what we need then

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

@zucchini-nlp
Copy link
Member Author

Done, updated the code for readability thus removing some unused things. Compile is working on T5 and other T5-based models, compile tests are green. The CI will also run slow tests for T5

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but as t5 is the first compile compatible model, let's make sure we have integration tests in T5 directly for:

  • cuda-graph decoding
  • cuda-graph encoder style
    with expected generation!

Otherwise, great work!

if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return self.self_attention_cache.get_seq_length(layer_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to check the length ! ([] lives on cpu which is probably the source of the issue) having tensors for everything might be better btw

Comment on lines 448 to 449
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a not, this will not produce the correct generation when you have cudagraphs AFAIK! But it's good for the other cases (BC and eager modes)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, why not? Cache position should be same, as the prev context length neither accounts for pad tokens


if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super surprised that compile is not complaining about the cache_position[-1]. I usually get:

torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 2188, in forward
    outputs = self.language_model(
  File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 1929, in forward
    outputs = self.model(
  File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 1729, in forward
    layer_outputs = decoder_layer(
  File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 1006, in forward
    hidden_states, attn_weights, past_key_value = self.cross_attn(
  File "/raid/arthur/py312/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/arthur/transformers/src/transformers/models/mllama/modeling_mllama.py", line 624, in forward
    elif cache_position[0] != 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it is because we never use real_seq_length anymore. I am passing cache_position to compute_bias so that compile doesn't complain but left everything else as it was for BC

@@ -4769,7 +4770,10 @@ def test_torch_compile(self):
n_iter = 3

tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
if self.is_encoder_decoder:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for t5 are we only testing one of them in that case? (probably decoder?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we test the whole model as encoder-decoder, I added a test_ckpt as t5-small. Since the forward call includes call to encoder and to decoder

@zucchini-nlp
Copy link
Member Author

Slow test for T5 added, for encoder-decoder and for encoder-only models

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
5 participants