From 0f511492218cfedcda7ae315679e034dbfd264d3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 6 Sep 2022 11:40:04 +0000 Subject: [PATCH 1/3] add model class validation --- src/transformers/generation_flax_utils.py | 13 ++++++++++++- src/transformers/generation_tf_utils.py | 13 ++++++++++++- src/transformers/generation_utils.py | 13 ++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 1c052aae7bafb6..0bd57ff9b73413 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -161,6 +161,16 @@ def _adapt_logits_for_beam_search(self, logits): """ return logits + def _validate_model_class(self): + """Confirms that the model class is compatible with generation.""" + if not hasattr(self, "prepare_inputs_for_generation"): + model_class = self.__class__.__name__ + raise TypeError( + f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", + "language model head. The following AutoModel classes are compatible: `FlaxAutoModelForCausalLM`, ", + "`FlaxAutoModelForSeq2SeqLM`, `FlaxAutoModelForVision2Seq`.", + ) + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" unused_model_args = [] @@ -281,7 +291,8 @@ def generate( >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ```""" - # Validate model kwargs + # Validate the `.generate()` call + self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # set init values diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index d5f92b51e722ec..d706b9aa29a0a3 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1290,6 +1290,16 @@ def adjust_logits_during_generation( else: return logits + def _validate_model_class(self): + """Confirms that the model class is compatible with generation.""" + if not hasattr(self, "prepare_inputs_for_generation"): + model_class = self.__class__.__name__ + raise TypeError( + f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", + "language model head. The following AutoModel classes are compatible: `TFAutoModelForCausalLM`, ", + "`TFAutoModelForSeq2SeqLM`, `TFAutoModelForVision2Seq`, `TFAutoModelForSpeechSeq2Seq`.", + ) + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling any model function @@ -1508,7 +1518,8 @@ def _generate( # generate sequences without allowing bad_words to be generated outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) ```""" - # 0. Validate model kwargs + # 0. Validate the `.generate()` call + self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # 1. Set generation parameters if not already defined diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index dcbe6e5946d24f..94c024a46bed43 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -841,6 +841,16 @@ def compute_transition_beam_scores( return transition_scores + def _validate_model_class(self): + """Confirms that the model class is compatible with generation.""" + if not hasattr(self, "prepare_inputs_for_generation"): + model_class = self.__class__.__name__ + raise TypeError( + f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", + "language model head. The following AutoModel classes are compatible: `AutoModelForCausalLM`, ", + "`AutoModelForSeq2SeqLM`, `AutoModelForVision2Seq`, `AutoModelForSpeechSeq2Seq`.", + ) + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling any model function @@ -1143,7 +1153,8 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # 0. Validate model kwargs + # 0. Validate the `.generate()` call + self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # 1. Set generation parameters if not already defined From 4b5fa5419f693eb1f83b0a47ab3733a92de5e2b2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 6 Sep 2022 16:15:00 +0000 Subject: [PATCH 2/3] tentative commit (let's see if tests break) --- src/transformers/generation_tf_utils.py | 6 ---- src/transformers/generation_utils.py | 43 +++++++++++++++++-------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index d706b9aa29a0a3..4f503819282920 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -357,12 +357,6 @@ def seed_generator(self): supports_xla_generation = True - def prepare_inputs_for_generation(self, inputs, **kwargs): - """ - Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method. - """ - return {"input_ids": inputs} - def _use_cache(self, outputs, use_cache): """During generation, decide whether to pass the `past` variable to the next forward pass.""" use_cache = getattr(self.config, "use_cache", False) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 94c024a46bed43..5e033ca696d40b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -51,6 +51,13 @@ StoppingCriteriaList, validate_stopping_criteria, ) +from .models.auto import ( + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, +) from .pytorch_utils import torch_int_div from .utils import ModelOutput, logging @@ -463,12 +470,6 @@ def _can_retrieve_inputs_from_name( return can_retrieve_inputs - def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. - """ - return {"input_ids": input_ids} - def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. @@ -842,14 +843,30 @@ def compute_transition_beam_scores( return transition_scores def _validate_model_class(self): - """Confirms that the model class is compatible with generation.""" + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ if not hasattr(self, "prepare_inputs_for_generation"): - model_class = self.__class__.__name__ - raise TypeError( - f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", - "language model head. The following AutoModel classes are compatible: `AutoModelForCausalLM`, ", - "`AutoModelForSeq2SeqLM`, `AutoModelForVision2Seq`, `AutoModelForSpeechSeq2Seq`.", - ) + generate_compatible_mappings = [ + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + ] + generate_compatible_classes = set() + for model_mapping in generate_compatible_mappings: + supported_models = model_mapping.get(type(self.config), default=None) + if supported_models is not None: + generate_compatible_classes.add(supported_models.__name__) + exception_message = ( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head." + ) + if generate_compatible_classes: + exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" + raise TypeError(exception_message) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" From 3df6f5f943c93492cdba3e7148b52c2d9bf6327b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 6 Sep 2022 18:32:02 +0000 Subject: [PATCH 3/3] exception points to the right class --- src/transformers/generation_flax_utils.py | 31 +++++++++++++---- src/transformers/generation_tf_utils.py | 33 +++++++++++++++---- .../models/openai/modeling_openai.py | 5 ++- .../models/openai/modeling_tf_openai.py | 3 ++ 4 files changed, 59 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 0bd57ff9b73413..6883b8cf37f10e 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -36,6 +36,11 @@ FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, ) +from .models.auto import ( + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, +) from .utils import ModelOutput, logging @@ -162,14 +167,28 @@ def _adapt_logits_for_beam_search(self, logits): return logits def _validate_model_class(self): - """Confirms that the model class is compatible with generation.""" + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ if not hasattr(self, "prepare_inputs_for_generation"): - model_class = self.__class__.__name__ - raise TypeError( - f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", - "language model head. The following AutoModel classes are compatible: `FlaxAutoModelForCausalLM`, ", - "`FlaxAutoModelForSeq2SeqLM`, `FlaxAutoModelForVision2Seq`.", + generate_compatible_mappings = [ + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + ] + generate_compatible_classes = set() + for model_mapping in generate_compatible_mappings: + supported_models = model_mapping.get(type(self.config), default=None) + if supported_models is not None: + generate_compatible_classes.add(supported_models.__name__) + exception_message = ( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head." ) + if generate_compatible_classes: + exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" + raise TypeError(exception_message) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 4f503819282920..d2b4ef746cc227 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -35,6 +35,12 @@ TFTopKLogitsWarper, TFTopPLogitsWarper, ) +from .models.auto import ( + TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + TF_MODEL_FOR_VISION_2_SEQ_MAPPING, +) from .tf_utils import shape_list, stable_softmax from .utils import ModelOutput, logging @@ -1285,14 +1291,29 @@ def adjust_logits_during_generation( return logits def _validate_model_class(self): - """Confirms that the model class is compatible with generation.""" + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ if not hasattr(self, "prepare_inputs_for_generation"): - model_class = self.__class__.__name__ - raise TypeError( - f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", - "language model head. The following AutoModel classes are compatible: `TFAutoModelForCausalLM`, ", - "`TFAutoModelForSeq2SeqLM`, `TFAutoModelForVision2Seq`, `TFAutoModelForSpeechSeq2Seq`.", + generate_compatible_mappings = [ + TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_VISION_2_SEQ_MAPPING, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + ] + generate_compatible_classes = set() + for model_mapping in generate_compatible_mappings: + supported_models = model_mapping.get(type(self.config), default=None) + if supported_models is not None: + generate_compatible_classes.add(supported_models.__name__) + exception_message = ( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head." ) + if generate_compatible_classes: + exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" + raise TypeError(exception_message) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index e5e5da5da0c9f6..2bd634abeb1154 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -20,7 +20,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn @@ -607,6 +607,9 @@ def forward( attentions=transformer_outputs.attentions, ) + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + return {"input_ids": input_ids} + @add_start_docstrings( """ diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index 8a176190862816..101f16931bd0ab 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -633,6 +633,9 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + def prepare_inputs_for_generation(self, inputs, **kwargs): + return {"input_ids": inputs} + @add_start_docstrings( """