Skip to content

Commit f735b5f

Browse files
author
Guang Yang
committed
export cache_position dynamically
1 parent 22ea304 commit f735b5f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

optimum/exporters/executorch/integrations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def export(
7575
)
7676

7777
with torch.no_grad():
78-
exported_program = exportable_module.export(example_input_ids, example_cache_position)
78+
exported_program = exportable_module.export(
79+
example_input_ids, example_cache_position, dynamic_shapes, strict
80+
)
7981
# Apply RemoveTransposes pass to remove
8082
# any back-to-back transpose ops that are not needed
8183
# e.g. output of update_cache is transposed and

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ def _lower_to_executorch(
9797
return et_progs
9898

9999
# Make the sequence length dim to be dynamic in orfer to leverage parallel prefill in ExecuTorch runtime.
100-
seq_length = 7
100+
seq_length = 3
101101
input_ids = torch.zeros((1, seq_length), dtype=torch.long)
102-
cache_position = torch.tensor([0], dtype=torch.long)
103-
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": None}
102+
cache_position = torch.tensor([0, 1, 2], dtype=torch.long).unsqueeze(0) # llama runner expects cache_pos to be 2d
103+
seq_len_dim = torch.export.Dim("seq_length_dim", max=128 - 1)
104+
dynamic_shapes = {"input_ids": {1: seq_len_dim}, "cache_position": {1: seq_len_dim}}
104105
strict = parse(torch.__version__) != parse("2.7.0") # Due to bug https://github.com/pytorch/pytorch/issues/150994
105106
exported_progs = model.export(
106107
input_ids=input_ids,

0 commit comments

Comments
 (0)