2424import torch
2525from huggingface_hub import hf_hub_download
2626from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
27- from packaging .version import parse
2827from transformers import (
2928 AutoModelForCausalLM ,
3029 AutoModelForImageClassification ,
3736)
3837from transformers .utils import is_offline_mode
3938
40- from executorch import version as executorch_version
4139from executorch .extension .pybindings .portable_lib import ExecuTorchModule , _load_for_executorch
4240from executorch .kernels import quantized # noqa
4341
@@ -676,10 +674,20 @@ def generate(
676674 )
677675 max_seq_len = self .max_cache_size
678676 generated_tokens = []
677+ seq_len = self .model .method_meta ("forward" ).input_tensor_meta (1 ).sizes ()[0 ]
679678
680- if parse (executorch_version .__version__ ).base_version <= "0.6.0" :
681- # TODO: Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
682- # We can remove this block once the executorch runtime supports `cache_position`.
679+ if seq_len > 1 :
680+ # The model is exported with dynamic shapes. Can support parallel prefill.
681+ self .stats .on_sampling_begin ()
682+ logits = self .forward (
683+ input_ids = torch .tensor (prompt_tokens , dtype = torch .long , device = self .device ).unsqueeze (0 ),
684+ cache_position = torch .arange (len (prompt_tokens ), dtype = torch .long , device = self .device ),
685+ )
686+ self .stats .on_sampling_end ()
687+ next_token = torch .argmax (logits , dim = - 1 )[0 , - 1 ].item ()
688+ else :
689+ # Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
690+ # TODO: We can remove this block once the executorch runtime supports `cache_position`.
683691 for i , prompt_token in enumerate (prompt_tokens ):
684692 self .stats .on_sampling_begin ()
685693 logits = self .forward (
@@ -688,14 +696,6 @@ def generate(
688696 )
689697 self .stats .on_sampling_end ()
690698 next_token = torch .argmax (logits , dim = - 1 ).item ()
691- else :
692- self .stats .on_sampling_begin ()
693- logits = self .forward (
694- input_ids = torch .tensor (prompt_tokens , dtype = torch .long , device = self .device ).unsqueeze (0 ),
695- cache_position = torch .arange (len (prompt_tokens ), dtype = torch .long , device = self .device ),
696- )
697- self .stats .on_sampling_end ()
698- next_token = torch .argmax (logits , dim = - 1 )[0 , - 1 ].item ()
699699 self .stats .on_prompt_eval_end ()
700700 first_token_generated = False
701701
0 commit comments