Skip to content

Commit

Permalink
fix #16 by also allowing a generationconfig object to be passed progr…
Browse files Browse the repository at this point in the history
…ammatically if needed
  • Loading branch information
clefourrier committed Dec 12, 2024
1 parent 4833929 commit 30bed89
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
GPTQConfig,
PretrainedConfig,
)
from transformers.generation.utils import GenerateOutput
from transformers.generation.utils import GenerateOutput, GenerationConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
Expand Down Expand Up @@ -126,6 +126,8 @@ class TransformersModelConfig:
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.
generation_parameters (GenerationParameters): Range of parameters which will affect the generation.
generation_config (GenerationConfig): GenerationConfig object (only passed during manual creation)
Methods:
__post_init__(): Performs post-initialization checks on the configuration.
Expand Down Expand Up @@ -154,6 +156,7 @@ class TransformersModelConfig:
use_chat_template: bool = False
compile: bool = False
generation_parameters: GenerationParameters = None
generation_config: GenerationConfig = None

def __post_init__(self):
# Making sure this parameter is a boolean
Expand All @@ -180,7 +183,12 @@ def __post_init__(self):
if not isinstance(self.device, str):
raise ValueError("Current device must be passed as string.")

if not self.generation_parameters:
if self.generation_config and self.generation_parameters:
raise ValueError(
"Can't use both generation_config and generation_parameters argument. Pass the generation parameters to your generation config object"
)

if not self.generation_parameters and not self.generation_config:
self.generation_parameters = GenerationParameters()

def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig:
Expand Down Expand Up @@ -275,8 +283,11 @@ def __init__(
self.model_sha = config.get_model_sha()

self.precision = _get_dtype(config.dtype, config=self._config)
self.generation_parameters = config.generation_parameters
self.generation_config_dict = self.generation_parameters.to_transformers_dict()
if config.generation_config is None:
self.generation_parameters = config.generation_parameters
self.generation_config_dict = self.generation_parameters.to_transformers_dict()
else:
self.generation_config_dict = config.generation_config.to_dict()

if is_accelerate_available():
model_size, _ = calculate_maximum_sizes(self.model)
Expand Down

0 comments on commit 30bed89

Please sign in to comment.