diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 37c90ee43fa5..5c7d27192292 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -225,29 +225,6 @@ outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=to tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` - -### Contrastive search - -[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied. - -Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return. - -```py -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device - -device = infer_device() - -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") -inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device) - -model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.float16).to(device) -# explicitly set to 100 because Llama2 generation length is 4096 -outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4) -tokenizer.batch_decode(outputs, skip_special_tokens=True) -'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has' -``` - ### Diverse beam search [Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups. diff --git a/docs/source/ja/generation_strategies.md b/docs/source/ja/generation_strategies.md index a93ef3d36440..856c4856c52f 100644 --- a/docs/source/ja/generation_strategies.md +++ b/docs/source/ja/generation_strategies.md @@ -168,29 +168,6 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te ['I look forward to seeing you all again!\n\n\n\n\n\n\n\n\n\n\n'] ``` -### Contrastive search - -コントラスティブ検索デコーディング戦略は、2022年の論文[A Contrastive Framework for Neural Text Generation](https://huggingface.co/papers/2202.06417)で提案されました。 -これは、非反復的でありながら一貫性のある長い出力を生成するために優れた結果を示しています。コントラスティブ検索の動作原理を学ぶには、[このブログポスト](https://huggingface.co/blog/introducing-csearch)をご覧ください。 -コントラスティブ検索の動作を有効にし、制御する2つの主要なパラメータは「penalty_alpha」と「top_k」です: - -```python ->>> from transformers import AutoTokenizer, AutoModelForCausalLM - ->>> checkpoint = "openai-community/gpt2-large" ->>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) ->>> model = AutoModelForCausalLM.from_pretrained(checkpoint) - ->>> prompt = "Hugging Face Company is" ->>> inputs = tokenizer(prompt, return_tensors="pt") - ->>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100) ->>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Hugging Face Company is a family owned and operated business. We pride ourselves on being the best -in the business and our customer service is second to none.\n\nIf you have any questions about our -products or services, feel free to contact us at any time. We look forward to hearing from you!'] -``` - ### Multinomial sampling 常に最高確率のトークンを次のトークンとして選択する貪欲検索とは異なり、多項分布サンプリング(または祖先サンプリングとも呼ばれます)はモデルによって提供される語彙全体の確率分布に基づいて次のトークンをランダムに選択します。ゼロ以外の確率を持つすべてのトークンには選択される可能性があり、これにより繰り返しのリスクが減少します。 diff --git a/docs/source/ko/generation_strategies.md b/docs/source/ko/generation_strategies.md index f45fea5b2280..da38e4f418f2 100644 --- a/docs/source/ko/generation_strategies.md +++ b/docs/source/ko/generation_strategies.md @@ -68,7 +68,7 @@ GenerationConfig { - `max_new_tokens`: 생성할 최대 토큰 수입니다. 즉, 프롬프트에 있는 토큰을 제외한 출력 시퀀스의 크기입니다. 출력의 길이를 중단 기준으로 사용하는 대신, 전체 생성물이 일정 시간을 초과할 때 생성을 중단하기로 선택할 수도 있습니다. 더 알아보려면 [`StoppingCriteria`]를 확인하세요. - `num_beams`: 1보다 큰 수의 빔을 지정함으로써, 탐욕 탐색(greedy search)에서 빔 탐색(beam search)으로 전환하게 됩니다. 이 전략은 각 시간 단계에서 여러 가설을 평가하고 결국 전체 시퀀스에 대해 가장 높은 확률을 가진 가설을 선택합니다. 이는 초기 토큰의 확률이 낮아 탐욕 탐색에 의해 무시되었을 높은 확률의 시퀀스를 식별할 수 있는 장점을 가집니다. - `do_sample`: 이 매개변수를 `True`로 설정하면, 다항 샘플링, 빔 탐색 다항 샘플링, Top-K 샘플링 및 Top-p 샘플링과 같은 디코딩 전략을 활성화합니다. 이러한 전략들은 전체 어휘에 대한 확률 분포에서 다음 토큰을 선택하며, 전략별로 특정 조정이 적용됩니다. -- `num_return_sequences`: 각 입력에 대해 반환할 시퀀스 후보의 수입니다. 이 옵션은 빔 탐색(beam search)의 변형과 샘플링과 같이 여러 시퀀스 후보를 지원하는 디코딩 전략에만 사용할 수 있습니다. 탐욕 탐색(greedy search)과 대조 탐색(contrastive search) 같은 디코딩 전략은 단일 출력 시퀀스를 반환합니다. +- `num_return_sequences`: 각 입력에 대해 반환할 시퀀스 후보의 수입니다. 이 옵션은 빔 탐색(beam search)의 변형과 샘플링과 같이 여러 시퀀스 후보를 지원하는 디코딩 전략에만 사용할 수 있습니다. 탐욕 탐색(greedy search) 같은 디코딩 전략은 단일 출력 시퀀스를 반환합니다. ## 모델에 사용자 정의 디코딩 전략 저장[[save-a-custom-decoding-strategy-with-your-model]] @@ -165,27 +165,6 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te ['I look forward to seeing you all again!\n\n\n\n\n\n\n\n\n\n\n'] ``` -### 대조 탐색(Contrastive search)[[contrastive-search]] - -2022년 논문 [A Contrastive Framework for Neural Text Generation](https://huggingface.co/papers/2202.06417)에서 제안된 대조 탐색 디코딩 전략은 반복되지 않으면서도 일관된 긴 출력을 생성하는 데 있어 우수한 결과를 보였습니다. 대조 탐색이 작동하는 방식을 알아보려면 [이 블로그 포스트](https://huggingface.co/blog/introducing-csearch)를 확인하세요. 대조 탐색의 동작을 가능하게 하고 제어하는 두 가지 주요 매개변수는 `penalty_alpha`와 `top_k`입니다: - -```python ->>> from transformers import AutoTokenizer, AutoModelForCausalLM - ->>> checkpoint = "openai-community/gpt2-large" ->>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) ->>> model = AutoModelForCausalLM.from_pretrained(checkpoint) - ->>> prompt = "Hugging Face Company is" ->>> inputs = tokenizer(prompt, return_tensors="pt") - ->>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100) ->>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Hugging Face Company is a family owned and operated business. We pride ourselves on being the best -in the business and our customer service is second to none.\n\nIf you have any questions about our -products or services, feel free to contact us at any time. We look forward to hearing from you!'] -``` - ### 다항 샘플링(Multinomial sampling)[[multinomial-sampling]] 탐욕 탐색(greedy search)이 항상 가장 높은 확률을 가진 토큰을 다음 토큰으로 선택하는 것과 달리, 다항 샘플링(multinomial sampling, 조상 샘플링(ancestral sampling)이라고도 함)은 모델이 제공하는 전체 어휘에 대한 확률 분포를 기반으로 다음 토큰을 무작위로 선택합니다. 0이 아닌 확률을 가진 모든 토큰은 선택될 기회가 있으므로, 반복의 위험을 줄일 수 있습니다. diff --git a/examples/pytorch/text-generation/run_generation_contrastive_search.py b/examples/pytorch/text-generation/run_generation_contrastive_search.py deleted file mode 100755 index 879229c062e3..000000000000 --- a/examples/pytorch/text-generation/run_generation_contrastive_search.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python -# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# /// script -# dependencies = [ -# "transformers @ git+https://github.com/huggingface/transformers.git", -# "accelerate >= 0.21.0", -# "sentencepiece != 0.1.92", -# "protobuf", -# "torch >= 1.3", -# ] -# /// - -"""The examples of running contrastive search on the auto-APIs; - -Running this example: -python run_generation_contrastive_search.py --model_name_or_path=openai-community/gpt2-large --penalty_alpha=0.6 --k=4 --length=256 -""" - -import argparse -import logging - -from accelerate import PartialState -from accelerate.utils import set_seed - -from transformers import AutoModelForCausalLM, AutoTokenizer - - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, -) -logger = logging.getLogger(__name__) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_name_or_path", - default=None, - type=str, - required=True, - ) - parser.add_argument("--prompt", type=str, default="") - parser.add_argument("--length", type=int, default=20) - parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") - parser.add_argument( - "--temperature", - type=float, - default=1.0, - help="temperature of 1.0 has no effect, lower tend toward greedy sampling", - ) - parser.add_argument( - "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" - ) - parser.add_argument("--k", type=int, default=0) - parser.add_argument("--penalty_alpha", type=float, default=0.0) - parser.add_argument("--p", type=float, default=0.9) - - parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") - parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") - parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") - - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") - parser.add_argument( - "--use_cpu", - action="store_true", - help="Whether or not to use cpu. If set to False, we will use gpu/npu or mps device if available", - ) - parser.add_argument( - "--fp16", - action="store_true", - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", - ) - args = parser.parse_args() - - # Initialize the distributed state. - distributed_state = PartialState(cpu=args.use_cpu) - - logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}") - - if args.seed is not None: - set_seed(args.seed) - - # Initialize the model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) - model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) - - # tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) - # model = OPTForCausalLM.from_pretrained(args.model_name_or_path) - # Set the model to the right device - model.to(distributed_state.device) - - if args.fp16: - model.half() - - logger.info(args) - prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") - - inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False) - inputs = {key: value.to(distributed_state.device) for key, value in inputs.items()} - - output_sequences = model.generate( - **inputs, - max_length=args.length + len(inputs["input_ids"][0]), - penalty_alpha=args.penalty_alpha, - top_k=args.k, - ) - - generated_sequences = [] - for generated_sequence_idx, generated_sequence in enumerate(output_sequences): - print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") - generated_sequence = generated_sequence.tolist() - - # Decode text - text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False) - - # Remove all text after the stop token - text = text[: text.find(args.stop_token) if args.stop_token else None] - - # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing - total_sequence = ( - prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :] - ) - - generated_sequences.append(total_sequence) - print(total_sequence) - - return generated_sequences - - -if __name__ == "__main__": - main() diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d1e1b441f67f..e1e38fc76da6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1358,7 +1358,7 @@ def check_dynamic_cache(self, method: str): def crop(self, maximum_length: int): """ Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub). """ self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) @@ -1378,13 +1378,13 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDec return out def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub).""" self.check_dynamic_cache(self.batch_repeat_interleave.__name__) self.self_attention_cache.batch_repeat_interleave(repeats) self.cross_attention_cache.batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub).""" self.check_dynamic_cache(self.batch_select_indices.__name__) self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 82332b9e7809..1edaf19948e8 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") STATIC_CACHE_IMPLEMENTATIONS = ("static", "offloaded_static") -DYNAMIC_CACHE_IMPLEMENTATIONS = ("dynamic", "offloaded", "quantized") +DYNAMIC_CACHE_IMPLEMENTATIONS = ("dynamic", "dynamic_full", "offloaded", "quantized") # All the following are redundant and deprecated, but kept for BC DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS = ( "sliding_window", @@ -86,7 +86,6 @@ class GenerationConfig(PushToHubMixin): for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - *greedy decoding* if `num_beams=1` and `do_sample=False` - - *contrastive search* if `penalty_alpha>0.` and `top_k>1` - *multinomial sampling* if `num_beams=1` and `do_sample=True` - *beam-search decoding* if `num_beams>1` and `do_sample=False` - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True` @@ -138,8 +137,6 @@ class GenerationConfig(PushToHubMixin): num_beam_groups (`int`, *optional*, defaults to 1): Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. [this paper](https://huggingface.co/papers/1610.02424) for more details. - penalty_alpha (`float`, *optional*): - The values balance the model confidence and the degeneration penalty in contrastive search decoding. > Parameters that control the cache @@ -255,9 +252,6 @@ class GenerationConfig(PushToHubMixin): The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer quality. - low_memory (`bool`, *optional*): - Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory. - Used with beam search and contrastive search. watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*): Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more @@ -366,8 +360,6 @@ def __init__(self, **kwargs): self.do_sample = kwargs.pop("do_sample", False) self.num_beams = kwargs.pop("num_beams", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1) - self.penalty_alpha = kwargs.pop("penalty_alpha", None) - self.dola_layers = kwargs.pop("dola_layers", None) # Parameters that control the cache self.use_cache = kwargs.pop("use_cache", True) @@ -403,7 +395,7 @@ def __init__(self, **kwargs): self.sequence_bias = kwargs.pop("sequence_bias", None) self.token_healing = kwargs.pop("token_healing", False) self.guidance_scale = kwargs.pop("guidance_scale", None) - self.low_memory = kwargs.pop("low_memory", None) + watermarking_config = kwargs.pop("watermarking_config", None) if watermarking_config is None: self.watermarking_config = None @@ -445,6 +437,11 @@ def __init__(self, **kwargs): self.compile_config = kwargs.pop("compile_config", None) self.disable_compile = kwargs.pop("disable_compile", False) + # Deprecated (moved to the Hub). TODO joao, manuel: remove in v4.62.0 + self.low_memory = kwargs.pop("low_memory", None) + self.penalty_alpha = kwargs.pop("penalty_alpha", None) + self.dola_layers = kwargs.pop("dola_layers", None) + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub # interface. self._from_model_config = kwargs.pop("_from_model_config", False) @@ -610,9 +607,7 @@ def validate(self, strict=False): minor_issues["typical_p"] = greedy_wrong_parameter_msg.format( flag_name="typical_p", flag_value=self.typical_p ) - if ( - self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None - ): # contrastive search uses top_k + if self.top_k is not None and self.top_k != 50: minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k) if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0: minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3e387c0808d6..e03ad600deb3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -43,7 +43,6 @@ from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.fsdp import is_fsdp_managed_module from ..masking_utils import create_masks_for_generate -from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie from ..utils import ( @@ -369,7 +368,6 @@ class GenerationMixin(ContinuousMixin): The class exposes [`~generation.GenerationMixin.generate`], which can be used for: - *greedy decoding* if `num_beams=1` and `do_sample=False` - - *contrastive search* if `penalty_alpha>0` and `top_k>1` - *multinomial sampling* if `num_beams=1` and `do_sample=True` - *beam-search decoding* if `num_beams>1` and `do_sample=False` - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True` @@ -1946,15 +1944,17 @@ def _prepare_cache_for_generation( ) generation_config.cache_implementation = None - # assisted decoding and contrastive search need to roll-back the Cache, which is not supported if - # it has sliding layers - so if we use any of those 2, do not pass the config to DynamicCache, which - # will result in creating a Cache with only full layers even if model uses sliding window + # Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers. + # To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache). + # The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own. generation_mode = generation_config.get_generation_mode(assistant_model) - dynamic_cache_kwargs = ( - {"config": self.config} - if generation_mode not in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH) - else {} - ) + if ( + generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH) + or generation_config.cache_implementation == "dynamic_full" + ): + dynamic_cache_kwargs = {} + else: + dynamic_cache_kwargs = {"config": self.config} if generation_config.cache_implementation is not None: if generation_config.cache_implementation in ALL_STATIC_CACHE_IMPLEMENTATIONS: if generation_config.cache_implementation in DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS: @@ -1995,7 +1995,7 @@ def _prepare_cache_for_generation( model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config) elif generation_config.cache_implementation == "offloaded": model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs, offloading=True) - elif generation_config.cache_implementation == "dynamic": + elif "dynamic" in generation_config.cache_implementation: model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs) # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that @@ -2512,29 +2512,26 @@ def generate( trust_remote_code=trust_remote_code, **kwargs, ) - + # TODO joao, manuel: remove this in v4.62.0 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: + logger.warning_once( + "Contrastive search was moved to a `custom_generate` repo: https://hf.co/transformers-community/contrastive-search. " + "To prevent loss of backward compatibility, add `custom_generate='transformers-community/contrastive-search'` " + "to your `generate` call before v4.62.0." + ) if not trust_remote_code: logger.warning_once( - "Contrastive Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " - "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." + "Contrastive search requires `trust_remote_code=True` in your `generate` call, since " + "it loads https://hf.co/transformers-community/contrastive-search." ) - if not model_kwargs["use_cache"]: - raise ValueError("Contrastive search requires `use_cache=True`") - if self._is_stateful: - # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) - raise ValueError( - f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" - ) - - result = self._contrastive_search( - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, + # Avoid calling the model-defined `generate` method, since some models (e.g. Janus, Whisper) override it. + return GenerationMixin.generate( + self, + inputs, + custom_generate="transformers-community/contrastive-search", generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, + trust_remote_code=trust_remote_code, + **kwargs, ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): @@ -2765,421 +2762,6 @@ def heal_tokens( return input_ids - @torch.no_grad() - def _contrastive_search( - self, - input_ids: torch.LongTensor, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **contrastive search** and can - be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) - top_k = generation_config.top_k - penalty_alpha = generation_config.penalty_alpha - pad_token_id = generation_config._pad_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - sequential = generation_config.low_memory - - # init attention / hidden states / scores tuples - raw_logits = () if (return_dict_in_generate and output_logits) else None - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape[:2] - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) - - # Create cosine_matrix_mask based on the attention_mask - cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) - if self.config.is_encoder_decoder: - if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: - cosine_matrix_mask = model_kwargs["decoder_attention_mask"] - else: - cosine_matrix_mask = model_kwargs["attention_mask"] - cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) - - this_peer_finished = False - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; - # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step - if model_kwargs.get("past_key_values") is None or ( - isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) - and model_kwargs["past_key_values"].get_seq_length() == 0 - ): - # prepare inputs - model_kwargs["use_cache"] = True - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save - # the `encoder_outputs` - outputs = self( - **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions - ) - - # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with - # previous tokens) - if self.config.is_encoder_decoder: - last_hidden_states = outputs.decoder_hidden_states[-1] - else: - last_hidden_states = outputs.hidden_states[-1] - - # next logit for contrastive search to select top-k candidate tokens - # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration - # (the clone itself is always small) - # torch.float32 is needed to retain precision for later logits manipulations - logit_for_next_step = outputs.logits[:, -1, :].to( - copy=True, dtype=torch.float32, device=input_ids.device - ) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - - if not sequential: - # Expands model inputs top_k times, for batched forward passes (akin to beam search). - # input_ids is required for expanding visual inputs in qwen2vl - _, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=top_k, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - past_key_values = model_kwargs.get("past_key_values") - if past_key_values is None: - raise ValueError( - f"{self.__class__.__name__} does not support caching and therefore **can't** be used " - "for contrastive search." - ) - elif ( - not isinstance(past_key_values[0], (tuple, torch.Tensor)) - or past_key_values[0][0].shape[0] != batch_size - ): - raise ValueError( - f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " - "used for contrastive search without further modifications." - ) - - # contrastive_search main logic start: - # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by - # degeneration penalty - processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) - next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) - - top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_logits: - raw_logits += (logit_for_next_step,) - if output_scores: - scores += (processed_logit_for_next_step,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # This is needed to properly delete outputs.logits which may be very large for this first iteration - # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() - del outputs - - if not sequential: - # Replicates the new past_key_values to match the `top_k` candidates - past = model_kwargs["past_key_values"] - # If it is a static cache, modify it in-place layer after layer to save memory - if isinstance(past, DynamicCache) or ( - isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) - ): - past.batch_repeat_interleave(top_k) - else: - new_key_values = [] - for layer in past: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(tuple(items)) - - past = tuple(new_key_values) - - model_kwargs["past_key_values"] = past - - if sequential: - all_outputs = [] - for i in range(top_k): - # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) - - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, - ) - if isinstance(outputs["past_key_values"], DynamicCache) or ( - isinstance(outputs["past_key_values"], EncoderDecoderCache) - and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) - ): - # Remove past K-V from output since we don't need to stack later - outputs["past_key_values"] = None - # Remove last token from past K-V since we don't want to append it at this point - model_kwargs["past_key_values"].crop(-1) - - all_outputs.append(outputs) - outputs = stack_model_outputs(all_outputs, self.config.get_text_config()) - - else: - # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, - ) - - # This is essential to avoid having a last reference to the big past K-V and double the necessary memory - # in the next loop - del next_model_inputs - - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - - # .float() is needed to retain precision for later logits manipulations - logits = outputs.logits[:, -1, :].float() - context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) - - # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the - # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't - # introduce (noticeable) slowdowns on single-device runs. - selected_idx = _ranking_fast( - context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k - ) - cosine_matrix_mask = torch.cat( - [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1 - ) - selected_idx = selected_idx.to("cpu") - - # This will be used instead of the previous inneficient torch.stack(torch.split()) - augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) - - # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing - # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores - # (model confidence minus degeneration penalty); (6) decoder hidden_states - next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] - next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) - next_hidden = next_hidden[range(batch_size), selected_idx, :] - last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) - - next_decoder_hidden_states = () - for layer in full_hidden_states: - layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] - next_decoder_hidden_states += (layer,) - - # generate past_key_values cache of only the selected token - if sequential: - next_model_input = self.prepare_inputs_for_generation( - top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs - ) - - selected_outputs = self( - **next_model_input, - return_dict=True, - output_hidden_states=False, - output_attentions=False, - ) - next_past_key_values = selected_outputs["past_key_values"] - - else: - next_past_key_values = None - for possible_cache_name in ALL_CACHE_NAMES: - next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None) - # Do it in-place layer per layer to save memory - if isinstance(next_past_key_values, DynamicCache) or ( - isinstance(next_past_key_values, EncoderDecoderCache) - and isinstance(next_past_key_values.self_attention_cache, DynamicCache) - ): - next_past_key_values.batch_select_indices(augmented_idx) - else: - new_key_values = [] - for layer in next_past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item[augmented_idx, ...]) - new_key_values.append(tuple(items)) - - next_past_key_values = tuple(new_key_values) - - logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] - logit_for_next_step = logit_for_next_step.to(input_ids.device) - - # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration - if self.config.is_encoder_decoder: - next_step_cross_attentions = () - next_step_decoder_attentions = () - if output_attentions: - for layer in outputs.cross_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] - next_step_cross_attentions += (layer,) - for layer in outputs.decoder_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] - next_step_decoder_attentions += (layer,) - outputs = Seq2SeqLMOutput( - past_key_values=next_past_key_values, - decoder_hidden_states=next_decoder_hidden_states, - decoder_attentions=next_step_decoder_attentions or None, - cross_attentions=next_step_cross_attentions or None, - ) - else: - next_step_attentions = () - if output_attentions: - for layer in outputs.attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] - next_step_attentions += (layer,) - outputs = CausalLMOutputWithPast( - past_key_values=next_past_key_values, - hidden_states=next_decoder_hidden_states, - attentions=next_step_attentions or None, - ) - # contrastive_search main logic end - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - continue - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - - # stop when each sentence is finished - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - # Contrastive search works by forward looking at the next token, so we need to exclude it from - # `past_key_values` to be consistent with the other decoding methods - if model_kwargs.get("past_key_values") is not None: - if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( - isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) - and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) - ): - model_kwargs["past_key_values"].crop(-1) - else: - past_key_values = [] - for layer in model_kwargs["past_key_values"]: - layer_past_key_values = [] - for item in layer: - layer_past_key_values.append(item[..., :-1, :]) - past_key_values.append(tuple(layer_past_key_values)) - model_kwargs["past_key_values"] = tuple(past_key_values) - - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids - def _sample( self, input_ids: torch.LongTensor, @@ -4873,37 +4455,6 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at return outputs -def _ranking_fast( - context_hidden: torch.FloatTensor, - next_hidden: torch.FloatTensor, - next_top_k_probs: torch.FloatTensor, - cosine_matrix_mask: torch.LongTensor, - alpha: float, - beam_width: int, -) -> torch.FloatTensor: - """ - Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described - in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each - row in the batch. - """ - norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) - norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) - cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] - - # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions) - # Using a large negative value for masked positions - cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) - cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min - cosine_matrix = cosine_matrix + cosine_matrix_mask - - degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] - next_top_k_probs = next_top_k_probs.view(-1) # [B*K] - contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty - contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] - _, selected_idx = contrastive_score.max(dim=-1) # [B] - return selected_idx - - def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: """ Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 765019f2e5f2..429d61cbd26a 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -289,6 +289,7 @@ def test_generation_mode(self): config = GenerationConfig(num_beams=2) self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH) + # TODO joao, manuel: remove this in v4.62.0 config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6) self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 92d770ff10d2..449d8122c12b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,6 @@ if is_torch_available(): import torch - import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, @@ -76,7 +75,6 @@ GPT2Tokenizer, ImageGPTForCausalImageModeling, SpeechEncoderDecoderModel, - T5ForConditionalGeneration, ) from transformers.cache_utils import ( Cache, @@ -415,41 +413,6 @@ def _constrained_beam_search_generate( return output_generate - def _contrastive_generate( - self, - model, - inputs_dict, - output_scores=False, - output_logits=False, - output_attentions=False, - output_hidden_states=False, - return_dict_in_generate=False, - use_cache=True, - ): - contrastive_search_kwargs = { - "penalty_alpha": 0.6, - "top_k": 5, - } - - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) - output_generate = model.generate( - do_sample=False, - num_beams=1, - max_new_tokens=self.max_new_tokens, - min_new_tokens=self.max_new_tokens, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - output_logits=output_logits, - return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, - **logits_processor_kwargs, - **contrastive_search_kwargs, - **inputs_dict, - ) - - return output_generate - @pytest.mark.generate def test_greedy_generate(self): for model_class in self.all_generative_model_classes: @@ -964,108 +927,6 @@ def test_constrained_beam_search_generate_dict_output(self): num_beams=beam_kwargs["num_beams"], ) - @pytest.mark.generate - def test_contrastive_generate(self): - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support contrastive search generation") - - # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") - - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - - # NOTE: contrastive search only works with cache on at the moment. - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - config.is_decoder = True - - # test old generation output for backwards compatibility - model = model_class(config).to(torch_device).eval() - output_generate = self._contrastive_generate( - model=model, - inputs_dict=inputs_dict, - use_cache=True, # Enable cache - ) - if model.config.get_text_config(decoder=True).is_encoder_decoder: - self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) - - @pytest.mark.generate - def test_contrastive_generate_dict_outputs_use_cache(self): - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support contrastive search generation") - - # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") - - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - - # NOTE: contrastive search only works with cache on at the moment. - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - config.is_decoder = True - if self.has_attentions: - config._attn_implementation = "eager" # can't output attentions otherwise - - model = model_class(config).to(torch_device).eval() - output_generate = self._contrastive_generate( - model=model, - inputs_dict=inputs_dict, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=self.has_attentions, - return_dict_in_generate=True, - use_cache=True, # Enable cache - ) - - if model.config.get_text_config(decoder=True).is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) - else: - self.assertTrue( - output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] - ) - - self._check_generate_outputs(output_generate, model.config, use_cache=True) - - @pytest.mark.generate - def test_contrastive_generate_low_memory(self): - # Check that choosing 'low_memory' does not change the model output - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support contrastive search generation") - - if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") - - config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - - # NOTE: contrastive search only works with cache on at the moment. - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - - config.is_decoder = True - - # test output equality of low versus high memory - model = model_class(config).to(torch_device).eval() - generate_kwargs = { - "top_k": 4, - "penalty_alpha": 0.6, - "max_new_tokens": self.max_new_tokens, - "use_cache": True, - "return_dict_in_generate": True, - "output_scores": True, - } - - low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True) - high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False) - self.assertTrue(has_similar_generate_outputs(low_output, high_output)) - @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate def test_assisted_decoding_matches_greedy_search(self, assistant_type): @@ -3443,31 +3304,6 @@ def test_decoder_start_id_from_config(self): with self.assertRaises(ValueError): outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) - def test_contrastive_search_batched(self): - # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) - articles = ["Foo", "Bar Baz"] - tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device) - - model.config.eos_token_id = None - input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device) - input_ids = tokenizer(articles[1], return_tensors="pt").input_ids.to(torch_device) - - output_sequences_batched = model.generate( - input_ids=input_ids_batched, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True - ) - output_sequences = model.generate( - input_ids=input_ids, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True - ) - - batched_out = tokenizer.decode(output_sequences_batched.sequences[1], skip_special_tokens=True) - out = tokenizer.decode(output_sequences.sequences[0], skip_special_tokens=True) - self.assertEqual(batched_out, out) - - # output_sequences_batched.scores[0][1] -> 1st set of logits, 2nd sequence - max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max() - self.assertTrue(max_score_diff < 1e-5) - def test_logits_processor_not_inplace(self): article = "Today a dragon flew over Paris." model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) @@ -4052,139 +3888,6 @@ def test_init_static_cache_multi_accelerator(self): values_1 = results.past_key_values.layers[1].values self.assertTrue(keys_1.device == values_1.device == torch.device(1)) - @slow - def test_padding_input_contrastive_search_gpt2(self): - # Load the pre-trained GPT-2 model and tokenizer - model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") - model.to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True) - - # Set the tokenizer to left-pad the sequences - tokenizer.padding_side = "left" - - # Define the PAD token as the EOS token - tokenizer.pad_token = tokenizer.eos_token - model.generation_config.pad_token_id = model.generation_config.eos_token_id - - # Define the input prompt - prompt_text = "The whispered legends of the haunted mansion spoke" - - # Tokenize the input prompt - encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True) - input_ids = encoded_prompt.input_ids.to(torch_device) - attention_mask = encoded_prompt.attention_mask.to(torch_device) - - # Define the contrastive search params - penalty_alpha = 0.6 - top_k = 4 - - # Define the padding length to add to the input IDs and attention mask - padding_length = 10 - - # Generate text without padding - outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) - - # Pad the input IDs and attention mask on the left - padded_input_ids = F.pad( - input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id - ) - padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0) - - # Generate text with padded inputs - outputs_with_padding = model.generate( - input_ids=padded_input_ids, - attention_mask=padded_attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) - - # Assert that the generated texts are identical for padded and non-padded inputs - self.assertEqual(generated_text_no_padding, generated_text_with_padding) - self.assertEqual( - generated_text_with_padding, - 'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling ' - 'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been ' - 'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea', - ) - - @slow - def test_padding_input_contrastive_search_t5(self): - # Load the pre-trained T5 model and tokenizer - model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - model.to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True) - - # Define the input prompt - prompt_text = "translate English to German: I need to finish this task before the end of the day." - - # Tokenize the input prompt - encoded_prompt = tokenizer(prompt_text, return_tensors="pt") - input_ids = encoded_prompt.input_ids.to(torch_device) - attention_mask = encoded_prompt.attention_mask.to(torch_device) - - # Define the decoder prompt - decoder_prompt_text = "Ich muss diese Aufgabe" - encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt") - decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device) - decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device) - - # Define the contrastive search params - penalty_alpha = 0.6 - top_k = 4 - - # Generate text without padding - outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) - - # Define the padding length to add to the input IDs and attention mask - padding_length = 10 - - # Pad the decoder input IDs and attention mask on the left - padded_decoder_input_ids = F.pad( - decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id - ) - padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0) - # Since the decoder_start_token_id is the same as the pad_token_id, - # the last padded token represents the decoder start token. - # Set the attention mask for the decoder_start_token_id to True (1). - padded_decoder_attention_mask[:, padding_length - 1] = 1 - # Generate text with padded inputs - outputs_with_padding = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=padded_decoder_input_ids, - decoder_attention_mask=padded_decoder_attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) - - # Assert that the generated texts are identical for padded and non-padded inputs - self.assertEqual(generated_text_no_padding, generated_text_with_padding) - self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") - def test_prepare_inputs_for_generation_decoder_llm(self): """Tests GenerationMixin.prepare_inputs_for_generation against expected usage with decoder-only llms.""" @@ -5113,7 +4816,13 @@ def test_generate_custom_cache_position(self): ) @pytest.mark.generate - def test_dola_hub_runs(self): + @parameterized.expand( + [ + ("transformers-community/dola", {"dola_layers": "low"}), + ("transformers-community/contrastive-search", {"penalty_alpha": 0.6, "top_k": 4}), + ] + ) + def test_hub_gen_strategies(self, custom_generate, extra_kwargs): model = AutoModelForCausalLM.from_pretrained( "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=torch_device, @@ -5123,7 +4832,7 @@ def test_dola_hub_runs(self): "input_ids": torch.tensor([[1, 22557, 28725, 1526, 28808]], device=torch_device), "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], device=torch_device), } - # Sets dola generation arguments such that: + # Sets generation arguments such that: # a) no EOS is generated, to ensure generation doesn't break early # b) there are at least two forward passes in the main model, to ensure the input preparation of # the main model is correct @@ -5138,13 +4847,13 @@ def test_dola_hub_runs(self): "output_attentions": True, "return_dict_in_generate": True, "use_cache": True, - "dola_layers": "low", "trust_remote_code": True, - "custom_generate": "transformers-community/dola", + "custom_generate": custom_generate, } + generation_kwargs.update(extra_kwargs) torch.manual_seed(0) - output_dola = model.generate(**generation_kwargs, **model_inputs) - self.assertEqual(output_dola.sequences.shape, (1, 9)) + output = model.generate(**generation_kwargs, **model_inputs) + self.assertEqual(output.sequences.shape, (1, 9)) @require_torch diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index ded8d5f0a8e8..9d887895b941 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1205,6 +1205,7 @@ def test_cnn_summarization_same_as_fairseq(self): generated_summaries = tok.batch_decode(hypotheses_batch.tolist()) assert generated_summaries == EXPECTED + # TODO joao, manuel: remove this in v4.62.0 @slow def test_contrastive_search_bart(self): article = ( @@ -1238,7 +1239,15 @@ def test_contrastive_search_bart(self): article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" ).input_ids.to(torch_device) - outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, num_beams=1) + outputs = bart_model.generate( + input_ids, + penalty_alpha=0.5, + top_k=5, + max_length=64, + num_beams=1, + trust_remote_code=True, + custom_generate="transformers-community/contrastive-search", + ) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual( diff --git a/tests/models/csm/test_modeling_csm.py b/tests/models/csm/test_modeling_csm.py index fa29ec4df6e7..f81685abd091 100644 --- a/tests/models/csm/test_modeling_csm.py +++ b/tests/models/csm/test_modeling_csm.py @@ -292,21 +292,6 @@ def test_constrained_beam_search_generate(self): def test_constrained_beam_search_generate_dict_output(self): pass - @pytest.mark.generate - @unittest.skip(reason="CSM does not support contrastive search.") - def test_contrastive_generate(self): - pass - - @pytest.mark.generate - @unittest.skip(reason="CSM does not support contrastive search.") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - - @pytest.mark.generate - @unittest.skip(reason="CSM does not support contrastive search.") - def test_contrastive_generate_low_memory(self): - pass - @pytest.mark.generate @unittest.skip(reason="CSM does not support prompt lookup decoding.") def test_prompt_lookup_decoding_matches_greedy_search(self): diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 284cd4c19909..7c6a322d99cb 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -516,6 +516,7 @@ def test_model_2b_bf16_dola(self): dola_layers="low", repetition_penalty=1.2, trust_remote_code=True, + custom_generate="transformers-community/dola", ) output_text = tokenizer.batch_decode(output, skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 1aac2069b084..072bbd081643 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -837,6 +837,7 @@ def test_gpt2_sample(self): all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))) ) # token_type_ids should change output + # TODO joao, manuel: remove this in v4.62.0 @slow def test_contrastive_search_gpt2(self): article = ( @@ -848,7 +849,14 @@ def test_contrastive_search_gpt2(self): gpt2_model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large").to(torch_device) input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + outputs = gpt2_model.generate( + input_ids, + penalty_alpha=0.6, + top_k=4, + max_length=256, + trust_remote_code=True, + custom_generate="transformers-community/contrastive-search", + ) generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index a57dc883f3ca..b24f47c32bca 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -424,14 +424,6 @@ def test_config(self): def test_retain_grad_hidden_states_attentions(self): pass - @unittest.skip(reason="Contrastive search not supported due to non-standard caching mechanism") - def test_contrastive_generate(self): - pass - - @unittest.skip(reason="Contrastive search not supported due to non-standard caching mechanism") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - @unittest.skip(reason="CPU offload seems to be broken for some reason - tiny models keep hitting corner cases") def test_cpu_offload(self): pass diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index d4cf398a6da9..073660b49cf0 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -541,6 +541,7 @@ def test_gptj_sample(self): all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))) ) # token_type_ids should change output + # TODO joao, manuel: remove this in v4.62.0 @tooslow def test_contrastive_search_gptj(self): article = ( @@ -554,7 +555,14 @@ def test_contrastive_search_gptj(self): ) input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - outputs = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + outputs = model.generate( + input_ids, + penalty_alpha=0.6, + top_k=4, + max_length=256, + trust_remote_code=True, + custom_generate="transformers-community/contrastive-search", + ) generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual( diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 15ab5ef22b47..454b38975cdd 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -844,18 +844,6 @@ def _check_attentions_for_generate( """ pass - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_low_memory(self): - pass - @unittest.skip(reason="We only test the model that takes in multiple images") def test_custom_4d_attention_mask(self): pass diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 1a31eb852ec1..199664a73d85 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -390,18 +390,6 @@ def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_low_memory(self): - pass - @unittest.skip( reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" ) diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 234f6ceb8b01..97cff53643bc 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -351,18 +351,6 @@ def test_inputs_embeds(): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_low_memory(self): - pass - @unittest.skip( reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" ) diff --git a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py index d9c2476a1fce..62ee1be2dbe6 100644 --- a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py +++ b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py @@ -585,14 +585,6 @@ def test_flash_attn_2_generate_reuse_cache(self): def test_generate_from_inputs_embeds(self): pass - # TODO: ydshieh - @pytest.mark.generate - @unittest.skip( - "Kosmos2_5ForConditionalGeneration returns `vision_model_output` which is currently not working with `stack_model_outputs`", - ) - def test_beam_search_low_memory(self): - pass - @pytest.mark.generate def test_left_padding_compatibility(self): # Overwrite because Kosmos-2.5 need to padd pixel values and pad image-attn-mask diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index 4603f54dc7f7..52d4b4d6fce1 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -75,18 +75,6 @@ def test_attention_outputs(self): def test_past_key_values_format(self): pass - @unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search") - def test_contrastive_generate(self): - pass - - @unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search") - def test_contrastive_generate_low_memory(self): - pass - @unittest.skip( "Lfm2 has a special cache format which is not compatible with compile as it has static address for conv cache" ) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 03458d53a37b..9217510fb0b0 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -249,7 +249,14 @@ def test_model_7b_dola_generation(self): # greedy generation outputs generated_ids = model.generate( - **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + **model_inputs, + max_new_tokens=64, + top_p=None, + temperature=1, + do_sample=False, + dola_layers="low", + trust_remote_code=True, + custom_generate="transformers-community/dola", ) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 4a62f2bb1a9b..26f793a7d814 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -192,6 +192,7 @@ def test_model_7b_dola_generation(self): dola_layers="low", repetition_penalty=1.2, trust_remote_code=True, + custom_generate="transformers-community/dola", ) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 8ff45ff9678b..331d1aba498b 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -543,6 +543,7 @@ def test_batched_nan_fp16(self): torch.isnan(outputs.logits[0]).any().item() ) # the first logits could contain NaNs if it fails + # TODO joao, manuel: remove this in v4.62.0 @slow def test_contrastive_search_opt(self): article = ( @@ -555,7 +556,14 @@ def test_contrastive_search_opt(self): opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device) input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256) + outputs = opt_model.generate( + input_ids, + penalty_alpha=0.6, + top_k=5, + max_length=256, + trust_remote_code=True, + custom_generate="transformers-community/contrastive-search", + ) generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual( diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index f4c211a5a6c5..ad345e70e03e 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -273,10 +273,6 @@ def test_feed_forward_chunking(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("Low memory will be removed soon so no need to fix it") - def test_beam_search_low_memory(self): - pass - @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate @unittest.skip("Paligemma2 does not seem to be compatible with assisted decoding") diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index c485fb92d8d3..45aec1da4ba9 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -345,18 +345,6 @@ def setUp(self): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") - def test_contrastive_generate_low_memory(self): - pass - @unittest.skip( reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" ) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 1a0780089779..9de1467fa061 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1569,6 +1569,7 @@ def test_translation_en_to_ro(self): translation = tok.decode(output[0]) self.assertEqual(translation, expected_translation) + # TODO joao, manuel: remove this in v4.62.0 @slow def test_contrastive_search_t5(self): article = ( @@ -1603,7 +1604,14 @@ def test_contrastive_search_t5(self): article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" ).input_ids.to(torch_device) - outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + outputs = t5_model.generate( + input_ids, + penalty_alpha=0.5, + top_k=5, + max_length=64, + trust_remote_code=True, + custom_generate="transformers-community/contrastive-search", + ) generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) # TODO: @arthur?