Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: torch.compile-ready generation config preparation #29443

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 59 additions & 33 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, is_accelerate_available, logging
from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
Expand Down Expand Up @@ -1162,6 +1162,59 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _prepare_generation_config(
self, generation_config: GenerationConfig, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if (
not is_torchdynamo_compiling()
and self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config

# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if is_torchdynamo_compiling():
model_kwargs = kwargs
generate_attributes_in_kwargs = [
key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
]
if len(generate_attributes_in_kwargs) > 0:
raise ValueError(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})."
)
else:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)

return generation_config, model_kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1260,44 +1313,17 @@ def generate(
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False

# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
is_torchdynamo_compiling,
is_torchvision_available,
is_training_run_on_sagemaker,
is_vision_available,
Expand Down
Loading