diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 38efc95b84eff7..8dac6d77bf45f9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1946,7 +1946,7 @@ def contrastive_search( ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape + batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) @@ -2397,7 +2397,7 @@ def greedy_search( ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape + batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) @@ -2688,7 +2688,7 @@ def sample( ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape + batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) @@ -4482,7 +4482,7 @@ def assisted_decoding( ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape + batch_size, cur_len = batch_size, cur_len = model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3b9e4f56aa262b..bce2e0d5d02d32 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1285,11 +1285,11 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: - cache_position = cache_position[-input_ids.shape[1] :] + cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None