Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SDPA support for T5 Style Models #30375

Closed
wants to merge 14 commits into from

Conversation

abdulfatir
Copy link

@abdulfatir abdulfatir commented Apr 21, 2024

What does this PR do?

Adds torch's scaled_dot_product_attention support for the T5 model.

Part of #28005

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @fxmarty @sayakpaul

@abdulfatir
Copy link
Author

abdulfatir commented Apr 21, 2024

I ran tests for T5 locally (pytest tests/models/t5/test_modeling_t5.py). Most of the tests pass except the following:

FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_disk_offload_bin - AssertionError: ValueError not raised
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_prompt_lookup_decoding_matches_greedy_search - RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 3
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx - AssertionError: Couldn't trace module: __bool__ should return bool, returned Tensor
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module: __bool__ should return bool, returned Tensor
  1. Not sure what's going on but this is not failing on Github CI.
  2. Need some help figuring out what's going on. Test fails when prompt_lookup_num_tokens is set in test_prompt_lookup_decoding_matches_greedy_search.
  3. Is this due to if self.is_decoder?
  4. Is this due to if self.is_decoder?

Is the fix for 3 & 4 making separate encoder and decoder classes? I would like to get some suggestions on this before I go ahead and fix other T5-like models.

Another question: If sdpa is supported by a model, would it automatically be incorporated into the current tests or should I also write some unit tests to ensure that the sdpa and eager versions match? Looks like sdpa is being compared with eager in a slow test which is being skipped. I ran RUN_SLOW=True pytest -rs tests/models/t5/test_modeling_t5.py::T5ModelTest::test_eager_matches_sdpa_generate and the test passed.

@abdulfatir abdulfatir changed the title Initial commit Add SDPA support for T5 Apr 21, 2024
@sayakpaul
Copy link
Member

sayakpaul commented Apr 24, 2024

Thanks @abdulfatir. Did you get to compare the speed-up provided by SDPA? If not, no worries. If it helps, you can maybe adapt this script:

from transformers import AutoTokenizer, CLIPTextModel
import argparse 
import torch
import time

def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

def load_model(args):
    model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation=args.attn_implementation).to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    return model, tokenizer

def get_inputs(args, tokenizer):
    inputs = tokenizer(["a photo of a cat"] * args.batch_size, padding=True, return_tensors="pt")
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    return inputs

@torch.no_grad()
def main(args):
    model, tokenizer = load_model(args)
    inputs = get_inputs(args, tokenizer)
    # warmup
    for _ in range(5):
        _  = model(**inputs)
    
    start = time.time()
    for _ in range(args.num_iters):
        _  = model(**inputs)
    end = time.time()
    avg_inference_time = (end - start) / args.num_iters

    print(f"{args.attn_implementation=}, {args.batch_size=}, {args.num_iters=}")
    print(f"avg_inference_time: {avg_inference_time:.3f} seconds")
    print(
        f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--attn_implementation", type=str, choices=["sdpa", "eager"], default="eager")
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--num_iters", type=int, default=10)
    args = parser.parse_args()
    main(args)

@abdulfatir
Copy link
Author

abdulfatir commented Apr 24, 2024

@sayakpaul Thanks for sharing! I was testing it but found the speedup and memory savings to be quite modest. Here's the script I used:

import argparse
import time

import torch

from transformers import (
    GenerationConfig,
    T5ForConditionalGeneration,
    T5TokenizerFast,
)


def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024


def load_model(args):
    model = T5ForConditionalGeneration.from_pretrained(
        "google/flan-t5-large",
        attn_implementation=args.attn_implementation,
        torch_dtype=getattr(torch, args.torch_dtype),
    ).to("cuda")
    tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-large")
    return model, tokenizer


def get_inputs(args, tokenizer):
    inputs = tokenizer(
        ["Pedro Pedro Pedro 🦝🙌" * 10] * args.batch_size,
        padding=True,
        return_tensors="pt",
    )
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    return inputs


@torch.no_grad()
def main(args):
    model, tokenizer = load_model(args)
    inputs = get_inputs(args, tokenizer)
    generation_config = GenerationConfig(max_new_tokens=20)
    # warmup
    for _ in range(5):
        _ = model.generate(**inputs, generation_config=generation_config)

    start = time.time()
    for _ in range(args.num_iters):
        _ = model.generate(**inputs, generation_config=generation_config)
    end = time.time()
    avg_inference_time = (end - start) / args.num_iters

    print(
        f"{args.attn_implementation=}, {args.torch_dtype=}, {args.batch_size=}, {args.num_iters=}"
    )
    print(f"avg_inference_time: {avg_inference_time:.3f} seconds")
    print(
        f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--attn_implementation", type=str, choices=["sdpa", "eager"], default="eager"
    )
    parser.add_argument(
        "--torch_dtype",
        type=str,
        choices=["float32", "bfloat16", "float16"],
        default="float32",
    )
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--num_iters", type=int, default=10)
    args = parser.parse_args()
    main(args)

Upon investigation, I came across something interesting that affects T5 (not sure about other models) and silently makes SDPA use the less efficient kernel. Torch SDPA expects that the attn_mask should have stride=1 in the last dimension. However, due to the permute operation in the computation of position bias in T5, the stride for the last dim gets changed and torch uses the worse kernel whenever that's the case. I thought a quick solution would be to just use .contiguous() for the SDPA version. However, .contiguous() doesn't really work correctly for tensors with non-singleton dimensions (pytorch/pytorch#116333). Someone fixed the check recently for flash attention kernel but not for the memory efficient one.

@abdulfatir
Copy link
Author

@ArthurZucker @fxmarty any thoughts on this? How should I proceed?

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @abdullaholuk! For the failing tests you mention, could you try running them on the main branch and check their status there? It may be they are failing there as well (if they are slow tests).

Could you try as well: RUN_SLOW=True pytest tests/models/t5 -k "test_eager_matches_sdpa_inference" -s -vvvvv

src/transformers/models/t5/modeling_t5.py Outdated Show resolved Hide resolved
@abdulfatir
Copy link
Author

Thank you @abdullaholuk! For the failing tests you mention, could you try running them on the main branch and check their status there? It may be they are failing there as well (if they are slow tests).

Could you try as well: RUN_SLOW=True pytest tests/models/t5 -k "test_eager_matches_sdpa_inference" -s -vvvvv

@fxmarty Thanks! These tests passed (will check on a bf16 supported device later):

tests/models/t5/test_modeling_t5.py::T5ModelTest::test_eager_matches_sdpa_inference_0_float16 <- tests/test_modeling_common.py PASSED
tests/models/t5/test_modeling_t5.py::T5ModelTest::test_eager_matches_sdpa_inference_1_bfloat16 <- tests/test_modeling_common.py SKIPPED (bfloat16 not supported on cuda (on the
specific device currently used, e.g. Nvidia T4 GPU))
tests/models/t5/test_modeling_t5.py::T5ModelTest::test_eager_matches_sdpa_inference_2_float32 <- tests/test_modeling_common.py PASSED
tests/models/t5/test_modeling_t5.py::T5EncoderOnlyModelTest::test_eager_matches_sdpa_inference_0_float16 <- tests/test_modeling_common.py PASSED
tests/models/t5/test_modeling_t5.py::T5EncoderOnlyModelTest::test_eager_matches_sdpa_inference_1_bfloat16 <- tests/test_modeling_common.py SKIPPED (bfloat16 not supported on cuda (on
the specific device currently used, e.g. Nvidia T4 GPU))
tests/models/t5/test_modeling_t5.py::T5EncoderOnlyModelTest::test_eager_matches_sdpa_inference_2_float32 <- tests/test_modeling_common.py PASSED

I would like to get your thoughts on the tracing test failures. Have you seen this before and do you know why these could be failing? If not, I can take a deeper look to figure out what's wrong.
Also, do you have any thoughts on the stride issue I mentioned in a comment above.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@abdulfatir
Copy link
Author

@fxmarty @ArthurZucker

I addressed the comments. The test_prompt_lookup_decoding_matches_greedy_search test is failing and I am not completely sure why. Do you have any insights?

Once we fix this and everything else looks okay, I will fix the copies of this model.

@ArthurZucker ArthurZucker requested a review from fxmarty June 6, 2024 06:59
@fxmarty
Copy link
Contributor

fxmarty commented Jun 6, 2024

@abdulfatir is there still the issue with

FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx - AssertionError: Couldn't trace module: __bool__ should return bool, returned Tensor
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module: __bool__ should return bool, returned Tensor

?

Comment on lines 689 to 691
# spda kernels require tensors to have stride=1 in the last dimension
# .contiguous() does not behave correctly for tensors with singleton dimensions
# .clone(memory_format=torch.contiguous_format) is a workaround
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could like to pytorch/pytorch#127523 here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

@fxmarty
Copy link
Contributor

fxmarty commented Jun 6, 2024

For test_prompt_lookup_decoding_matches_greedy_search, I am not sure... I could have a look if I have time, but otherwise you could add some prints and see where the shape mismatch occurs.

@abdulfatir abdulfatir changed the title Add SDPA support for T5 Add SDPA support for T5 Style Models Jun 15, 2024
@abdulfatir
Copy link
Author

@ArthurZucker I changed it to use clone instead of copy_ but now the test breaks with:

FAILED tests/models/pop2piano/test_modeling_pop2piano.py::Pop2PianoModelTest::test_export_to_onnx - torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 9 is not supported. Support for this operator was added in version 14, try exporting with this version.

But I believe this should also be an issue for other sdpa calls, right? Why is only this test breaking?

@abdulfatir
Copy link
Author

Tried a higher opset_version as suggested by the error but something is still failing with SDPA + ONNX. This maybe relevant: pytorch/pytorch#96944 cc @fxmarty

@fxmarty
Copy link
Contributor

fxmarty commented Jun 19, 2024

@abdulfatir you can remove this test. Transformers ONNX export has been deprecated for >1 year now. Or alternatively, use attn_implementation="eager" when loading the model with from_pretrained in the test.

The test test_export_to_onnx is apparently only left for mt5, fsmt, t5, longt5, pop2piano, switch_transformers and umt5 in transformers. The onnx export is otherwise tested and handled in https://github.com/huggingface/optimum

@abdulfatir
Copy link
Author

abdulfatir commented Jun 19, 2024

Thanks @fxmarty! I skipped the test. Can you please also review the PR (and run the remaining workflows)?

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as the slow tests pass & some real models have been validated

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@abdulfatir
Copy link
Author

@fxmarty for slow tests do we need a label?

@fxmarty
Copy link
Contributor

fxmarty commented Jun 24, 2024

@abdulfatir I don't think we can run them from github actions. You just need to run RUN_SLOW=1 pytest tests/models/t5 -s -vvvvv and check that no more tests fail than on main branch (some might already be failing).

@fxmarty fxmarty requested a review from ArthurZucker June 24, 2024 13:37
@fxmarty fxmarty assigned fxmarty and unassigned fxmarty Jun 24, 2024
@fxmarty fxmarty requested a review from amyeroberts June 24, 2024 13:38
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work adding this!

cc @ylacombe regarding pop2piano

All of the models which have this added should have information added to their modeling page, on how to use and speedups e.g. like here for Mistral

Are SDPA tests run for all of these models now?

The slow integration tests will need to be run for all the models with SDPA added. They can be triggered by pushing a commit with the message: [run_slow] pop2piano,mt5,t5. Another HF member or I will need to approve the workflow for it to run

@@ -610,7 +609,7 @@ def test_model_from_pretrained(self):
model = Pop2PianoForConditionalGeneration.from_pretrained(model_name)
self.assertIsNotNone(model)

@require_onnx
@unittest.skip("ONNX support deprecated")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be related to this PR?

We shouldn't be skipping tests because they're deprecated - deprecation happens with main code, but if a test is deprecated then we should just remove it

@@ -848,7 +848,7 @@ def test_export_to_onnx(self):
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/t5_test.onnx",
export_params=True,
opset_version=9,
opset_version=14,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this?

@@ -851,7 +851,7 @@ def test_export_to_onnx(self):
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/t5_test.onnx",
export_params=True,
opset_version=9,
opset_version=14,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here

Comment on lines +490 to +492
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Library convention is for comments to go on the line above to avoid line splitting

Suggested change
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
# (batch_size, n_heads, seq_length, dim_per_head)
query_states = self._shape(self.q(hidden_states), batch_size)

Comment on lines +552 to +554
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
# (batch_size, seq_length, dim)
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size)

Comment on lines +518 to +520
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
# (batch_size, n_heads, seq_length, dim_per_head)
query_states = self._shape(self.q(hidden_states), batch_size)

Comment on lines +580 to +582
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
# (batch_size, seq_length, dim)
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size)

Comment on lines +771 to +773
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
# (batch_size, n_heads, seq_length, dim_per_head)
query_states = self._shape(self.q(hidden_states), batch_size)

Comment on lines +833 to +835
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
# (batch_size, seq_length, dim)
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size)

@@ -1007,6 +1019,7 @@ def unshape(states):

# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
class LongT5LayerSelfAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for not adding support for LongT5?

@ylacombe
Copy link
Contributor

ylacombe commented Jun 26, 2024

Hey @amyeroberts, thanks for pinging, I'll take a closer look to pop2piano once your commentaries have been addressed if that's okay with you!

A small comment from my side @abdulfatir, have you been able to test speed-ups again now that you found the reasons for the apparent lack of speed-ups ?

Also, to properly measure inference speed-ups, you should also modify a bit the script you've used to:

  1. set a fixed number of tokens to generate and benchmark over a range of number of tokens to generate
  2. use torch.cuda.Event to properly measure time spent in inference

I haven't tested the following script but it should work with little to any modifs!

import argparse
import time

import torch

from transformers import (
    GenerationConfig,
    T5ForConditionalGeneration,
    T5TokenizerFast,
)
from transformers import set_seed


def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024


def load_model(args):
    model = T5ForConditionalGeneration.from_pretrained(
        "google/flan-t5-large",
        attn_implementation=args.attn_implementation,
        torch_dtype=getattr(torch, args.torch_dtype),
    ).to("cuda")
    tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-large")
    return model, tokenizer


def get_inputs(args, tokenizer):
    inputs = tokenizer(
        ["Pedro Pedro Pedro 🦝🙌" * 10] * args.batch_size,
        padding=True,
        return_tensors="pt",
    )
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    return inputs

def measure_latency_and_memory_use(model, device, inputs, generation_config, nb_loops):

  # define Events that measure start and end of the generate pass
  start_event = torch.cuda.Event(enable_timing=True)
  end_event = torch.cuda.Event(enable_timing=True)

  # reset cuda memory stats and empty cache
  torch.cuda.reset_peak_memory_stats(device)
  torch.cuda.empty_cache()
  torch.cuda.synchronize()

  # get the start time
  start_event.record()

  # actually generate
  for _ in range(nb_loops):
        # set seed for reproducibility
        set_seed(0)
        generation = model.generate(**inputs, generation_config=generation_config)

  # get the end time
  end_event.record()
  torch.cuda.synchronize()

  # measure memory footprint and elapsed time
  max_memory = torch.cuda.max_memory_allocated(device)
  elapsed_time = start_event.elapsed_time(end_event) * 1.0e-3


  execution_time_in_s =  elapsed_time/nb_loops
  max_memory_footprint_in_GB = bytes_to_giga_bytes(max_memory)

  return execution_time_in_s, max_memory_footprint_in_GB


@torch.no_grad()
def main(args):
    model, tokenizer = load_model(args)
    inputs = get_inputs(args, tokenizer)
    generation_config = GenerationConfig(min_new_tokens=args.num_tokens, max_new_tokens=args.num_tokens)
    # warmup
    for _ in range(5):
        _ = model.generate(**inputs, generation_config=generation_config)

    avg_inference_time, max_memory_footprint_in_GB = measure_latency_and_memory_use(model, model.device, inputs, generation_config, args.num_iters)

    print(
        f"{args.attn_implementation=}, {args.torch_dtype=}, {args.batch_size=}, {args.num_iters=}, {args.num_tokens=}"
    )
    print(f"avg_inference_time: {avg_inference_time:.3f} seconds")
    print(
        f"Max memory allocated: {max_memory_footprint_in_GB} GB"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--attn_implementation", type=str, choices=["sdpa", "eager"], default="eager"
    )
    parser.add_argument(
        "--torch_dtype",
        type=str,
        choices=["float32", "bfloat16", "float16"],
        default="float32",
    )
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--num_iters", type=int, default=10)
    parser.add_argument("--num_tokens", type=int, default=20)
    args = parser.parse_args()
    main(args)

@abdulfatir
Copy link
Author

Thanks @amyeroberts and @ylacombe! I will probably look into this later over the weekend.

@agemagician
Copy link
Contributor

@abdulfatir any update on when this PR will be merged ?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 22, 2024
@alvaropp
Copy link

alvaropp commented Oct 9, 2024

Hi 👋🏽 any updates on this?

@ArthurZucker
Copy link
Collaborator

I think #34089 enabled it!

@alvaropp
Copy link

I think #34089 enabled it!

Hi @ArthurZucker, I've tested compiling the T5 model and I do get a very modest speedup:

# %%
import time

import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# %% Example prompt
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")

text_input = "The theory of relativity usually encompasses two interrelated physics theories by Albert Einstein: special relativity and general relativity, proposed and published in 1905 and 1915, respectively. Special relativity applies to all physical phenomena in the absence of gravity. General relativity explains the law of gravitation and its relation to the forces of nature. It applies to the cosmological and astrophysical realm, "
input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to(device)

# %% Without compilation
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base").to(device)

# First pass to initialise any caches/JITs/whatever
_ = model.generate(input_ids, max_length=500)

times = []
for _ in tqdm(range(20)):
    start = time.time()
    _ = model.generate(input_ids, max_length=500)
    times.append(time.time() - start)

print(f"Times without compilation: {np.mean(times):.4f} ± {np.std(times):.4f}")

# %% With compilation
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base").to(device)
compiled_model = torch.compile(model)

# First pass to initialise any caches/JITs/whatever
_ = compiled_model.generate(input_ids, max_length=500)

times = []
for _ in tqdm(range(20)):
    start = time.time()
    _ = compiled_model.generate(input_ids, max_length=500)
    times.append(time.time() - start)

print(f"Times with compilation: {np.mean(times):.4f} ± {np.std(times):.4f}")

However, I've tried using:

model = AutoModelForSeq2SeqLM.from_pretrained(
    "google-t5/t5-base", device_map="cuda:1", attn_implementation="SDPA"
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    "google-t5/t5-base", device_map="cuda:1", attn_implementation="flash_attention_2"
)

but it complains that neither SDPA nor FlashAttention are implemented for the T5 model.

Any ideas? 😄

@ArthurZucker
Copy link
Collaborator

Ah sorry it's not implemented indeed !
Will handle it after #34282 !

@alvaropp
Copy link

Ah sorry it's not implemented indeed ! Will handle it after #34282 !

Thanks, will keep an eye open!

@rubenweitzman
Copy link

Hi @ArthurZucker wondering where the state of affair is for a t5 with sdpa/flash with working compile?

@ArthurZucker
Copy link
Collaborator

Now that the PR is merged, if you want you can open one for T5 or I'll do it in a bit!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants