diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index f47204df48..72c01b2707 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -21,7 +21,7 @@ import numpy as np import torch -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, GenerationConfig from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions @@ -339,6 +339,7 @@ def __init__( use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, **kwargs ): """ @@ -357,6 +358,9 @@ def __init__( The directory under which the model exported to ONNX was saved. preprocessors (`Optional[List]`, defaults to `None`): The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. + generation_config (`Optional[GenerationConfig]`, defaults to `None`): + The generation configuration used by default when calling `generate()`. + Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate. """ # TODO: remove at version 2.0 def show_deprecated_argument(arg_name): @@ -399,6 +403,10 @@ def show_deprecated_argument(arg_name): self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + if generation_config is None: + generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config + @staticmethod def load_model( decoder_path: Union[str, Path], @@ -626,6 +634,20 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir + generation_config = None + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + ) + except OSError: + logger.info("Generation config file not found, using a generation config created from the model config.") + return cls( model[0], config, @@ -633,6 +655,7 @@ def _from_pretrained( use_io_binding=use_io_binding, model_save_dir=model_save_dir, preprocessors=preprocessors, + generation_config=generation_config, ) @classmethod @@ -784,3 +807,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past ) + + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 9f6e9a2dec..a4ca8be4e5 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -17,7 +17,6 @@ """ import logging -import re import shutil from abc import ABC, abstractmethod from pathlib import Path @@ -26,7 +25,7 @@ import numpy as np import torch -from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq +from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, GenerationConfig from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -37,7 +36,7 @@ from ..exporters.tasks import TasksManager from ..onnx.utils import _get_external_data_paths from ..utils import NormalizedConfigManager, check_if_transformers_greater -from ..utils.file_utils import find_files_matching_pattern, validate_file_exists +from ..utils.file_utils import validate_file_exists from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper from .modeling_decoder import ORTDecoder @@ -728,6 +727,7 @@ def __init__( use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, **kwargs, ): """ @@ -748,6 +748,9 @@ def __init__( The directory under which the model exported to ONNX was saved. preprocessors (`Optional[List]`, defaults to `None`): The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. + generation_config (`Optional[GenerationConfig]`, defaults to `None`): + The generation configuration used by default when calling `generate()`. + Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate. """ # TODO: remove at version 2.0 def show_deprecated_argument(arg_name): @@ -804,6 +807,10 @@ def show_deprecated_argument(arg_name): self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + if generation_config is None: + generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config + @abstractmethod def _initialize_encoder( self, @@ -1076,6 +1083,20 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir + generation_config = None + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + ) + except OSError: + logger.info("Generation config file not found, using a generation config created from the model config.") + return cls( *model[:2], config, @@ -1083,6 +1104,7 @@ def _from_pretrained( use_io_binding=use_io_binding, model_save_dir=model_save_dir, preprocessors=preprocessors, + generation_config=generation_config, ) @classmethod @@ -1178,6 +1200,12 @@ def to(self, device: Union[torch.device, str, int]): return self + def can_generate(self): + logger.warning( + "ORTModelForConditionalGeneration is an abstract class and is not meant to be used for generation. Please use ORTModelForSeq2SeqLM or ORTModelForSpeechSeq2Seq." + ) + return False + class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): """ @@ -1286,6 +1314,10 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: ) return reordered_past + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True + class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin): """ @@ -1398,3 +1430,7 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True