Skip to content

Commit d0e3764

Browse files
committed
Enable prefill for running CausalLM using ET runtime
1 parent 6dc9aa2 commit d0e3764

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

optimum/executorch/modeling.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ def forward(
623623
torch.Tensor: Logits output from the model.
624624
"""
625625
self.stats.on_model_execution_start()
626+
print(f"DEBUG: {self.model.method_meta('forward')}")
626627
logits = self.model.forward((input_ids, cache_position))[0]
627628
self.stats.on_model_execution_end()
628629
return logits
@@ -667,14 +668,12 @@ def generate(
667668
max_seq_len = self.max_cache_size
668669
generated_tokens = []
669670

670-
# prefill
671-
for i, prompt_token in enumerate(prompt_tokens):
672-
self.stats.on_sampling_begin()
673-
logits = self.forward(
674-
input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0),
675-
cache_position=torch.tensor([i], dtype=torch.long, device=self.device),
676-
)
677-
self.stats.on_sampling_end()
671+
self.stats.on_sampling_begin()
672+
logits = self.forward(
673+
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
674+
cache_position=torch.tensor([0], dtype=torch.long, device=self.device),
675+
)
676+
self.stats.on_sampling_end()
678677

679678
self.stats.on_prompt_eval_end()
680679
first_token_generated = False

optimum/exporters/executorch/integrations.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict
15+
from typing import Dict, Optional
1616

1717
import torch
1818
from torch.export import ExportedProgram
@@ -43,7 +43,13 @@ def __init__(self, model):
4343
self.config = model.config
4444
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
4545

46-
def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]:
46+
def export(
47+
self,
48+
input_ids=None,
49+
cache_position=None,
50+
dynamic_shapes: Optional[dict] = None,
51+
strict: Optional[bool] = None,
52+
) -> Dict[str, ExportedProgram]:
4753
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
4854
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
4955

@@ -57,13 +63,17 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
5763
exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len)
5864

5965
with torch.no_grad():
60-
exported_program = exportable_module.export(example_input_ids, example_cache_position)
66+
exported_program = exportable_module.export(
67+
example_input_ids, example_cache_position, dynamic_shapes, strict
68+
)
6169
else:
6270
from transformers.integrations.executorch import (
6371
convert_and_export_with_cache,
6472
)
6573

66-
exported_program = convert_and_export_with_cache(self.model, example_input_ids, example_cache_position)
74+
exported_program = convert_and_export_with_cache(
75+
self.model, example_input_ids, example_cache_position, dynamic_shapes, strict
76+
)
6777

6878
return {"model": exported_program}
6979

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
from typing import Dict, Union
1717

18+
import torch
1819
from packaging.version import parse
1920
from tabulate import tabulate
2021
from torch.export import ExportedProgram
@@ -95,7 +96,18 @@ def _lower_to_executorch(
9596
)
9697
return et_progs
9798

98-
exported_progs = model.export()
99+
# Make the sequence length dim to be dynamic in orfer to leverage parallel prefill in ExecuTorch runtime.
100+
seq_length = 7
101+
input_ids = torch.zeros((1, seq_length), dtype=torch.long)
102+
cache_position = torch.tensor([0], dtype=torch.long)
103+
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": None}
104+
strict = parse(torch.__version__) != parse("2.7.0") # Due to bug https://github.com/pytorch/pytorch/issues/150994
105+
exported_progs = model.export(
106+
input_ids=input_ids,
107+
cache_position=cache_position,
108+
dynamic_shapes=dynamic_shapes,
109+
strict=strict,
110+
)
99111

100112
if model.config._attn_implementation == "custom_sdpa":
101113
# Sanity check to make sure the exported program contains the custom sdpa operator.

0 commit comments

Comments
 (0)