Skip to content

Commit 92541cc

Browse files
Guang Yangguangy10
authored andcommitted
Introduce 8da4w quant for decoder-only text models
1 parent b44e2d1 commit 92541cc

File tree

7 files changed

+165
-2
lines changed

7 files changed

+165
-2
lines changed

.github/workflows/test_models.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ jobs:
5252
- name: Install dependencies for ExecuTorch
5353
run: |
5454
if [ "${{ matrix.executorch-version }}" == "nightly" ]; then
55-
export NIGHTLY_VERSION=dev20250413
55+
export NIGHTLY_VERSION=dev20250422
5656
pip install executorch==0.7.0.${NIGHTLY_VERSION} \
5757
torch==2.8.0.${NIGHTLY_VERSION} \
5858
torchvision==0.22.0.${NIGHTLY_VERSION} \
5959
torchaudio==2.6.0.${NIGHTLY_VERSION} \
60+
torchao==0.11.0.${NIGHTLY_VERSION} \
6061
--extra-index-url "https://download.pytorch.org/whl/nightly/cpu"
6162
else
6263
pip install executorch==${{ matrix.executorch-version }}

optimum/commands/export/executorch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def parse_args_executorch(parser):
5757
action="store_true",
5858
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
5959
)
60+
required_group.add_argument(
61+
"-q",
62+
"--quantize",
63+
required=False,
64+
choices=["8da4w"],
65+
help="Quantization recipe to use. Defaults to None.",
66+
)
6067

6168

6269
class ExecuTorchExportCommand(BaseOptimumCLICommand):
@@ -72,6 +79,8 @@ def run(self):
7279
kwargs = {}
7380
if self.args.use_custom_sdpa:
7481
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
82+
if self.args.quantize:
83+
kwargs["quantize"] = self.args.quantize
7584

7685
main_export(
7786
model_name_or_path=self.args.model,

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import logging
1616
from typing import Dict, Union
1717

18+
from tabulate import tabulate
1819
from torch.export import ExportedProgram
1920

2021
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
22+
from executorch.devtools.backend_debug import get_delegation_info
2123
from executorch.exir import (
2224
EdgeCompileConfig,
2325
ExecutorchBackendConfig,
@@ -73,7 +75,15 @@ def _lower_to_executorch(
7375
extract_delegate_segments=True,
7476
),
7577
)
76-
logging.debug(f"Exported program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}")
78+
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
79+
logging.debug(
80+
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
81+
)
82+
delegation_info = get_delegation_info(et_progs[pte_name].exported_program().graph_module)
83+
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
84+
logging.debug(
85+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
86+
)
7787
return et_progs
7888

7989
exported_progs = model.export()

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import torch
16+
import torchao
17+
from packaging.version import parse
1518
from transformers import AutoModelForCausalLM, GenerationConfig
1619

1720
from ..integrations import CausalLMExportableModule
@@ -54,12 +57,14 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5457
cache_implementation = kwargs.get("cache_implementation", "static")
5558
max_length = kwargs.get("max_length", 2048)
5659
config = kwargs.get("config", None)
60+
quantization_recipe = kwargs.get("quantize", None)
5761

5862
eager_model = AutoModelForCausalLM.from_pretrained(
5963
model_name_or_path,
6064
device_map=device,
6165
torch_dtype=dtype,
6266
config=config,
67+
# quantization_config=quantization_config,
6368
attn_implementation=attn_implementation,
6469
generation_config=GenerationConfig(
6570
use_cache=True,
@@ -71,4 +76,25 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7176
},
7277
),
7378
)
79+
80+
if quantization_recipe == "8da4w":
81+
if parse(torchao.__version__) < parse("0.11.0.dev0"):
82+
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")
83+
84+
from torchao.quantization.granularity import PerGroup
85+
from torchao.quantization.quant_api import (
86+
Int8DynamicActivationIntxWeightConfig,
87+
)
88+
89+
# TODO: Should switch to TorchAoConfig once the quant issue on final lm_head layer is fixed.
90+
linear_config = Int8DynamicActivationIntxWeightConfig(
91+
weight_dtype=torch.int4,
92+
weight_granularity=PerGroup(128),
93+
)
94+
95+
torchao.quantize_(
96+
eager_model,
97+
linear_config,
98+
)
99+
74100
return CausalLMExportableModule(eager_model)

tests/models/test_modeling_gemma3.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import unittest
2222

2323
import pytest
24+
import torchao
2425
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26+
from packaging.version import parse
2527
from transformers import AutoTokenizer
2628
from transformers.testing_utils import slow
2729

@@ -153,3 +155,40 @@ def test_gemma3_text_generation_with_custom_sdpa_float16(self):
153155
gc.collect()
154156

155157
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
158+
159+
@slow
160+
@pytest.mark.run_slow
161+
@pytest.mark.skipif(
162+
parse(torchao.__version__) < parse("0.11.0.dev0"),
163+
reason="Only available on torchao >= 0.11.0.dev0",
164+
)
165+
def test_gemma3_text_generation_with_custom_sdpa_8da4w(self):
166+
model_id = "google/gemma-3-1b-it"
167+
prompt = "Write a poem about a machine learning."
168+
tokenizer = AutoTokenizer.from_pretrained(model_id)
169+
kwargs = {"quantize": "8da4w"}
170+
171+
# ExecuTorch model + custom sdpa + float16
172+
model = ExecuTorchModelForCausalLM.from_pretrained(
173+
model_id,
174+
recipe="xnnpack",
175+
attn_implementation="custom_sdpa",
176+
**kwargs,
177+
)
178+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
179+
self.assertIsInstance(model.model, ExecuTorchModule)
180+
181+
generated_text = model.text_generation(
182+
tokenizer=tokenizer,
183+
prompt=prompt,
184+
max_seq_len=64,
185+
)
186+
logging.info(f"\nGenerated text:\n\t{generated_text}")
187+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
188+
189+
# Free memory before loading eager for quality check
190+
del model
191+
del tokenizer
192+
gc.collect()
193+
194+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

tests/models/test_modeling_qwen3.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import unittest
2222

2323
import pytest
24+
import torchao
2425
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26+
from packaging.version import parse
2527
from transformers import AutoTokenizer
2628
from transformers.testing_utils import slow
2729

@@ -136,3 +138,39 @@ def test_qwen3_text_generation_with_custom_sdpa_float16(self):
136138
gc.collect()
137139

138140
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
141+
142+
@slow
143+
@pytest.mark.run_slow
144+
@pytest.mark.skipif(
145+
parse(torchao.__version__) < parse("0.11.0.dev0"),
146+
reason="Only available on torchao >= 0.11.0.dev0",
147+
)
148+
def test_qwen3_text_generation_with_custom_sdpa_8da4w(self):
149+
model_id = "Qwen/Qwen3-0.6B"
150+
prompt = "Give me a short introduction to large language model."
151+
tokenizer = AutoTokenizer.from_pretrained(model_id)
152+
153+
# ExecuTorch model + custom sdpa
154+
kwargs = {"quantize": "8da4w"}
155+
model = ExecuTorchModelForCausalLM.from_pretrained(
156+
model_id,
157+
recipe="xnnpack",
158+
attn_implementation="custom_sdpa",
159+
**kwargs,
160+
)
161+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
162+
self.assertIsInstance(model.model, ExecuTorchModule)
163+
generated_text = model.text_generation(
164+
tokenizer=tokenizer,
165+
prompt=prompt,
166+
max_seq_len=64,
167+
)
168+
logging.info(f"\nGenerated text:\n\t{generated_text}")
169+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
170+
171+
# Free memory before loading eager for quality check
172+
del model
173+
del tokenizer
174+
gc.collect()
175+
176+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

tests/models/test_modeling_smollm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import unittest
2222

2323
import pytest
24+
import torchao
2425
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26+
from packaging.version import parse
2527
from transformers import AutoTokenizer
2628
from transformers.testing_utils import slow
2729

@@ -106,3 +108,41 @@ def test_smollm_text_generation_with_custom_sdpa(self):
106108
gc.collect()
107109

108110
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
111+
112+
@slow
113+
@pytest.mark.run_slow
114+
@pytest.mark.skipif(
115+
parse(torchao.__version__) < parse("0.11.0.dev0"),
116+
reason="Only available on torchao >= 0.11.0.dev0",
117+
)
118+
def test_smollm_text_generation_with_custom_sdpa_8da4w(self):
119+
model_id = "HuggingFaceTB/SmolLM2-135M"
120+
prompt = "My favourite condiment is "
121+
max_seq_len = 32
122+
tokenizer = AutoTokenizer.from_pretrained(model_id)
123+
124+
# ExecuTorch model + custom sdpa
125+
kwargs = {"quantize": "8da4w"}
126+
model = ExecuTorchModelForCausalLM.from_pretrained(
127+
model_id,
128+
recipe="xnnpack",
129+
attn_implementation="custom_sdpa",
130+
**kwargs,
131+
)
132+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
133+
self.assertIsInstance(model.model, ExecuTorchModule)
134+
135+
generated_text = model.text_generation(
136+
tokenizer=tokenizer,
137+
prompt=prompt,
138+
max_seq_len=max_seq_len,
139+
)
140+
logging.info(f"\nGenerated text:\n\t{generated_text}")
141+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
142+
143+
# Free memory before loading eager for quality check
144+
del model
145+
del tokenizer
146+
gc.collect()
147+
148+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

0 commit comments

Comments
 (0)