-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SDPA support for T5 Style Models #30375
Conversation
I ran tests for T5 locally (
Is the fix for 3 & 4 making separate encoder and decoder classes? I would like to get some suggestions on this before I go ahead and fix other T5-like models.
|
Thanks @abdulfatir. Did you get to compare the speed-up provided by SDPA? If not, no worries. If it helps, you can maybe adapt this script: from transformers import AutoTokenizer, CLIPTextModel
import argparse
import torch
import time
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
def load_model(args):
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation=args.attn_implementation).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
return model, tokenizer
def get_inputs(args, tokenizer):
inputs = tokenizer(["a photo of a cat"] * args.batch_size, padding=True, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
return inputs
@torch.no_grad()
def main(args):
model, tokenizer = load_model(args)
inputs = get_inputs(args, tokenizer)
# warmup
for _ in range(5):
_ = model(**inputs)
start = time.time()
for _ in range(args.num_iters):
_ = model(**inputs)
end = time.time()
avg_inference_time = (end - start) / args.num_iters
print(f"{args.attn_implementation=}, {args.batch_size=}, {args.num_iters=}")
print(f"avg_inference_time: {avg_inference_time:.3f} seconds")
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--attn_implementation", type=str, choices=["sdpa", "eager"], default="eager")
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_iters", type=int, default=10)
args = parser.parse_args()
main(args) |
@sayakpaul Thanks for sharing! I was testing it but found the speedup and memory savings to be quite modest. Here's the script I used: import argparse
import time
import torch
from transformers import (
GenerationConfig,
T5ForConditionalGeneration,
T5TokenizerFast,
)
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
def load_model(args):
model = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-large",
attn_implementation=args.attn_implementation,
torch_dtype=getattr(torch, args.torch_dtype),
).to("cuda")
tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-large")
return model, tokenizer
def get_inputs(args, tokenizer):
inputs = tokenizer(
["Pedro Pedro Pedro 🦝🙌" * 10] * args.batch_size,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
return inputs
@torch.no_grad()
def main(args):
model, tokenizer = load_model(args)
inputs = get_inputs(args, tokenizer)
generation_config = GenerationConfig(max_new_tokens=20)
# warmup
for _ in range(5):
_ = model.generate(**inputs, generation_config=generation_config)
start = time.time()
for _ in range(args.num_iters):
_ = model.generate(**inputs, generation_config=generation_config)
end = time.time()
avg_inference_time = (end - start) / args.num_iters
print(
f"{args.attn_implementation=}, {args.torch_dtype=}, {args.batch_size=}, {args.num_iters=}"
)
print(f"avg_inference_time: {avg_inference_time:.3f} seconds")
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--attn_implementation", type=str, choices=["sdpa", "eager"], default="eager"
)
parser.add_argument(
"--torch_dtype",
type=str,
choices=["float32", "bfloat16", "float16"],
default="float32",
)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_iters", type=int, default=10)
args = parser.parse_args()
main(args) Upon investigation, I came across something interesting that affects T5 (not sure about other models) and silently makes SDPA use the less efficient kernel. Torch SDPA expects that the |
@ArthurZucker @fxmarty any thoughts on this? How should I proceed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @abdullaholuk! For the failing tests you mention, could you try running them on the main branch and check their status there? It may be they are failing there as well (if they are slow tests).
Could you try as well: RUN_SLOW=True pytest tests/models/t5 -k "test_eager_matches_sdpa_inference" -s -vvvvv
@fxmarty Thanks! These tests passed (will check on a
I would like to get your thoughts on the tracing test failures. Have you seen this before and do you know why these could be failing? If not, I can take a deeper look to figure out what's wrong. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I addressed the comments. The Once we fix this and everything else looks okay, I will fix the copies of this model. |
@abdulfatir is there still the issue with
? |
# spda kernels require tensors to have stride=1 in the last dimension | ||
# .contiguous() does not behave correctly for tensors with singleton dimensions | ||
# .clone(memory_format=torch.contiguous_format) is a workaround |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could like to pytorch/pytorch#127523 here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
For |
@ArthurZucker I changed it to use
But I believe this should also be an issue for other sdpa calls, right? Why is only this test breaking? |
Tried a higher |
@abdulfatir you can remove this test. Transformers ONNX export has been deprecated for >1 year now. Or alternatively, use The test |
Thanks @fxmarty! I skipped the test. Can you please also review the PR (and run the remaining workflows)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as long as the slow tests pass & some real models have been validated
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@fxmarty for slow tests do we need a label? |
@abdulfatir I don't think we can run them from github actions. You just need to run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work adding this!
cc @ylacombe regarding pop2piano
All of the models which have this added should have information added to their modeling page, on how to use and speedups e.g. like here for Mistral
Are SDPA tests run for all of these models now?
The slow integration tests will need to be run for all the models with SDPA added. They can be triggered by pushing a commit with the message: [run_slow] pop2piano,mt5,t5
. Another HF member or I will need to approve the workflow for it to run
@@ -610,7 +609,7 @@ def test_model_from_pretrained(self): | |||
model = Pop2PianoForConditionalGeneration.from_pretrained(model_name) | |||
self.assertIsNotNone(model) | |||
|
|||
@require_onnx | |||
@unittest.skip("ONNX support deprecated") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be related to this PR?
We shouldn't be skipping tests because they're deprecated - deprecation happens with main code, but if a test is deprecated then we should just remove it
@@ -848,7 +848,7 @@ def test_export_to_onnx(self): | |||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), | |||
f"{tmpdirname}/t5_test.onnx", | |||
export_params=True, | |||
opset_version=9, | |||
opset_version=14, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for this?
@@ -851,7 +851,7 @@ def test_export_to_onnx(self): | |||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), | |||
f"{tmpdirname}/t5_test.onnx", | |||
export_params=True, | |||
opset_version=9, | |||
opset_version=14, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here
query_states = self._shape( | ||
self.q(hidden_states), batch_size | ||
) # (batch_size, n_heads, seq_length, dim_per_head) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Library convention is for comments to go on the line above to avoid line splitting
query_states = self._shape( | |
self.q(hidden_states), batch_size | |
) # (batch_size, n_heads, seq_length, dim_per_head) | |
# (batch_size, n_heads, seq_length, dim_per_head) | |
query_states = self._shape(self.q(hidden_states), batch_size) |
attn_output = self._unshape( | ||
torch.matmul(attn_weights, value_states), batch_size | ||
) # (batch_size, seq_length, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attn_output = self._unshape( | |
torch.matmul(attn_weights, value_states), batch_size | |
) # (batch_size, seq_length, dim) | |
# (batch_size, seq_length, dim) | |
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size) |
query_states = self._shape( | ||
self.q(hidden_states), batch_size | ||
) # (batch_size, n_heads, seq_length, dim_per_head) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query_states = self._shape( | |
self.q(hidden_states), batch_size | |
) # (batch_size, n_heads, seq_length, dim_per_head) | |
# (batch_size, n_heads, seq_length, dim_per_head) | |
query_states = self._shape(self.q(hidden_states), batch_size) |
attn_output = self._unshape( | ||
torch.matmul(attn_weights, value_states), batch_size | ||
) # (batch_size, seq_length, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attn_output = self._unshape( | |
torch.matmul(attn_weights, value_states), batch_size | |
) # (batch_size, seq_length, dim) | |
# (batch_size, seq_length, dim) | |
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size) |
query_states = self._shape( | ||
self.q(hidden_states), batch_size | ||
) # (batch_size, n_heads, seq_length, dim_per_head) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query_states = self._shape( | |
self.q(hidden_states), batch_size | |
) # (batch_size, n_heads, seq_length, dim_per_head) | |
# (batch_size, n_heads, seq_length, dim_per_head) | |
query_states = self._shape(self.q(hidden_states), batch_size) |
attn_output = self._unshape( | ||
torch.matmul(attn_weights, value_states), batch_size | ||
) # (batch_size, seq_length, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attn_output = self._unshape( | |
torch.matmul(attn_weights, value_states), batch_size | |
) # (batch_size, seq_length, dim) | |
# (batch_size, seq_length, dim) | |
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size) |
@@ -1007,6 +1019,7 @@ def unshape(states): | |||
|
|||
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 | |||
class LongT5LayerSelfAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for not adding support for LongT5?
Hey @amyeroberts, thanks for pinging, I'll take a closer look to pop2piano once your commentaries have been addressed if that's okay with you! A small comment from my side @abdulfatir, have you been able to test speed-ups again now that you found the reasons for the apparent lack of speed-ups ? Also, to properly measure inference speed-ups, you should also modify a bit the script you've used to:
I haven't tested the following script but it should work with little to any modifs! import argparse
import time
import torch
from transformers import (
GenerationConfig,
T5ForConditionalGeneration,
T5TokenizerFast,
)
from transformers import set_seed
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
def load_model(args):
model = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-large",
attn_implementation=args.attn_implementation,
torch_dtype=getattr(torch, args.torch_dtype),
).to("cuda")
tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-large")
return model, tokenizer
def get_inputs(args, tokenizer):
inputs = tokenizer(
["Pedro Pedro Pedro 🦝🙌" * 10] * args.batch_size,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
return inputs
def measure_latency_and_memory_use(model, device, inputs, generation_config, nb_loops):
# define Events that measure start and end of the generate pass
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# reset cuda memory stats and empty cache
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()
# get the start time
start_event.record()
# actually generate
for _ in range(nb_loops):
# set seed for reproducibility
set_seed(0)
generation = model.generate(**inputs, generation_config=generation_config)
# get the end time
end_event.record()
torch.cuda.synchronize()
# measure memory footprint and elapsed time
max_memory = torch.cuda.max_memory_allocated(device)
elapsed_time = start_event.elapsed_time(end_event) * 1.0e-3
execution_time_in_s = elapsed_time/nb_loops
max_memory_footprint_in_GB = bytes_to_giga_bytes(max_memory)
return execution_time_in_s, max_memory_footprint_in_GB
@torch.no_grad()
def main(args):
model, tokenizer = load_model(args)
inputs = get_inputs(args, tokenizer)
generation_config = GenerationConfig(min_new_tokens=args.num_tokens, max_new_tokens=args.num_tokens)
# warmup
for _ in range(5):
_ = model.generate(**inputs, generation_config=generation_config)
avg_inference_time, max_memory_footprint_in_GB = measure_latency_and_memory_use(model, model.device, inputs, generation_config, args.num_iters)
print(
f"{args.attn_implementation=}, {args.torch_dtype=}, {args.batch_size=}, {args.num_iters=}, {args.num_tokens=}"
)
print(f"avg_inference_time: {avg_inference_time:.3f} seconds")
print(
f"Max memory allocated: {max_memory_footprint_in_GB} GB"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--attn_implementation", type=str, choices=["sdpa", "eager"], default="eager"
)
parser.add_argument(
"--torch_dtype",
type=str,
choices=["float32", "bfloat16", "float16"],
default="float32",
)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_iters", type=int, default=10)
parser.add_argument("--num_tokens", type=int, default=20)
args = parser.parse_args()
main(args) |
Thanks @amyeroberts and @ylacombe! I will probably look into this later over the weekend. |
@abdulfatir any update on when this PR will be merged ? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi 👋🏽 any updates on this? |
I think #34089 enabled it! |
Hi @ArthurZucker, I've tested compiling the T5 model and I do get a very modest speedup: # %%
import time
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# %% Example prompt
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
text_input = "The theory of relativity usually encompasses two interrelated physics theories by Albert Einstein: special relativity and general relativity, proposed and published in 1905 and 1915, respectively. Special relativity applies to all physical phenomena in the absence of gravity. General relativity explains the law of gravitation and its relation to the forces of nature. It applies to the cosmological and astrophysical realm, "
input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to(device)
# %% Without compilation
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base").to(device)
# First pass to initialise any caches/JITs/whatever
_ = model.generate(input_ids, max_length=500)
times = []
for _ in tqdm(range(20)):
start = time.time()
_ = model.generate(input_ids, max_length=500)
times.append(time.time() - start)
print(f"Times without compilation: {np.mean(times):.4f} ± {np.std(times):.4f}")
# %% With compilation
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base").to(device)
compiled_model = torch.compile(model)
# First pass to initialise any caches/JITs/whatever
_ = compiled_model.generate(input_ids, max_length=500)
times = []
for _ in tqdm(range(20)):
start = time.time()
_ = compiled_model.generate(input_ids, max_length=500)
times.append(time.time() - start)
print(f"Times with compilation: {np.mean(times):.4f} ± {np.std(times):.4f}") However, I've tried using: model = AutoModelForSeq2SeqLM.from_pretrained(
"google-t5/t5-base", device_map="cuda:1", attn_implementation="SDPA"
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"google-t5/t5-base", device_map="cuda:1", attn_implementation="flash_attention_2"
) but it complains that neither SDPA nor FlashAttention are implemented for the T5 model. Any ideas? 😄 |
Ah sorry it's not implemented indeed ! |
Thanks, will keep an eye open! |
Hi @ArthurZucker wondering where the state of affair is for a t5 with sdpa/flash with working compile? |
Now that the PR is merged, if you want you can open one for T5 or I'll do it in a bit! |
What does this PR do?
Adds
torch
'sscaled_dot_product_attention
support for the T5 model.Part of #28005
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @fxmarty @sayakpaul