diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 1c052aae7bafb6..6883b8cf37f10e 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -36,6 +36,11 @@ FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, ) +from .models.auto import ( + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, +) from .utils import ModelOutput, logging @@ -161,6 +166,30 @@ def _adapt_logits_for_beam_search(self, logits): """ return logits + def _validate_model_class(self): + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ + if not hasattr(self, "prepare_inputs_for_generation"): + generate_compatible_mappings = [ + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + ] + generate_compatible_classes = set() + for model_mapping in generate_compatible_mappings: + supported_models = model_mapping.get(type(self.config), default=None) + if supported_models is not None: + generate_compatible_classes.add(supported_models.__name__) + exception_message = ( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head." + ) + if generate_compatible_classes: + exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" + raise TypeError(exception_message) + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" unused_model_args = [] @@ -281,7 +310,8 @@ def generate( >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ```""" - # Validate model kwargs + # Validate the `.generate()` call + self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # set init values diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index d5f92b51e722ec..d2b4ef746cc227 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -35,6 +35,12 @@ TFTopKLogitsWarper, TFTopPLogitsWarper, ) +from .models.auto import ( + TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + TF_MODEL_FOR_VISION_2_SEQ_MAPPING, +) from .tf_utils import shape_list, stable_softmax from .utils import ModelOutput, logging @@ -357,12 +363,6 @@ def seed_generator(self): supports_xla_generation = True - def prepare_inputs_for_generation(self, inputs, **kwargs): - """ - Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method. - """ - return {"input_ids": inputs} - def _use_cache(self, outputs, use_cache): """During generation, decide whether to pass the `past` variable to the next forward pass.""" use_cache = getattr(self.config, "use_cache", False) @@ -1290,6 +1290,31 @@ def adjust_logits_during_generation( else: return logits + def _validate_model_class(self): + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ + if not hasattr(self, "prepare_inputs_for_generation"): + generate_compatible_mappings = [ + TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_VISION_2_SEQ_MAPPING, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + ] + generate_compatible_classes = set() + for model_mapping in generate_compatible_mappings: + supported_models = model_mapping.get(type(self.config), default=None) + if supported_models is not None: + generate_compatible_classes.add(supported_models.__name__) + exception_message = ( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head." + ) + if generate_compatible_classes: + exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" + raise TypeError(exception_message) + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling any model function @@ -1508,7 +1533,8 @@ def _generate( # generate sequences without allowing bad_words to be generated outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) ```""" - # 0. Validate model kwargs + # 0. Validate the `.generate()` call + self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # 1. Set generation parameters if not already defined diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index dcbe6e5946d24f..5e033ca696d40b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -51,6 +51,13 @@ StoppingCriteriaList, validate_stopping_criteria, ) +from .models.auto import ( + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, +) from .pytorch_utils import torch_int_div from .utils import ModelOutput, logging @@ -463,12 +470,6 @@ def _can_retrieve_inputs_from_name( return can_retrieve_inputs - def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. - """ - return {"input_ids": input_ids} - def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. @@ -841,6 +842,32 @@ def compute_transition_beam_scores( return transition_scores + def _validate_model_class(self): + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ + if not hasattr(self, "prepare_inputs_for_generation"): + generate_compatible_mappings = [ + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + ] + generate_compatible_classes = set() + for model_mapping in generate_compatible_mappings: + supported_models = model_mapping.get(type(self.config), default=None) + if supported_models is not None: + generate_compatible_classes.add(supported_models.__name__) + exception_message = ( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head." + ) + if generate_compatible_classes: + exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" + raise TypeError(exception_message) + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling any model function @@ -1143,7 +1170,8 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # 0. Validate model kwargs + # 0. Validate the `.generate()` call + self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # 1. Set generation parameters if not already defined diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index e5e5da5da0c9f6..2bd634abeb1154 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -20,7 +20,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn @@ -607,6 +607,9 @@ def forward( attentions=transformer_outputs.attentions, ) + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + return {"input_ids": input_ids} + @add_start_docstrings( """ diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index 8a176190862816..101f16931bd0ab 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -633,6 +633,9 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + def prepare_inputs_for_generation(self, inputs, **kwargs): + return {"input_ids": inputs} + @add_start_docstrings( """