Skip to content

Commit a0254b0

Browse files
RFC: automatically use litellm if possible (#534)
## Summary This replaces the default model provider with a `MultiProvider`, which has the logic: - if the model name starts with `openai/` or doesn't contain "/", use OpenAI - if the model name starts with `litellm/`, use LiteLLM to use the appropriate model provider. It's also extensible, so users can create their own mappings. I also imagine that if we natively supported Anthropic/Gemini etc, we can add it to MultiProvider to make it work. The goal is that it should be really easy to use any model provider. Today if you pass `model="gpt-4.1"`, it works great. But `model="claude-sonnet-3.7"` doesn't. If we can make it that easy, it's a win for devx. I'm not entirely sure if this is a good idea - is it too magical? Is the API too reliant on litellm? Comments welcome. ## Test plan For now, the example. Will add unit tests if we agree its worth mergin. --------- Co-authored-by: Steven Heidel <steven@heidel.ca>
1 parent 0a3dfa0 commit a0254b0

File tree

4 files changed

+208
-2
lines changed

4 files changed

+208
-2
lines changed

Diff for: examples/model_providers/litellm_auto.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
5+
from agents import Agent, Runner, function_tool, set_tracing_disabled
6+
7+
"""This example uses the built-in support for LiteLLM. To use this, ensure you have the
8+
ANTHROPIC_API_KEY environment variable set.
9+
"""
10+
11+
set_tracing_disabled(disabled=True)
12+
13+
14+
@function_tool
15+
def get_weather(city: str):
16+
print(f"[debug] getting weather for {city}")
17+
return f"The weather in {city} is sunny."
18+
19+
20+
async def main():
21+
agent = Agent(
22+
name="Assistant",
23+
instructions="You only respond in haikus.",
24+
# We prefix with litellm/ to tell the Runner to use the LitellmModel
25+
model="litellm/anthropic/claude-3-5-sonnet-20240620",
26+
tools=[get_weather],
27+
)
28+
29+
result = await Runner.run(agent, "What's the weather in Tokyo?")
30+
print(result.final_output)
31+
32+
33+
if __name__ == "__main__":
34+
import os
35+
36+
if os.getenv("ANTHROPIC_API_KEY") is None:
37+
raise ValueError(
38+
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again."
39+
)
40+
41+
asyncio.run(main())

Diff for: src/agents/extensions/models/litellm_provider.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from ...models.interface import Model, ModelProvider
2+
from .litellm_model import LitellmModel
3+
4+
DEFAULT_MODEL: str = "gpt-4.1"
5+
6+
7+
class LitellmProvider(ModelProvider):
8+
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
9+
```python
10+
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
11+
```
12+
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
13+
14+
NOTE: API keys must be set via environment variables. If you're using models that require
15+
additional configuration (e.g. Azure API base or version), those must also be set via the
16+
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
17+
copy-pasting this class and making any modifications you need.
18+
"""
19+
20+
def get_model(self, model_name: str | None) -> Model:
21+
return LitellmModel(model_name or DEFAULT_MODEL)

Diff for: src/agents/models/multi_provider.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from __future__ import annotations
2+
3+
from openai import AsyncOpenAI
4+
5+
from ..exceptions import UserError
6+
from .interface import Model, ModelProvider
7+
from .openai_provider import OpenAIProvider
8+
9+
10+
class MultiProviderMap:
11+
"""A map of model name prefixes to ModelProviders."""
12+
13+
def __init__(self):
14+
self._mapping: dict[str, ModelProvider] = {}
15+
16+
def has_prefix(self, prefix: str) -> bool:
17+
"""Returns True if the given prefix is in the mapping."""
18+
return prefix in self._mapping
19+
20+
def get_mapping(self) -> dict[str, ModelProvider]:
21+
"""Returns a copy of the current prefix -> ModelProvider mapping."""
22+
return self._mapping.copy()
23+
24+
def set_mapping(self, mapping: dict[str, ModelProvider]):
25+
"""Overwrites the current mapping with a new one."""
26+
self._mapping = mapping
27+
28+
def get_provider(self, prefix: str) -> ModelProvider | None:
29+
"""Returns the ModelProvider for the given prefix.
30+
31+
Args:
32+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
33+
"""
34+
return self._mapping.get(prefix)
35+
36+
def add_provider(self, prefix: str, provider: ModelProvider):
37+
"""Adds a new prefix -> ModelProvider mapping.
38+
39+
Args:
40+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
41+
provider: The ModelProvider to use for the given prefix.
42+
"""
43+
self._mapping[prefix] = provider
44+
45+
def remove_provider(self, prefix: str):
46+
"""Removes the mapping for the given prefix.
47+
48+
Args:
49+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
50+
"""
51+
del self._mapping[prefix]
52+
53+
54+
class MultiProvider(ModelProvider):
55+
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
56+
mapping is:
57+
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
58+
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
59+
60+
You can override or customize this mapping.
61+
"""
62+
63+
def __init__(
64+
self,
65+
*,
66+
provider_map: MultiProviderMap | None = None,
67+
openai_api_key: str | None = None,
68+
openai_base_url: str | None = None,
69+
openai_client: AsyncOpenAI | None = None,
70+
openai_organization: str | None = None,
71+
openai_project: str | None = None,
72+
openai_use_responses: bool | None = None,
73+
) -> None:
74+
"""Create a new OpenAI provider.
75+
76+
Args:
77+
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
78+
we will use a default mapping. See the documentation for this class to see the
79+
default mapping.
80+
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
81+
the default API key.
82+
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
83+
use the default base URL.
84+
openai_client: An optional OpenAI client to use. If not provided, we will create a new
85+
OpenAI client using the api_key and base_url.
86+
openai_organization: The organization to use for the OpenAI provider.
87+
openai_project: The project to use for the OpenAI provider.
88+
openai_use_responses: Whether to use the OpenAI responses API.
89+
"""
90+
self.provider_map = provider_map
91+
self.openai_provider = OpenAIProvider(
92+
api_key=openai_api_key,
93+
base_url=openai_base_url,
94+
openai_client=openai_client,
95+
organization=openai_organization,
96+
project=openai_project,
97+
use_responses=openai_use_responses,
98+
)
99+
100+
self._fallback_providers: dict[str, ModelProvider] = {}
101+
102+
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
103+
if model_name is None:
104+
return None, None
105+
elif "/" in model_name:
106+
prefix, model_name = model_name.split("/", 1)
107+
return prefix, model_name
108+
else:
109+
return None, model_name
110+
111+
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
112+
if prefix == "litellm":
113+
from ..extensions.models.litellm_provider import LitellmProvider
114+
115+
return LitellmProvider()
116+
else:
117+
raise UserError(f"Unknown prefix: {prefix}")
118+
119+
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
120+
if prefix is None or prefix == "openai":
121+
return self.openai_provider
122+
elif prefix in self._fallback_providers:
123+
return self._fallback_providers[prefix]
124+
else:
125+
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
126+
return self._fallback_providers[prefix]
127+
128+
def get_model(self, model_name: str | None) -> Model:
129+
"""Returns a Model based on the model name. The model name can have a prefix, ending with
130+
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
131+
the OpenAI provider.
132+
133+
Args:
134+
model_name: The name of the model to get.
135+
136+
Returns:
137+
A Model.
138+
"""
139+
prefix, model_name = self._get_prefix_and_model_name(model_name)
140+
141+
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
142+
return provider.get_model(model_name)
143+
else:
144+
return self._get_fallback_provider(prefix).get_model(model_name)

Diff for: src/agents/run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .logger import logger
3535
from .model_settings import ModelSettings
3636
from .models.interface import Model, ModelProvider
37-
from .models.openai_provider import OpenAIProvider
37+
from .models.multi_provider import MultiProvider
3838
from .result import RunResult, RunResultStreaming
3939
from .run_context import RunContextWrapper, TContext
4040
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
@@ -56,7 +56,7 @@ class RunConfig:
5656
agent. The model_provider passed in below must be able to resolve this model name.
5757
"""
5858

59-
model_provider: ModelProvider = field(default_factory=OpenAIProvider)
59+
model_provider: ModelProvider = field(default_factory=MultiProvider)
6060
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
6161

6262
model_settings: ModelSettings | None = None

0 commit comments

Comments
 (0)