Skip to content

Commit 5413b89

Browse files
authored
KV cache is no longer a model attribute (#30730)
kv_cache is no longer a model attribute
1 parent 218f441 commit 5413b89

File tree

5 files changed

+0
-28
lines changed

5 files changed

+0
-28
lines changed

src/transformers/models/cohere/modeling_cohere.py

-6
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def forward(
271271
key_states = key_states.transpose(1, 2)
272272
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
273273

274-
past_key_value = getattr(self, "past_key_value", past_key_value)
275274
cos, sin = self.rotary_emb(value_states, position_ids)
276275
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
277276

@@ -365,8 +364,6 @@ def forward(
365364
cos, sin = self.rotary_emb(value_states, position_ids)
366365
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
367366

368-
past_key_value = getattr(self, "past_key_value", past_key_value)
369-
370367
if past_key_value is not None:
371368
# sin and cos are specific to RoPE models; position_ids needed for the static cache
372369
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -571,9 +568,6 @@ def forward(
571568
cos, sin = self.rotary_emb(value_states, position_ids)
572569
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
573570

574-
# In case static cache is used, it is an instance attribute.
575-
past_key_value = getattr(self, "past_key_value", past_key_value)
576-
577571
if past_key_value is not None:
578572
# sin and cos are specific to RoPE models; cache_position needed for the static cache
579573
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}

src/transformers/models/dbrx/modeling_dbrx.py

-5
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def forward(
287287
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
288288
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
289289

290-
past_key_value = getattr(self, "past_key_value", past_key_value)
291290
cos, sin = self.rotary_emb(value_states, position_ids)
292291
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
293292

@@ -387,8 +386,6 @@ def forward(
387386
cos, sin = self.rotary_emb(value_states, position_ids)
388387
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
389388

390-
past_key_value = getattr(self, "past_key_value", past_key_value)
391-
392389
if past_key_value is not None:
393390
# sin and cos are specific to RoPE models; cache_position needed for the static cache
394391
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -600,8 +597,6 @@ def forward(
600597
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
601598
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
602599

603-
past_key_value = getattr(self, "past_key_value", past_key_value)
604-
605600
if past_key_value is not None:
606601
# sin and cos are specific to RoPE models; cache_position needed for the static cache
607602
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}

src/transformers/models/gemma/modeling_gemma.py

-5
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def forward(
262262
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
263263
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
264264

265-
past_key_value = getattr(self, "past_key_value", past_key_value)
266265
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
267266
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
268267

@@ -353,8 +352,6 @@ def forward(
353352
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
354353
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
355354

356-
past_key_value = getattr(self, "past_key_value", past_key_value)
357-
358355
if past_key_value is not None:
359356
# sin and cos are specific to RoPE models; cache_position needed for the static cache
360357
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -552,8 +549,6 @@ def forward(
552549
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
553550
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
554551

555-
past_key_value = getattr(self, "past_key_value", past_key_value)
556-
557552
if past_key_value is not None:
558553
# sin and cos are specific to RoPE models; cache_position needed for the static cache
559554
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}

src/transformers/models/llama/modeling_llama.py

-6
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def forward(
356356
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
357357
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
358358

359-
past_key_value = getattr(self, "past_key_value", past_key_value)
360359
cos, sin = self.rotary_emb(value_states, position_ids)
361360
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
362361

@@ -452,8 +451,6 @@ def forward(
452451
cos, sin = self.rotary_emb(value_states, position_ids)
453452
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
454453

455-
past_key_value = getattr(self, "past_key_value", past_key_value)
456-
457454
if past_key_value is not None:
458455
# sin and cos are specific to RoPE models; cache_position needed for the static cache
459456
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -650,9 +647,6 @@ def forward(
650647
cos, sin = self.rotary_emb(value_states, position_ids)
651648
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
652649

653-
# In case static cache is used, it is an instance attribute.
654-
past_key_value = getattr(self, "past_key_value", past_key_value)
655-
656650
if past_key_value is not None:
657651
# sin and cos are specific to RoPE models; cache_position needed for the static cache
658652
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}

src/transformers/models/olmo/modeling_olmo.py

-6
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ def forward(
328328
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329329
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
330330

331-
past_key_value = getattr(self, "past_key_value", past_key_value)
332331
cos, sin = self.rotary_emb(value_states, position_ids)
333332
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
334333

@@ -419,8 +418,6 @@ def forward(
419418
cos, sin = self.rotary_emb(value_states, position_ids)
420419
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
421420

422-
past_key_value = getattr(self, "past_key_value", past_key_value)
423-
424421
if past_key_value is not None:
425422
# sin and cos are specific to RoPE models; cache_position needed for the static cache
426423
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -624,9 +621,6 @@ def forward(
624621
cos, sin = self.rotary_emb(value_states, position_ids)
625622
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
626623

627-
# In case static cache is used, it is an instance attribute.
628-
past_key_value = getattr(self, "past_key_value", past_key_value)
629-
630624
if past_key_value is not None:
631625
# sin and cos are specific to RoPE models; cache_position needed for the static cache
632626
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}

0 commit comments

Comments
 (0)