Skip to content

Commit e2f4cd2

Browse files
author
Guang Yang
committed
fix generate by extracting seq_len from the method meta
1 parent 49d4b1a commit e2f4cd2

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

optimum/executorch/modeling.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import torch
2525
from huggingface_hub import hf_hub_download
2626
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
27-
from packaging.version import parse
2827
from transformers import (
2928
AutoModelForCausalLM,
3029
AutoModelForImageClassification,
@@ -37,7 +36,6 @@
3736
)
3837
from transformers.utils import is_offline_mode
3938

40-
from executorch import version as executorch_version
4139
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
4240
from executorch.kernels import quantized # noqa
4341

@@ -676,10 +674,20 @@ def generate(
676674
)
677675
max_seq_len = self.max_cache_size
678676
generated_tokens = []
677+
seq_len = self.model.method_meta("forward").input_tensor_meta(1).sizes()[0]
679678

680-
if parse(executorch_version.__version__).base_version <= "0.6.0":
681-
# TODO: Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
682-
# We can remove this block once the executorch runtime supports `cache_position`.
679+
if seq_len > 1:
680+
# The model is exported with dynamic shapes. Can support parallel prefill.
681+
self.stats.on_sampling_begin()
682+
logits = self.forward(
683+
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
684+
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
685+
)
686+
self.stats.on_sampling_end()
687+
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
688+
else:
689+
# Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
690+
# TODO: We can remove this block once the executorch runtime supports `cache_position`.
683691
for i, prompt_token in enumerate(prompt_tokens):
684692
self.stats.on_sampling_begin()
685693
logits = self.forward(
@@ -688,14 +696,6 @@ def generate(
688696
)
689697
self.stats.on_sampling_end()
690698
next_token = torch.argmax(logits, dim=-1).item()
691-
else:
692-
self.stats.on_sampling_begin()
693-
logits = self.forward(
694-
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
695-
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
696-
)
697-
self.stats.on_sampling_end()
698-
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
699699
self.stats.on_prompt_eval_end()
700700
first_token_generated = False
701701

0 commit comments

Comments
 (0)