File tree Expand file tree Collapse file tree 3 files changed +4
-8
lines changed Expand file tree Collapse file tree 3 files changed +4
-8
lines changed Original file line number Diff line number Diff line change @@ -207,8 +207,8 @@ def _validate_update_cache_params(
207207 1
208208 ), f"Start position { start_pos } must be less than sequence length { cache .size (1 )} "
209209
210- torch ._check ((start_pos + seq_len ) < cache .size (1 ))
211- assert (start_pos + seq_len ) < cache .size (
210+ torch ._check ((start_pos + seq_len ) <= cache .size (1 ))
211+ assert (start_pos + seq_len ) <= cache .size (
212212 1
213213 ), f"Start position + length = { start_pos + seq_len } must be less than sequence length { cache .size (1 )} "
214214
Original file line number Diff line number Diff line change @@ -144,12 +144,8 @@ def __init__(
144144 else :
145145 # Two input arguments: tokens and input_pos but input_pos is static shape.
146146
147- # A runtime assertion is added by torch.ops.llama.update_cache requires that
148- # L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len
149- # This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1
150- # run with TORCH_LOGS=+dynamic for details
151147 self .dynamic_shapes = (
152- {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
148+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
153149 {"input_pos" : {0 : 1 }},
154150 )
155151
Original file line number Diff line number Diff line change @@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888 # Check first element (tokens dimension)
8989 self .assertIsInstance (result [0 ], dict )
9090 self .assertIn (1 , result [0 ])
91- self .assertEqual (result [0 ][1 ].max , self .max_seq_len - 1 )
91+ self .assertEqual (result [0 ][1 ].max , self .max_seq_len )
9292
9393 # Check second element (input_pos dimension)
9494 self .assertIsInstance (result [1 ], dict )
You can’t perform that action at this time.
0 commit comments