Skip to content

Commit 608733c

Browse files
author
Guang Yang
committed
fix generate by extracting seq_len from the method meta
1 parent 49d4b1a commit 608733c

File tree

2 files changed

+25
-40
lines changed

2 files changed

+25
-40
lines changed

optimum/executorch/modeling.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import torch
2525
from huggingface_hub import hf_hub_download
2626
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
27-
from packaging.version import parse
2827
from transformers import (
2928
AutoModelForCausalLM,
3029
AutoModelForImageClassification,
@@ -37,7 +36,6 @@
3736
)
3837
from transformers.utils import is_offline_mode
3938

40-
from executorch import version as executorch_version
4139
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
4240
from executorch.kernels import quantized # noqa
4341

@@ -676,10 +674,20 @@ def generate(
676674
)
677675
max_seq_len = self.max_cache_size
678676
generated_tokens = []
677+
seq_len = self.model.method_meta("forward").input_tensor_meta(1).sizes()[0]
679678

680-
if parse(executorch_version.__version__).base_version <= "0.6.0":
681-
# TODO: Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
682-
# We can remove this block once the executorch runtime supports `cache_position`.
679+
if seq_len > 1:
680+
# The model is exported with dynamic shapes. Can support parallel prefill.
681+
self.stats.on_sampling_begin()
682+
logits = self.forward(
683+
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
684+
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
685+
)
686+
self.stats.on_sampling_end()
687+
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
688+
else:
689+
# Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
690+
# TODO: We can remove this block once the executorch runtime supports `cache_position`.
683691
for i, prompt_token in enumerate(prompt_tokens):
684692
self.stats.on_sampling_begin()
685693
logits = self.forward(
@@ -688,14 +696,6 @@ def generate(
688696
)
689697
self.stats.on_sampling_end()
690698
next_token = torch.argmax(logits, dim=-1).item()
691-
else:
692-
self.stats.on_sampling_begin()
693-
logits = self.forward(
694-
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
695-
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
696-
)
697-
self.stats.on_sampling_end()
698-
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
699699
self.stats.on_prompt_eval_end()
700700
first_token_generated = False
701701

tests/models/test_modeling_phi4.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,10 @@
1616
import gc
1717
import logging
1818
import os
19-
import sys
2019
import unittest
2120

2221
import pytest
23-
import torchao
24-
import transformers
2522
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26-
from packaging.version import parse
2723
from transformers import AutoConfig, AutoTokenizer
2824
from transformers.testing_utils import slow
2925

@@ -33,8 +29,8 @@
3329

3430

3531
os.environ["TOKENIZERS_PARALLELISM"] = "false"
32+
3633
is_ci = os.environ.get("GITHUB_ACTIONS") == "true"
37-
is_linux_ci = sys.platform.startswith("linux") and os.environ.get("GITHUB_ACTIONS") == "true"
3834

3935

4036
class ExecuTorchModelIntegrationTest(unittest.TestCase):
@@ -44,47 +40,36 @@ def __init__(self, *args, **kwargs):
4440
@slow
4541
@pytest.mark.run_slow
4642
@pytest.mark.skipif(
47-
is_linux_ci
48-
or parse(transformers.__version__) < parse("4.52.0")
49-
or parse(torchao.__version__) < parse("0.11.0"),
50-
reason="Only available on transformers >= 4.52.0 and torchao >= 0.11.0. OOM on linux runner.",
43+
is_ci,
44+
reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM",
5145
)
52-
def test_phi4_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
46+
def test_phi4_text_generation(self):
5347
model_id = "microsoft/Phi-4-mini-instruct"
5448
config = AutoConfig.from_pretrained(model_id)
5549
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
5650
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
5751
# that function to avoid the data-dependent control flow.
5852
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
5953
config.rope_scaling["type"] = "default"
60-
model = ExecuTorchModelForCausalLM.from_pretrained(
61-
model_id,
62-
recipe="xnnpack",
63-
config=config,
64-
attn_implementation="custom_sdpa",
65-
use_custom_kv_cache=True,
66-
**{"qlinear": True, "qembeeding": True},
67-
)
54+
model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack", config=config)
6855
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
6956
self.assertIsInstance(model.model, ExecuTorchModule)
7057

7158
tokenizer = AutoTokenizer.from_pretrained(model_id)
7259
generated_text = model.text_generation(
7360
tokenizer=tokenizer,
7461
prompt="My favourite condiment is ",
75-
max_seq_len=64,
62+
max_seq_len=32,
7663
)
7764
logging.info(f"\nGenerated text:\n\t{generated_text}")
65+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
7866

79-
if not is_ci:
80-
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
81-
82-
# Free memory before loading eager for quality check
83-
del model
84-
del tokenizer
85-
gc.collect()
67+
# Free memory before loading eager for quality check
68+
del model
69+
del tokenizer
70+
gc.collect()
8671

87-
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
72+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
8873

8974
@slow
9075
@pytest.mark.run_slow

0 commit comments

Comments
 (0)