Skip to content

Commit f1e2548

Browse files
authored
Fix max seq length bug
Differential Revision: D84562463 Pull Request resolved: #15084
1 parent 6e0c9f6 commit f1e2548

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def _validate_update_cache_params(
207207
1
208208
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209209

210-
torch._check((start_pos + seq_len) < cache.size(1))
211-
assert (start_pos + seq_len) < cache.size(
210+
torch._check((start_pos + seq_len) <= cache.size(1))
211+
assert (start_pos + seq_len) <= cache.size(
212212
1
213213
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
214214

extension/llm/export/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,8 @@ def __init__(
144144
else:
145145
# Two input arguments: tokens and input_pos but input_pos is static shape.
146146

147-
# A runtime assertion is added by torch.ops.llama.update_cache requires that
148-
# L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len
149-
# This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1
150-
# run with TORCH_LOGS=+dynamic for details
151147
self.dynamic_shapes = (
152-
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
148+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
153149
{"input_pos": {0: 1}},
154150
)
155151

extension/llm/export/test/test_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
91+
self.assertEqual(result[0][1].max, self.max_seq_len)
9292

9393
# Check second element (input_pos dimension)
9494
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)