Skip to content

Commit

Permalink
Generate: torch.compile-ready generation config preparation (#29443)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Mar 6, 2024
1 parent 9322576 commit ddb4fda
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
92 changes: 59 additions & 33 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, is_accelerate_available, logging
from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
Expand Down Expand Up @@ -1162,6 +1162,59 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _prepare_generation_config(
self, generation_config: GenerationConfig, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if (
not is_torchdynamo_compiling()
and self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config

# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if is_torchdynamo_compiling():
model_kwargs = kwargs
generate_attributes_in_kwargs = [
key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
]
if len(generate_attributes_in_kwargs) > 0:
raise ValueError(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})."
)
else:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)

return generation_config, model_kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1260,44 +1313,17 @@ def generate(
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False

# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
is_torchdynamo_compiling,
is_torchvision_available,
is_training_run_on_sagemaker,
is_vision_available,
Expand Down

0 comments on commit ddb4fda

Please sign in to comment.