Skip to content

Commit b44e2d1

Browse files
guangy10Guang Yang
andauthored
Cleanup and minor fixes (#61)
Co-authored-by: Guang Yang <guangyang@fb.com>
1 parent 32d066f commit b44e2d1

File tree

11 files changed

+47
-100
lines changed

11 files changed

+47
-100
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ We currently support a wide range of popular transformer models, including encod
144144
#### Decoder-only models
145145
- [Gemma](https://huggingface.co/google/gemma-2b): `Gemma-2b` and its variants
146146
- [Gemma2](https://huggingface.co/google/gemma-2-2b): `Gemma-2-2b` and its variants
147-
- [Gemma3](https://huggingface.co/google/gemma-3-1b-it): `Gemma-3-1b` and its variants
147+
- [Gemma3](https://huggingface.co/google/gemma-3-1b-it): `Gemma-3-1b` and its variants *(requires install latest `transformers (4.52.0.dev0)` manually from source)*
148148
- [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B): `Llama-3.2-1B` and its variants
149149
- [Qwen2](https://huggingface.co/Qwen/Qwen2.5-0.5B): `Qwen2.5-0.5B` and its variants
150150
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B): `Qwen3-0.6B` and its variants

optimum/commands/export/executorch.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
from pathlib import Path
1818
from typing import TYPE_CHECKING
1919

20-
from executorch import version as executorch_version
21-
from packaging import version as pkg_version
22-
2320
from ...exporters import TasksManager
2421
from ..base import BaseOptimumCLICommand, CommandInfo
2522

@@ -74,8 +71,6 @@ def run(self):
7471

7572
kwargs = {}
7673
if self.args.use_custom_sdpa:
77-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
78-
raise ValueError("custom_sdpa is not supported for executorch < 0.6.0")
7974
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
8075

8176
main_export(

optimum/executorch/attentions/custom_sdpa.py

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,56 +15,52 @@
1515
from typing import Optional, Tuple, Union
1616

1717
import torch
18-
from executorch import version as executorch_version
19-
from packaging import version as pkg_version
18+
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa
2019

2120

22-
if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
23-
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa
21+
def custom_sdpa_with_start_pos_forward(
22+
module: torch.nn.Module,
23+
query: torch.Tensor,
24+
key: torch.Tensor,
25+
value: torch.Tensor,
26+
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
27+
scaling: Optional[float] = None,
28+
softcap: Optional[float] = None,
29+
head_mask: Optional[torch.Tensor] = None,
30+
**kwargs,
31+
) -> Tuple[torch.Tensor, None]:
32+
# This is before the transpose
33+
max_seq_len = key.shape[2]
2434

25-
def custom_sdpa_with_start_pos_forward(
26-
module: torch.nn.Module,
27-
query: torch.Tensor,
28-
key: torch.Tensor,
29-
value: torch.Tensor,
30-
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
31-
scaling: Optional[float] = None,
32-
softcap: Optional[float] = None,
33-
head_mask: Optional[torch.Tensor] = None,
34-
**kwargs,
35-
) -> Tuple[torch.Tensor, None]:
36-
# This is before the transpose
37-
max_seq_len = key.shape[2]
35+
# FA2 uses non-transposed inputs
36+
query = query.transpose(1, 2)
37+
key = key.transpose(1, 2)
38+
value = value.transpose(1, 2)
3839

39-
# FA2 uses non-transposed inputs
40-
query = query.transpose(1, 2)
41-
key = key.transpose(1, 2)
42-
value = value.transpose(1, 2)
40+
# Convert the hell out of the inputs to fp32 and back
41+
input_dtype = query.dtype
42+
query = query.to(torch.float32)
43+
key = key.to(torch.float32)
44+
value = value.to(torch.float32)
4345

44-
# Convert the hell out of the inputs to fp32 and back
45-
input_dtype = query.dtype
46-
query = query.to(torch.float32)
47-
key = key.to(torch.float32)
48-
value = value.to(torch.float32)
46+
# Ignore the causal flag from kwargs but use the one in module
47+
kwargs.pop("is_causal", None)
4948

50-
# Ignore the causal flag from kwargs but use the one in module
51-
kwargs.pop("is_causal", None)
52-
53-
# Calculate the input pos from attention mask.
54-
# Branch out for float vs bool mask
55-
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
56-
attention_mask = attention_mask.reshape(-1, max_seq_len)
57-
first_row_mask = attention_mask[0, :]
58-
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
59-
start_pos = torch.argmin(first_row_mask).item() - 1
60-
output = torch.ops.llama.custom_sdpa(
61-
query,
62-
key,
63-
value,
64-
start_pos=start_pos,
65-
attn_mask=None,
66-
drpout_p=0.0,
67-
is_causal=module.is_causal,
68-
scale=scaling,
69-
)
70-
return output.to(input_dtype), None
49+
# Calculate the input pos from attention mask.
50+
# Branch out for float vs bool mask
51+
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
52+
attention_mask = attention_mask.reshape(-1, max_seq_len)
53+
first_row_mask = attention_mask[0, :]
54+
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
55+
start_pos = torch.argmin(first_row_mask).item() - 1
56+
output = torch.ops.llama.custom_sdpa(
57+
query,
58+
key,
59+
value,
60+
start_pos=start_pos,
61+
attn_mask=None,
62+
drpout_p=0.0,
63+
is_causal=module.is_causal,
64+
scale=scaling,
65+
)
66+
return output.to(input_dtype), None

optimum/exporters/executorch/convert.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,17 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22-
from packaging import version as pkg_version
2322
from transformers.modeling_utils import AttentionInterface
2423

25-
from executorch import version as executorch_version
24+
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
2625

2726
from .recipe_registry import discover_recipes, recipe_registry
2827

2928

30-
if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
31-
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
32-
33-
# Register custom sdpa via `AttentionInterface` for executorch>=0.6.0
34-
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
35-
36-
3729
logger = logging.getLogger(__name__)
3830

31+
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
32+
3933

4034
def export_to_executorch(
4135
model,

tests/models/test_modeling_common.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
from tempfile import TemporaryDirectory
2323

2424
import torch
25-
from executorch import version as executorch_version
2625
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
2726
from huggingface_hub import HfApi
28-
from packaging import version as pkg_version
2927
from transformers import (
3028
AutoModelForCausalLM,
3129
AutoTokenizer,
@@ -113,9 +111,6 @@ def test_find_files_matching_pattern(self):
113111
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)
114112

115113
def test_export_with_custom_sdpa(self):
116-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
117-
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
118-
119114
model_id = "optimum-internal-testing/tiny-random-llama"
120115
with tempfile.TemporaryDirectory() as tempdir:
121116
subprocess.run(
@@ -130,9 +125,6 @@ def test_export_with_custom_sdpa(self):
130125
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
131126

132127
def test_eager_text_generation_with_custom_sdpa(self):
133-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
134-
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
135-
136128
model_id = "HuggingFaceTB/SmolLM2-135M"
137129
prompt = "My favourite condiment is "
138130
max_seq_len = 32

tests/models/test_modeling_gemma.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
import unittest
2323

2424
import pytest
25-
from executorch import version as executorch_version
2625
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
27-
from packaging import version as pkg_version
2826
from transformers import AutoTokenizer
2927
from transformers.testing_utils import slow
3028

@@ -65,9 +63,6 @@ def test_gemma_export_to_executorch(self):
6563
@slow
6664
@pytest.mark.run_slow
6765
def test_gemma_text_generation(self):
68-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
69-
self.skipTest(reason="Support of float16 requires executorch >= 0.6 to run.")
70-
7166
# TODO: Switch to use google/gemma-2b once https://github.com/huggingface/optimum/issues/2127 is fixed
7267
# model_id = "google/gemma-2b"
7368
model_id = "weqweasdas/RM-Gemma-2B"

tests/models/test_modeling_gemma2.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
import unittest
2323

2424
import pytest
25-
from executorch import version as executorch_version
2625
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
27-
from packaging import version as pkg_version
2826
from transformers import AutoTokenizer
2927
from transformers.testing_utils import slow
3028

@@ -65,9 +63,6 @@ def test_gemma2_export_to_executorch(self):
6563
@slow
6664
@pytest.mark.run_slow
6765
def test_gemma2_text_generation(self):
68-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
69-
self.skipTest(reason="Support of float16 requires executorch >= 0.6 to run.")
70-
7166
# TODO: Switch to use google/gemma-2-2b once https://github.com/huggingface/optimum/issues/2127 is fixed
7267
# model_id = "google/gemma-2-2b"
7368
model_id = "unsloth/gemma-2-2b-it"

tests/models/test_modeling_llama.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
import unittest
2323

2424
import pytest
25-
from executorch import version as executorch_version
2625
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
27-
from packaging import version as pkg_version
2826
from transformers import AutoTokenizer
2927
from transformers.testing_utils import slow
3028

@@ -91,9 +89,6 @@ def test_llama3_2_1b_text_generation(self):
9189
@slow
9290
@pytest.mark.run_slow
9391
def test_llama_text_generation_with_custom_sdpa(self):
94-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
95-
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
96-
9792
# ExecuTorch model + custom sdpa
9893
model_id = "NousResearch/Llama-3.2-1B"
9994
model = ExecuTorchModelForCausalLM.from_pretrained(

tests/models/test_modeling_olmo.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import unittest
2222

2323
import pytest
24-
from executorch import version as executorch_version
2524
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26-
from packaging import version as pkg_version
2725
from transformers import AutoTokenizer
2826
from transformers.testing_utils import slow
2927

@@ -77,9 +75,6 @@ def test_olmo_text_generation_with_xnnpack(self):
7775
@slow
7876
@pytest.mark.run_slow
7977
def test_olmo_text_generation_with_custom_sdpa(self):
80-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
81-
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
82-
8378
# ExecuTorch model + custom sdpa
8479
model_id = "allenai/OLMo-1B-hf"
8580
model = ExecuTorchModelForCausalLM.from_pretrained(

tests/models/test_modeling_qwen2.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import unittest
2222

2323
import pytest
24-
from executorch import version as executorch_version
2524
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26-
from packaging import version as pkg_version
2725
from transformers import AutoTokenizer
2826
from transformers.testing_utils import slow
2927

@@ -80,9 +78,6 @@ def test_qwen2_5_text_generation(self):
8078
@slow
8179
@pytest.mark.run_slow
8280
def test_qwen2_5_text_generation_with_custom_sdpa(self):
83-
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
84-
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
85-
8681
model_id = "Qwen/Qwen2.5-0.5B"
8782
prompt = "My favourite condiment is "
8883
max_seq_len = 32

0 commit comments

Comments
 (0)