Skip to content

Commit

Permalink
let's not warn when someone is running a forward (#32176)
Browse files Browse the repository at this point in the history
* let's not warn when someone is running a foward without cache + self.training

* more models

* fixup
  • Loading branch information
ArthurZucker authored Jul 24, 2024
1 parent e0182f3 commit 8d2534c
Show file tree
Hide file tree
Showing 16 changed files with 33 additions and 17 deletions.
4 changes: 3 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,9 @@ def forward(

past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,9 @@ def forward(
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/gemma/diff_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

Expand All @@ -795,7 +797,9 @@ def forward(
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_legacy_cache = True
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def forward(
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
Expand Down

0 comments on commit 8d2534c

Please sign in to comment.