Skip to content

Commit

Permalink
derp
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Mar 14, 2024
1 parent 8c29e49 commit 10360b3
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,9 @@ def _contrastive_search(

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -2394,7 +2396,9 @@ def _greedy_search(

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -2696,7 +2700,9 @@ def _sample(

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -4537,7 +4543,9 @@ def _assisted_decoding(

# keep track of which sequences are already finished
batch_size, cur_len = batch_size, cur_len = (
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down

0 comments on commit 10360b3

Please sign in to comment.