From 10360b30c33bcf386bda0fd549002c6894c31744 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 6 Mar 2024 16:34:20 +0000 Subject: [PATCH] derp --- src/transformers/generation/utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bea4b383870036..a0c58749c33743 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1932,7 +1932,9 @@ def _contrastive_search( # keep track of which sequences are already finished batch_size, cur_len = ( - model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape ) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) @@ -2394,7 +2396,9 @@ def _greedy_search( # keep track of which sequences are already finished batch_size, cur_len = ( - model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape ) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) @@ -2696,7 +2700,9 @@ def _sample( # keep track of which sequences are already finished batch_size, cur_len = ( - model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape ) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) @@ -4537,7 +4543,9 @@ def _assisted_decoding( # keep track of which sequences are already finished batch_size, cur_len = batch_size, cur_len = ( - model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape ) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)