Skip to content
This repository was archived by the owner on Jan 11, 2025. It is now read-only.

Commit 5a28b22

Browse files
clefourrierNathanHBalbertvillanova
authored
Homogeneize generation params (huggingface#428)
This PR does 3 things: Provide an homogeneized API for people to use to provide model generation parameters in model configs. Those parameters are notably provided to all models which can take them (vllm, open ai, tgi, transformers, ...) Renames BaseModel to TransformersModel Also allows TransformersModels to use a transformers.GenerationConfig object directly, when created programmatically I would put system_prompt, fewshot_seeds, and use_chat_template in the GenerationParameters too since they are generation parameters logically, but it can be another PR --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
1 parent 24afde2 commit 5a28b22

File tree

20 files changed

+404
-113
lines changed

20 files changed

+404
-113
lines changed

docs/source/package_reference/models.mdx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
## Accelerate and Transformers Models
9-
### BaseModel
10-
[[autodoc]] models.transformers.base_model.BaseModelConfig
11-
[[autodoc]] models.transformers.base_model.BaseModel
9+
### TransformersModel
10+
[[autodoc]] models.transformers.transformers_model.TransformersModelConfig
11+
[[autodoc]] models.transformers.transformers_model.TransformersModel
1212

1313
### AdapterModel
1414
[[autodoc]] models.transformers.adapter_model.AdapterModelConfig
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
model:
22
base_params:
3-
model_args: "pretrained=HuggingFaceH4/zephyr-7b-beta,revision=main" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ...
3+
model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ...
44
dtype: "bfloat16"
55
compile: true
66
merged_weights: # Ignore this section if you are not using PEFT models
@@ -9,3 +9,4 @@ model:
99
base_model: null # path to the base_model
1010
generation:
1111
multichoice_continuations_start_space: null # If true/false, will force multiple choice continuations to start/not start with a space. If none, will do nothing
12+
temperature: 0.5

src/lighteval/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222
import logging
23-
from logging.config import dictConfig
23+
import logging.config
2424

2525
import colorlog
2626
import typer
@@ -57,7 +57,8 @@
5757
},
5858
)
5959

60-
dictConfig(logging_config)
60+
logging.config.dictConfig(logging_config)
61+
logging.captureWarnings(capture=True)
6162

6263
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_accelerate.accelerate)
6364
app.command(rich_help_panel="Evaluation Utils")(lighteval.main_baseline.baseline)

src/lighteval/main_accelerate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def accelerate( # noqa C901
4444
model_args: Annotated[
4545
str,
4646
Argument(
47-
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/base_model.yaml)"
47+
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/transformers_model.yaml)"
4848
),
4949
],
5050
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
@@ -107,9 +107,10 @@ def accelerate( # noqa C901
107107
from accelerate import Accelerator, InitProcessGroupKwargs
108108

109109
from lighteval.logging.evaluation_tracker import EvaluationTracker
110+
from lighteval.models.model_input import GenerationParameters
110111
from lighteval.models.transformers.adapter_model import AdapterModelConfig
111-
from lighteval.models.transformers.base_model import BaseModelConfig, BitsAndBytesConfig
112112
from lighteval.models.transformers.delta_model import DeltaModelConfig
113+
from lighteval.models.transformers.transformers_model import BitsAndBytesConfig, TransformersModelConfig
113114
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
114115

115116
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
@@ -154,6 +155,8 @@ def accelerate( # noqa C901
154155
# We extract the model args
155156
args_dict = {k.split("=")[0]: k.split("=")[1] for k in config["base_params"]["model_args"].split(",")}
156157

158+
args_dict["generation_parameters"] = GenerationParameters.from_dict(config)
159+
157160
# We store the relevant other args
158161
args_dict["base_model"] = config["merged_weights"]["base_model"]
159162
args_dict["compile"] = bool(config["base_params"]["compile"])
@@ -180,13 +183,13 @@ def accelerate( # noqa C901
180183
elif config["merged_weights"]["base_model"] not in ["", None]:
181184
raise ValueError("You can't specify a base model if you are not using delta/adapter weights")
182185
else:
183-
model_config = BaseModelConfig(**args_dict)
186+
model_config = TransformersModelConfig(**args_dict)
184187
else:
185188
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
186189
model_args_dict["accelerator"] = accelerator
187190
model_args_dict["use_chat_template"] = use_chat_template
188191
model_args_dict["compile"] = bool(model_args_dict["compile"]) if "compile" in model_args_dict else False
189-
model_config = BaseModelConfig(**model_args_dict)
192+
model_config = TransformersModelConfig(**model_args_dict)
190193

191194
pipeline = Pipeline(
192195
tasks=tasks,

src/lighteval/main_endpoint.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@
4242
@app.command(rich_help_panel="Evaluation Backends")
4343
def openai(
4444
# === general ===
45-
model_name: Annotated[
46-
str, Argument(help="The model name to evaluate (has to be available through the openai API.")
45+
model_args: Annotated[
46+
str,
47+
Argument(
48+
help="Model name as a string (has to be available through the openai API) or path to yaml config file (see examples/model_configs/transformers_model.yaml)"
49+
),
4750
],
4851
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
4952
# === Common parameters ===
@@ -96,6 +99,11 @@ def openai(
9699
from lighteval.models.endpoints.openai_model import OpenAIModelConfig
97100
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
98101

102+
if model_args.endswith(".yaml"):
103+
model_config = OpenAIModelConfig.from_path(model_args)
104+
else:
105+
model_config = OpenAIModelConfig(model=model_args)
106+
99107
env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
100108
evaluation_tracker = EvaluationTracker(
101109
output_dir=output_dir,
@@ -107,7 +115,6 @@ def openai(
107115
)
108116

109117
parallelism_manager = ParallelismManager.OPENAI
110-
model_config = OpenAIModelConfig(model=model_name)
111118

112119
pipeline_params = PipelineParameters(
113120
launcher_type=parallelism_manager,
@@ -205,7 +212,6 @@ def inference_endpoint(
205212
"""
206213
Evaluate models using inference-endpoints as backend.
207214
"""
208-
209215
from lighteval.logging.evaluation_tracker import EvaluationTracker
210216
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig
211217
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
@@ -319,7 +325,6 @@ def tgi(
319325
"""
320326
Evaluate models using TGI as backend.
321327
"""
322-
323328
from lighteval.logging.evaluation_tracker import EvaluationTracker
324329
from lighteval.models.endpoints.tgi_model import TGIModelConfig
325330
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

src/lighteval/main_vllm.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737

3838
def vllm(
3939
# === general ===
40-
model_args: Annotated[str, Argument(help="Model arguments in the form key1=value1,key2=value2,...")],
40+
model_args: Annotated[
41+
str,
42+
Argument(
43+
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/transformers_model.yaml)"
44+
),
45+
],
4146
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
4247
# === Common parameters ===
4348
use_chat_template: Annotated[
@@ -88,7 +93,10 @@ def vllm(
8893
"""
8994
Evaluate models using vllm as backend.
9095
"""
96+
import yaml
97+
9198
from lighteval.logging.evaluation_tracker import EvaluationTracker
99+
from lighteval.models.model_input import GenerationParameters
92100
from lighteval.models.vllm.vllm_model import VLLMModelConfig
93101
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
94102

@@ -118,8 +126,15 @@ def vllm(
118126
system_prompt=system_prompt,
119127
)
120128

121-
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
122-
model_config = VLLMModelConfig(**model_args_dict)
129+
if model_args.endswith(".yaml"):
130+
with open(model_args, "r") as f:
131+
config = yaml.safe_load(f)["model"]
132+
generation_parameters = GenerationParameters.from_dict(config)
133+
model_config = VLLMModelConfig(config, generation_parameters=generation_parameters)
134+
135+
else:
136+
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
137+
model_config = VLLMModelConfig(**model_args_dict)
123138

124139
pipeline = Pipeline(
125140
tasks=tasks,

src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import logging
2525
import re
2626
import time
27-
from dataclasses import dataclass
27+
from dataclasses import dataclass, replace
2828
from typing import Coroutine, Dict, List, Optional, Union
2929

3030
import requests
@@ -35,6 +35,7 @@
3535
InferenceEndpoint,
3636
InferenceEndpointError,
3737
InferenceEndpointTimeoutError,
38+
TextGenerationInputGenerateParameters,
3839
TextGenerationInputGrammarType,
3940
TextGenerationOutput,
4041
create_inference_endpoint,
@@ -48,6 +49,7 @@
4849

4950
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
5051
from lighteval.models.abstract_model import LightevalModel, ModelInfo
52+
from lighteval.models.model_input import GenerationParameters
5153
from lighteval.models.model_output import GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse
5254
from lighteval.tasks.requests import (
5355
GreedyUntilRequest,
@@ -78,6 +80,11 @@
7880
class ServerlessEndpointModelConfig:
7981
model_name: str
8082
add_special_tokens: bool = True
83+
generation_parameters: GenerationParameters = None
84+
85+
def __post_init__(self):
86+
if not self.generation_parameters:
87+
self.generation_parameters = GenerationParameters()
8188

8289
@classmethod
8390
def from_path(cls, path: str) -> "ServerlessEndpointModelConfig":
@@ -106,6 +113,7 @@ class InferenceEndpointModelConfig:
106113
namespace: str = None # The namespace under which to launch the endpoint. Defaults to the current user's namespace
107114
image_url: str = None
108115
env_vars: dict = None
116+
generation_parameters: GenerationParameters = None
109117

110118
def __post_init__(self):
111119
# xor operator, one is None but not the other
@@ -117,6 +125,9 @@ def __post_init__(self):
117125
if not (self.endpoint_name is None) ^ int(self.model_name is None):
118126
raise ValueError("You need to set either endpoint_name or model_name (but not both).")
119127

128+
if not self.generation_parameters:
129+
self.generation_parameters = GenerationParameters()
130+
120131
@classmethod
121132
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
122133
"""Load configuration for inference endpoint model from YAML file path.
@@ -305,6 +316,8 @@ def __init__( # noqa: C901
305316
model_dtype=getattr(config, "model_dtype", "default"),
306317
model_size=-1,
307318
)
319+
self.generation_parameters = config.generation_parameters
320+
self.generation_config = TextGenerationInputGenerateParameters(**self.generation_parameters.to_tgi_ie_dict())
308321

309322
@staticmethod
310323
def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None):
@@ -388,16 +401,17 @@ def _async_process_request(
388401
) -> Coroutine[None, list[TextGenerationOutput], str]:
389402
# Todo: add an option to launch with conversational instead for chat prompts
390403
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
391-
generated_text = self.async_client.text_generation(
392-
prompt=context,
404+
generation_config: TextGenerationInputGenerateParameters = replace(
405+
self.generation_config,
406+
stop=stop_tokens,
407+
max_new_tokens=max_tokens,
393408
details=True,
394409
decoder_input_details=True,
395410
grammar=grammar,
396-
max_new_tokens=max_tokens,
397-
stop_sequences=stop_tokens,
398-
# truncate=,
399411
)
400412

413+
generated_text = self.async_client.text_generation(prompt=context, generation_config=generation_config)
414+
401415
return generated_text
402416

403417
def _process_request(
@@ -409,14 +423,18 @@ def _process_request(
409423
) -> TextGenerationOutput:
410424
# Todo: add an option to launch with conversational instead for chat prompts
411425
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
412-
generated_text = self.client.text_generation(
413-
prompt=context,
426+
generation_config: TextGenerationInputGenerateParameters = replace(
427+
self.generation_config,
428+
stop=stop_tokens,
429+
max_new_tokens=max_tokens,
414430
details=True,
415431
decoder_input_details=True,
416432
grammar=grammar,
417-
max_new_tokens=max_tokens,
418-
stop_sequences=stop_tokens,
419-
# truncate=,
433+
)
434+
435+
generated_text = self.client.text_generation(
436+
prompt=context,
437+
generation_config=generation_config,
420438
)
421439

422440
return generated_text

src/lighteval/models/endpoints/openai_model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
3333
from lighteval.models.abstract_model import LightevalModel
3434
from lighteval.models.endpoints.endpoint_model import ModelInfo
35+
from lighteval.models.model_input import GenerationParameters
3536
from lighteval.models.model_output import (
3637
GenerativeResponse,
3738
LoglikelihoodResponse,
@@ -62,14 +63,30 @@
6263
@dataclass
6364
class OpenAIModelConfig:
6465
model: str
66+
generation_parameters: GenerationParameters = None
67+
68+
def __post_init__(self):
69+
if not self.generation_parameters:
70+
self.generation_parameters = GenerationParameters()
71+
72+
@classmethod
73+
def from_path(cls, path: str) -> "OpenAIModelConfig":
74+
import yaml
75+
76+
with open(path, "r") as f:
77+
config = yaml.safe_load(f)["model"]
78+
generation_parameters = GenerationParameters.from_dict(config)
79+
return cls(model=config["model_name"], generation_parameters=generation_parameters)
6580

6681

6782
class OpenAIClient(LightevalModel):
6883
_DEFAULT_MAX_LENGTH: int = 4096
6984

70-
def __init__(self, config, env_config) -> None:
85+
def __init__(self, config: OpenAIModelConfig, env_config) -> None:
7186
api_key = os.environ["OPENAI_API_KEY"]
7287
self.client = OpenAI(api_key=api_key)
88+
self.generation_parameters = config.generation_parameters
89+
self.sampling_params = self.generation_parameters.to_vllm_openai_dict()
7390

7491
self.model_info = ModelInfo(
7592
model_name=config.model,
@@ -96,6 +113,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b
96113
logprobs=return_logits,
97114
logit_bias=logit_bias,
98115
n=num_samples,
116+
**self.sampling_params,
99117
)
100118
return response
101119
except Exception as e:

0 commit comments

Comments
 (0)