Skip to content

Commit 9ff47a7

Browse files
authored
Fix condition for emitting warning when generation exceeds max model length (#40775)
correct warning when generation exceeds max model length Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
1 parent ae9ef2e commit 9ff47a7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/generation/stopping_criteria.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = Non
7676
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
7777
cur_len = input_ids.shape[1]
7878
is_done = cur_len >= self.max_length
79-
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
79+
if self.max_position_embeddings is not None and not is_done and cur_len > self.max_position_embeddings:
8080
logger.warning_once(
81-
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
81+
"This is a friendly reminder - the current text generation call has exceeded the model's predefined "
8282
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
8383
"exceptions, performance degradation, or nothing at all."
8484
)

0 commit comments

Comments
 (0)