Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: add model class validation #18902

Merged
merged 3 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from .utils import ModelOutput, logging


Expand Down Expand Up @@ -161,6 +166,30 @@ def _adapt_logits_for_beam_search(self, logits):
"""
return logits

def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not hasattr(self, "prepare_inputs_for_generation"):
generate_compatible_mappings = [
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
Expand Down Expand Up @@ -281,7 +310,8 @@ def generate(
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
# Validate the `.generate()` call
self._validate_model_class()
self._validate_model_kwargs(model_kwargs.copy())

# set init values
Expand Down
40 changes: 33 additions & 7 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from .tf_utils import shape_list, stable_softmax
from .utils import ModelOutput, logging

Expand Down Expand Up @@ -357,12 +363,6 @@ def seed_generator(self):

supports_xla_generation = True

def prepare_inputs_for_generation(self, inputs, **kwargs):
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return {"input_ids": inputs}

def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
use_cache = getattr(self.config, "use_cache", False)
Expand Down Expand Up @@ -1290,6 +1290,31 @@ def adjust_logits_during_generation(
else:
return logits

def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not hasattr(self, "prepare_inputs_for_generation"):
generate_compatible_mappings = [
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
Expand Down Expand Up @@ -1508,7 +1533,8 @@ def _generate(
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self._validate_model_class()
self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
Expand Down
42 changes: 35 additions & 7 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
StoppingCriteriaList,
validate_stopping_criteria,
)
from .models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from .pytorch_utils import torch_int_div
from .utils import ModelOutput, logging

Expand Down Expand Up @@ -463,12 +470,6 @@ def _can_retrieve_inputs_from_name(

return can_retrieve_inputs

def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return {"input_ids": input_ids}

def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Expand Down Expand Up @@ -841,6 +842,32 @@ def compute_transition_beam_scores(

return transition_scores

def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not hasattr(self, "prepare_inputs_for_generation"):
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
Expand Down Expand Up @@ -1143,7 +1170,8 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self._validate_model_class()
self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/openai/modeling_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -607,6 +607,9 @@ def forward(
attentions=transformer_outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
return {"input_ids": input_ids}


@add_start_docstrings(
"""
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/openai/modeling_tf_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:

return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)

def prepare_inputs_for_generation(self, inputs, **kwargs):
return {"input_ids": inputs}


@add_start_docstrings(
"""
Expand Down