Skip to content
This repository was archived by the owner on Jan 11, 2025. It is now read-only.

Commit f907a34

Browse files
Test inference endpoint model config parsing from path (huggingface#434)
* Add example model config for existing endpoint * Test InferenceEndpointModelConfig.from_path * Comment default main branch in example * Fix typo * Delete unused add_special_tokens param in endpoint example config * Fix typo * Implement InferenceEndpointModelConfig.from_path * Use InferenceEndpointModelConfig.from_path * Refactor InferenceEndpointModelConfig.from_path * Align docs
1 parent de8dba3 commit f907a34

File tree

6 files changed

+105
-31
lines changed

6 files changed

+105
-31
lines changed

docs/source/evaluate-the-model-on-a-server-or-container.mdx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ model:
3131
# endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
3232
# reuse_existing: true # defaults to false; if true, ignore all params in instance, and don't delete the endpoint after evaluation
3333
model_name: "meta-llama/Llama-2-7b-hf"
34-
revision: "main"
34+
# revision: "main" # defaults to "main"
3535
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
3636
instance:
3737
accelerator: "gpu"
@@ -45,8 +45,6 @@ model:
4545
image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models.
4646
env_vars:
4747
null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048`
48-
generation:
49-
add_special_tokens: true
5048
```
5149

5250
### Text Generation Inference (TGI)

examples/model_configs/endpoint_model.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ model:
44
# endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
55
# reuse_existing: true # defaults to false; if true, ignore all params in instance, and don't delete the endpoint after evaluation
66
model_name: "meta-llama/Llama-2-7b-hf"
7-
revision: "main"
7+
# revision: "main" # defaults to "main"
88
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
99
instance:
1010
accelerator: "gpu"
@@ -14,9 +14,7 @@ model:
1414
instance_size: "x1"
1515
framework: "pytorch"
1616
endpoint_type: "protected"
17-
namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace
17+
namespace: null # The namespace under which to launch the endpoint. Defaults to the current user's namespace
1818
image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models.
1919
env_vars:
2020
null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048`
21-
generation:
22-
add_special_tokens: true
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
model:
2+
base_params:
3+
# Pass either model_name, or endpoint_name and true reuse_existing
4+
endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
5+
reuse_existing: true # defaults to false; if true, ignore all params in instance, and don't delete the endpoint after evaluation

src/lighteval/main_endpoint.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def inference_endpoint(
198198
"""
199199
Evaluate models using inference-endpoints as backend.
200200
"""
201-
import yaml
202201

203202
from lighteval.logging.evaluation_tracker import EvaluationTracker
204203
from lighteval.models.endpoints.endpoint_model import (
@@ -220,31 +219,11 @@ def inference_endpoint(
220219

221220
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote
222221

223-
with open(model_config_path, "r") as f:
224-
config = yaml.safe_load(f)["model"]
225-
226222
# Find a way to add this back
227223
# if config["base_params"].get("endpoint_name", None):
228224
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
229-
all_params = {
230-
"model_name": config["base_params"].get("model_name", None),
231-
"endpoint_name": config["base_params"].get("endpoint_name", None),
232-
"model_dtype": config["base_params"].get("dtype", None),
233-
"revision": config["base_params"].get("revision", None) or "main",
234-
"reuse_existing": config["base_params"].get("reuse_existing"),
235-
"accelerator": config.get("instance", {}).get("accelerator", None),
236-
"region": config.get("instance", {}).get("region", None),
237-
"vendor": config.get("instance", {}).get("vendor", None),
238-
"instance_size": config.get("instance", {}).get("instance_size", None),
239-
"instance_type": config.get("instance", {}).get("instance_type", None),
240-
"namespace": config.get("instance", {}).get("namespace", None),
241-
"image_url": config.get("instance", {}).get("image_url", None),
242-
"env_vars": config.get("instance", {}).get("env_vars", None),
243-
}
244-
model_config = InferenceEndpointModelConfig(
245-
# We only initialize params which have a non default value
246-
**{k: v for k, v in all_params.items() if v is not None},
247-
)
225+
226+
model_config = InferenceEndpointModelConfig.from_path(model_config_path)
248227

249228
pipeline_params = PipelineParameters(
250229
launcher_type=parallelism_manager,

src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,21 @@ def __post_init__(self):
103103
# xor operator, one is None but not the other
104104
if (self.instance_size is None) ^ (self.instance_type is None):
105105
raise ValueError(
106-
"When creating an inference endpoint, you need to specify explicitely both instance_type and instance_size, or none of them for autoscaling."
106+
"When creating an inference endpoint, you need to specify explicitly both instance_type and instance_size, or none of them for autoscaling."
107107
)
108108

109109
if not (self.endpoint_name is None) ^ int(self.model_name is None):
110110
raise ValueError("You need to set either endpoint_name or model_name (but not both).")
111111

112+
@classmethod
113+
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
114+
import yaml
115+
116+
with open(path, "r") as f:
117+
config = yaml.safe_load(f)["model"]
118+
config["base_params"]["model_dtype"] = config["base_params"].pop("dtype", None)
119+
return cls(**config["base_params"], **config.get("instance", {}))
120+
112121
def get_dtype_args(self) -> Dict[str, str]:
113122
if self.model_dtype is None:
114123
return {}

tests/models/test_endpoint_model.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
import pytest
24+
25+
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig
26+
27+
28+
# "examples/model_configs/endpoint_model.yaml"
29+
30+
31+
class TestInferenceEndpointModelConfig:
32+
@pytest.mark.parametrize(
33+
"config_path, expected_config",
34+
[
35+
(
36+
"examples/model_configs/endpoint_model.yaml",
37+
{
38+
"model_name": "meta-llama/Llama-2-7b-hf",
39+
"revision": "main",
40+
"model_dtype": "float16",
41+
"endpoint_name": None,
42+
"reuse_existing": False,
43+
"accelerator": "gpu",
44+
"region": "eu-west-1",
45+
"vendor": "aws",
46+
"instance_type": "nvidia-a10g",
47+
"instance_size": "x1",
48+
"framework": "pytorch",
49+
"endpoint_type": "protected",
50+
"namespace": None,
51+
"image_url": None,
52+
"env_vars": None,
53+
},
54+
),
55+
(
56+
"examples/model_configs/endpoint_model_lite.yaml",
57+
{
58+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
59+
# Defaults:
60+
"revision": "main",
61+
"model_dtype": None,
62+
"endpoint_name": None,
63+
"reuse_existing": False,
64+
"accelerator": "gpu",
65+
"region": "us-east-1",
66+
"vendor": "aws",
67+
"instance_type": None,
68+
"instance_size": None,
69+
"framework": "pytorch",
70+
"endpoint_type": "protected",
71+
"namespace": None,
72+
"image_url": None,
73+
"env_vars": None,
74+
},
75+
),
76+
(
77+
"examples/model_configs/endpoint_model_reuse_existing.yaml",
78+
{"endpoint_name": "llama-2-7B-lighteval", "reuse_existing": True},
79+
),
80+
],
81+
)
82+
def test_from_path(self, config_path, expected_config):
83+
config = InferenceEndpointModelConfig.from_path(config_path)
84+
for key, value in expected_config.items():
85+
assert getattr(config, key) == value

0 commit comments

Comments
 (0)