Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
f"This attribute is used to identify the corresponding AutoModel class."
)

for key, value in models.items():
setattr(self, key, value)
if len(models) == 1:
# For single PTE, always set the attr to "model"
setattr(self, "model", next(iter(models.values())))
else:
for key, value in models.items():
setattr(self, key, value)

self.stats = Stats()

Expand Down Expand Up @@ -570,8 +574,8 @@ class ExecuTorchModelForCausalLM(ExecuTorchModelBase):
Data type of the model parameters.
bos_token_id (`int`):
Beginning-of-sequence token ID.
eos_token_id (`int`):
End-of-sequence token ID.
eos_token_ids (`List[int]`):
End-of-sequence token IDs.
vocab_size (`int`):
Size of the model vocabulary.
"""
Expand All @@ -594,8 +598,10 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
self.dtype = self.model.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.model.run_method("get_bos_id")[0]
if "get_eos_id" in metadata:
self.eos_token_id = self.model.run_method("get_eos_id")[0]
for key in ("get_eos_id", "get_eos_ids"):
if key in metadata:
self.eos_token_ids = self.model.run_method(key)
break
if "get_vocab_size" in metadata:
self.vocab_size = self.model.run_method("get_vocab_size")[0]
if "use_sdpa_with_kv_cache" in metadata:
Expand Down Expand Up @@ -694,7 +700,7 @@ def generate(
next_token = torch.argmax(logits, dim=-1).item()
generated_tokens.append(next_token)

if next_token == self.eos_token_id:
if next_token in self.eos_token_ids:
break

self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens))
Expand Down Expand Up @@ -730,9 +736,9 @@ def text_generation(
raise ValueError(
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
)
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id:
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id not in self.eos_token_ids:
raise ValueError(
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}."
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must match with the model's eos_token_ids={self.eos_token_ids}."
)

# Reset stats for a new generation
Expand Down
22 changes: 22 additions & 0 deletions optimum/executorch/passes/remove_padding_idx_embedding_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemovePaddingIdxEmbeddingPass(ExportPass):
"""
An ExportPass that removes the `padding_idx` keyword argument
from all aten.embedding.default operator calls.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default:
# node.args[2] is the padding_idx
if len(node.args) == 3:
node.args = (node.args[0], node.args[1])
graph_module.recompile()
return PassResult(graph_module, True)
3 changes: 3 additions & 0 deletions optimum/exporters/executorch/recipes/xnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ExecutorchProgram,
to_edge_transform_and_lower,
)
from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass

from ..integrations import (
CausalLMExportableModule,
Expand Down Expand Up @@ -76,9 +77,11 @@ def _lower_to_executorch(
exported_program,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=metadata,
transform_passes=[RemovePaddingIdxEmbeddingPass()],
).to_executorch(
config=ExecutorchBackendConfig(**backend_config_dict),
)
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/executorch/tasks/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
),
)

for param in eager_model.parameters():
# Must disable gradient for quantized checkpoint
if isinstance(param, torchao.utils.TorchAOBaseTensor):
param.requires_grad = False

# TODO: Move quantization recipe out for better composability.
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
qlinear_config = kwargs.get("qlinear", None)
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,33 @@ def test_eager_text_generation_with_custom_sdpa(self):
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids))

def test_removing_padding_idx_embedding_pass(self):
class ModuleWithEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(10, 3, padding_idx=0)

def forward(self, x):
return self.emb(x) + torch.ops.aten.embedding.default(self.emb.weight, x, padding_idx=1)

test_model = ModuleWithEmbedding()
example_inputs = (torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]),)
exported_model = torch.export.export(test_model, example_inputs)

from executorch.exir import to_edge_transform_and_lower
from executorch.exir.dialects._ops import ops as exir_ops

from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass

et_model = to_edge_transform_and_lower(
exported_model,
transform_passes=[RemovePaddingIdxEmbeddingPass()],
)
self.assertTrue(
all(
len(node.args) < 3
for node in et_model.exported_program().graph_module.graph.nodes
if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default
)
)
102 changes: 101 additions & 1 deletion tests/models/test_modeling_phi4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@

import gc
import logging
import os
import unittest

import pytest
import torchao
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from packaging.version import parse
from transformers import AutoConfig, AutoTokenizer
from transformers.testing_utils import slow

Expand All @@ -27,13 +30,21 @@
from ..utils import check_causal_lm_output_quality


@pytest.mark.skip(reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

is_ci = os.environ.get("GITHUB_ACTIONS") == "true"


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
is_ci,
reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM",
)
def test_phi4_text_generation(self):
model_id = "microsoft/Phi-4-mini-instruct"
config = AutoConfig.from_pretrained(model_id)
Expand Down Expand Up @@ -61,3 +72,92 @@ def test_phi4_text_generation(self):
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
parse(torchao.__version__) < parse("0.11.0.dev0"),
reason="Only available on torchao >= 0.11.0.dev0",
)
def test_phi4_text_generation_with_quantized_pte_from_hub(self):
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
config = AutoConfig.from_pretrained(model_id)
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
# that function to avoid the data-dependent control flow.
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
config.rope_scaling["type"] = "default"
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id, recipe="xnnpack", config=config, file_name="phi4-mini-8da4w.pte"
)
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
self.assertIsInstance(model.model, ExecuTorchModule)

tokenizer = AutoTokenizer.from_pretrained(model_id)
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt="My favourite condiment is ",
max_seq_len=64,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")

if not is_ci:
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

self.assertTrue(
check_causal_lm_output_quality(
"microsoft/Phi-4-mini-instruct",
generated_tokens,
)
)

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
parse(torchao.__version__) < parse("0.11.0.dev0"),
reason="Only available on torchao >= 0.11.0.dev0",
)
def test_phi4_text_generation_with_quantized_ckp(self):
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
config = AutoConfig.from_pretrained(model_id)
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
# that function to avoid the data-dependent control flow.
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
config.rope_scaling["type"] = "default"
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id,
recipe="xnnpack",
config=config,
export=True,
)
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
self.assertIsInstance(model.model, ExecuTorchModule)

tokenizer = AutoTokenizer.from_pretrained(model_id)
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt="My favourite condiment is ",
max_seq_len=64,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")

if not is_ci:
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

self.assertTrue(
check_causal_lm_output_quality(
"microsoft/Phi-4-mini-instruct",
generated_tokens,
)
)
Loading