-
Notifications
You must be signed in to change notification settings - Fork 23
/
llm_loader.py
95 lines (68 loc) · 3.18 KB
/
llm_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""LLM backend libraries loader."""
import logging
from typing import Optional
from langchain.llms.base import LLM
from ols import config
from ols.app.models.config import LLMProviders, ProviderConfig
from ols.src.llms.providers.registry import LLMProvidersRegistry
logger = logging.getLogger(__name__)
class LLMConfigurationError(Exception):
"""LLM configuration is wrong."""
class UnknownProviderError(LLMConfigurationError):
"""No configuration for provider."""
class UnsupportedProviderError(LLMConfigurationError):
"""Provider is not supported."""
class ModelConfigMissingError(LLMConfigurationError):
"""No configuration exists for the requested model name."""
def resolve_provider_config(
provider: str, model: str, providers_config: LLMProviders
) -> ProviderConfig:
"""Ensure the provided inputs (provider/model) are valid in config.
Return respective provider configuration.
"""
if provider not in providers_config.providers:
raise UnknownProviderError(
f"Provider '{provider}' is not a valid provider. "
f"Valid providers are: {list(providers_config.providers.keys())}"
)
provider_config = providers_config.providers.get(provider)
if model not in provider_config.models:
raise ModelConfigMissingError(
f"Model '{model}' is not a valid model for provider '{provider}'. "
f"Valid models are: {list(provider_config.models.keys())}"
)
return provider_config
def load_llm(
provider: str, model: str, generic_llm_params: Optional[dict] = None
) -> LLM:
"""Load LLM according to input provider and model.
Args:
provider: The provider name.
model: The model name.
generic_llm_params: The optional parameters that will be converted into LLM-specific ones.
Raises:
LLMConfigurationError: If the whole provider configuration is missing.
UnsupportedProviderError: If the provider is not supported (implemented).
UnknownProviderError: If the provider is not known.
ModelConfigMissingError: If the model configuration is missing.
Example:
```python
# using the class and overriding specific parameters
generic_llm_params = {'temperature': 0.02, 'top_p': 0.95}
bare_llm = load_llm(provider="openai", model="gpt-3.5-turbo",
generic_llm_params=generic_llm_params).llm
llm_chain = LLMChain(llm=bare_llm, prompt=prompt)
```
"""
providers_config = config.config.llm_providers
if providers_config is None:
raise LLMConfigurationError("Providers configuration missing in rcsconfig.yaml")
llm_providers_reg = LLMProvidersRegistry
provider_config = resolve_provider_config(provider, model, providers_config)
if provider_config.type not in llm_providers_reg.llm_providers:
raise UnsupportedProviderError(
f"Unsupported LLM provider type '{provider_config.type}'."
)
logger.debug("loading LLM model '%s' from provider '%s'", model, provider)
llm_provider = llm_providers_reg.llm_providers[provider_config.type]
return llm_provider(model, provider_config, generic_llm_params or {}).load()