Skip to content

Commit

Permalink
handle corner cases
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Mar 5, 2024
1 parent fa1b49f commit f5c91b9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,7 +1946,7 @@ def contrastive_search(
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down Expand Up @@ -2397,7 +2397,7 @@ def greedy_search(
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down Expand Up @@ -2688,7 +2688,7 @@ def sample(
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down Expand Up @@ -4482,7 +4482,7 @@ def assisted_decoding(
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
batch_size, cur_len = batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,11 +1285,11 @@ def prepare_inputs_for_generation(
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
else:
cache_position = cache_position[-input_ids.shape[1] :]
cache_position = cache_position[-input_length:]

if has_static_cache:
past_key_values = None
Expand Down

0 comments on commit f5c91b9

Please sign in to comment.