-
Notifications
You must be signed in to change notification settings - Fork 275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Configurable models for NeurIPS Efficiency Challenge #1861
Changes from 3 commits
8fd3433
21146fa
039524d
f644944
a97a362
d397bc3
a1dafb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import os | ||
from typing import Dict, Optional, List | ||
from dataclasses import dataclass | ||
|
||
|
@@ -6,32 +7,52 @@ | |
|
||
from helm.common.hierarchical_logger import hlog | ||
from helm.common.object_spec import ObjectSpec | ||
from helm.proxy.models import ALL_MODELS, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, MODEL_NAME_TO_MODEL, TEXT_MODEL_TAG, Model | ||
|
||
|
||
MODEL_DEPLOYMENTS_FILE = "model_deployments.yaml" | ||
|
||
|
||
class ClientSpec(ObjectSpec): | ||
pass | ||
|
||
|
||
class WindowServiceSpec(ObjectSpec): | ||
pass | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ModelDeployment: | ||
"""A model deployment is an accessible instance of this model (e.g. a hosted endpoint). | ||
|
||
A model can have model deployments.""" | ||
A model can have multiple model deployments.""" | ||
|
||
name: str | ||
"""Name of the model deployment.""" | ||
|
||
model_name: str | ||
"""Name of the model that this model deployment is for.""" | ||
|
||
client_spec: ClientSpec | ||
"""Specification for instantiating the client for this model deployment.""" | ||
|
||
max_sequence_length: Optional[int] | ||
"""Maximum equence length for this model deployment.""" | ||
model_name: Optional[str] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we should put an example in the docstring so people have a sense of what the difference between name and model_name is, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going to defer this until we actually implement the multi-deployments feature, which isn't on the roadmap yet. |
||
"""Name of the model that this model deployment is for. | ||
|
||
If unset, defaults to the the same value as `name`.""" | ||
|
||
tokenizer_name: Optional[str] | ||
"""Tokenizer for this model deployment.""" | ||
tokenizer_name: Optional[str] = None | ||
"""Tokenizer for this model deployment. | ||
|
||
If unset, auto-inferred by the WindowService.""" | ||
|
||
window_service_spec: Optional[WindowServiceSpec] = None | ||
"""Specification for instantiating the window service for this model deplooyment""" | ||
|
||
max_sequence_length: Optional[int] = None | ||
"""Maximum sequence length for this model deployment.""" | ||
|
||
max_request_length: Optional[int] = None | ||
"""Maximum request length for this model deployment. | ||
|
||
If unset, defaults to the same value as max_sequence_length.""" | ||
|
||
|
||
@dataclass(frozen=True) | ||
|
@@ -49,8 +70,27 @@ def register_model_deployments_from_path(path: str) -> None: | |
raw = yaml.safe_load(f) | ||
model_deployments: ModelDeployments = cattrs.structure(raw, ModelDeployments) | ||
for model_deployment in model_deployments.model_deployments: | ||
hlog(f"Registered model deployment {model_deployment.name}") | ||
_name_to_model_deployment[model_deployment.name] = model_deployment | ||
|
||
# Auto-register a model with this name if none exists | ||
model_name = model_deployment.model_name or model_deployment.name | ||
if model_name not in MODEL_NAME_TO_MODEL: | ||
model = Model( | ||
group="none", | ||
name=model_name, | ||
tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG], | ||
) | ||
MODEL_NAME_TO_MODEL[model_name] = model | ||
ALL_MODELS.append(model) | ||
hlog(f"Registered default metadata for model {model_name}") | ||
|
||
|
||
def maybe_register_model_deployments_from_base_path(base_path: str) -> None: | ||
path = os.path.join(base_path, MODEL_DEPLOYMENTS_FILE) | ||
if os.path.exists(path): | ||
register_model_deployments_from_path(path) | ||
|
||
|
||
def get_model_deployment(name: str) -> Optional[ModelDeployment]: | ||
return _name_to_model_deployment.get(name) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import os | ||
from typing import Optional, List | ||
from dataclasses import dataclass, field | ||
from datetime import date | ||
|
@@ -8,6 +9,9 @@ | |
from helm.proxy.models import ALL_MODELS, MODEL_NAME_TO_MODEL, Model | ||
|
||
|
||
MODEL_METADATA_FILE = "model_metadata.yaml" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ModelMetadata: | ||
name: str | ||
|
@@ -58,3 +62,9 @@ def register_model_metadata_from_path(path: str) -> None: | |
) | ||
MODEL_NAME_TO_MODEL[model_metadata.name] = model | ||
ALL_MODELS.append(model) | ||
|
||
|
||
def maybe_register_model_metadata_from_base_path(base_path: str) -> None: | ||
path = os.path.join(base_path, MODEL_METADATA_FILE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docstring. |
||
if os.path.exists(path): | ||
register_model_metadata_from_path(path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2516,7 +2516,13 @@ def construct_run_specs(spec: ObjectSpec) -> List[RunSpec]: | |
] | ||
|
||
def alter_run_spec(run_spec: RunSpec) -> RunSpec: | ||
model = get_model(run_spec.adapter_spec.model) | ||
try: | ||
model = get_model(run_spec.adapter_spec.model) | ||
except ValueError: | ||
# Models registered from configs cannot have expanders applied to them, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ValueError means that the model has not been loaded yet? I was a bit confused by the comment at first, maybe connect the dots a bit more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It means the model has not been registered yet. I'll add more docs. |
||
# because the models will not have been registered yet at this point. | ||
# TODO: Figure out a cleaner way to deal with this. | ||
return run_spec | ||
# For models that strip newlines, when we're generating, we need to set | ||
# the delimiter to be '###' so we stop properly. | ||
if NO_NEWLINES_TAG in model.tags and run_spec.adapter_spec.method in ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import importlib | ||
import dataclasses | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Optional, Tuple, Hashable, Type | ||
import inspect | ||
from typing import Any, Callable, Dict, Optional, Tuple, Hashable, Type, TypeVar | ||
|
||
|
||
@dataclass(frozen=True) | ||
|
@@ -32,13 +34,44 @@ def get_class_by_name(full_class_name: str) -> Type[Any]: | |
return getattr(importlib.import_module(module_name), class_name) | ||
|
||
|
||
def create_object(spec: ObjectSpec, additional_args: Optional[Dict[str, Any]] = None): | ||
ObjectSpecT = TypeVar("ObjectSpecT", bound=ObjectSpec) | ||
|
||
|
||
def inject_object_spec_args( | ||
spec: ObjectSpecT, | ||
constant_bindings: Optional[Dict[str, Any]] = None, | ||
provider_bindings: Optional[Dict[str, Callable[[], Any]]] = None, | ||
) -> ObjectSpecT: | ||
"""Return a new ObjectSpec that is a copy of the original ObjectSpec with additional arguments. | ||
|
||
The original ObjectSpec may be missing arguments for parameters that are required by the | ||
ObjectSpec's class's constructor. | ||
This function returns a new ObjectSpec with these missing parameter filled in. | ||
To do this, for every missing parameter, check look up each of the `*_bindings` arguments in order until we | ||
find one with a key matching the missing parameter's name. | ||
If found in constant_bindings, add the corresponding value to args. | ||
If found in provider_bindings, call the corresponding value and add the return values to args. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you provide an example or two of usage? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added example. |
||
This is loosely based on instance (constant) bindings and provider bindings in Guice dependency injection.""" | ||
cls = get_class_by_name(spec.class_name) | ||
init_signature = inspect.signature(cls.__init__) | ||
args = {} | ||
args.update(spec.args) | ||
for parameter_name in init_signature.parameters.keys(): | ||
if parameter_name == "self" or parameter_name in args: | ||
continue | ||
elif constant_bindings and parameter_name in constant_bindings: | ||
args[parameter_name] = constant_bindings[parameter_name] | ||
elif provider_bindings and parameter_name in provider_bindings: | ||
args[parameter_name] = provider_bindings[parameter_name]() | ||
return dataclasses.replace(spec, args=args) | ||
|
||
|
||
def create_object(spec: ObjectSpec): | ||
"""Create the actual object given the `spec`.""" | ||
cls = get_class_by_name(spec.class_name) | ||
args = {} | ||
args.update(spec.args) | ||
if additional_args: | ||
args.update(additional_args) | ||
return cls(**args) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,10 @@ | |
from retrying import RetryError, Attempt | ||
|
||
from helm.benchmark.model_deployment_registry import get_model_deployment | ||
from helm.benchmark.tokenizer_config_registry import get_tokenizer_config | ||
from helm.common.cache import CacheConfig, MongoCacheConfig, SqliteCacheConfig | ||
from helm.common.hierarchical_logger import hlog | ||
from helm.common.object_spec import create_object | ||
from helm.common.object_spec import create_object, inject_object_spec_args | ||
from helm.common.request import Request, RequestResult | ||
from helm.common.tokenization_request import ( | ||
TokenizationRequest, | ||
|
@@ -70,18 +71,23 @@ def _get_client(self, model: str) -> Client: | |
# TODO: Migrate all clients to use model deployments | ||
model_deployment = get_model_deployment(model) | ||
if model_deployment: | ||
api_key = None | ||
if "deployments" not in self.credentials: | ||
raise AuthenticationError("Could not find key 'deployments' in credentials.conf") | ||
deployment_api_keys = self.credentials["deployments"] | ||
if model not in deployment_api_keys: | ||
raise AuthenticationError( | ||
f"Could not find key '{model}' under key 'deployments' in credentials.conf" | ||
) | ||
api_key = deployment_api_keys[model] | ||
client = create_object( | ||
model_deployment.client_spec, additional_args={"cache_config": cache_config, "api_key": api_key} | ||
|
||
def provide_api_key(): | ||
if "deployments" not in self.credentials: | ||
raise AuthenticationError("Could not find key 'deployments' in credentials.conf") | ||
deployment_api_keys = self.credentials["deployments"] | ||
if model not in deployment_api_keys: | ||
raise AuthenticationError( | ||
f"Could not find key '{model}' under key 'deployments' in credentials.conf" | ||
) | ||
return deployment_api_keys[model] | ||
|
||
client_spec = inject_object_spec_args( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you write some comments on why this injection is needed? My initial impression is that it seems a bit complicated / fancy... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a passage. Dependency injection is needed here for these reasons:
|
||
model_deployment.client_spec, | ||
constant_bindings={"cache_config": cache_config}, | ||
provider_bindings={"api_key": provide_api_key}, | ||
) | ||
client = create_object(client_spec) | ||
|
||
elif get_huggingface_model_config(model): | ||
from helm.proxy.clients.huggingface_client import HuggingFaceClient | ||
|
@@ -211,7 +217,14 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client: | |
|
||
if client is None: | ||
cache_config: CacheConfig = self._build_cache_config(organization) | ||
if get_huggingface_model_config(tokenizer): | ||
# TODO: Migrate all clients to use tokenizer configs | ||
tokenizer_config = get_tokenizer_config(tokenizer) | ||
if tokenizer_config: | ||
tokenizer_spec = inject_object_spec_args( | ||
tokenizer_config.tokenizer_spec, constant_bindings={"cache_config": cache_config} | ||
) | ||
client = create_object(tokenizer_spec) | ||
elif get_huggingface_model_config(tokenizer): | ||
from helm.proxy.clients.huggingface_client import HuggingFaceClient | ||
|
||
client = HuggingFaceClient(cache_config=cache_config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this moved down?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ordering should be that all the required parameters come first, then all the optional parameters with default arguments.