-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathinterface.py
179 lines (137 loc) · 5.28 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from __future__ import annotations
import abc
import asyncio
import enum
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from ..agent_output import AgentOutputSchema
from ..handoffs import Handoff
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
from ..tool import Tool
if TYPE_CHECKING:
from ..model_settings import ModelSettings
class ModelTracing(enum.Enum):
DISABLED = 0
"""Tracing is disabled entirely."""
ENABLED = 1
"""Tracing is enabled, and all data is included."""
ENABLED_WITHOUT_DATA = 2
"""Tracing is enabled, but inputs/outputs are not included."""
def is_disabled(self) -> bool:
return self == ModelTracing.DISABLED
def include_data(self) -> bool:
return self == ModelTracing.ENABLED
@dataclass
class ModelRetrySettings:
"""Settings for retrying model calls on failure.
This class helps manage backoff and retry logic when API calls fail.
"""
max_retries: int = 3
"""Maximum number of retries to attempt."""
initial_backoff_seconds: float = 1.0
"""Initial backoff time in seconds before the first retry."""
max_backoff_seconds: float = 30.0
"""Maximum backoff time in seconds between retries."""
backoff_multiplier: float = 2.0
"""Multiplier for backoff time after each retry."""
retryable_status_codes: list[int] = field(default_factory=lambda: [429, 500, 502, 503, 504])
"""HTTP status codes that should trigger a retry."""
async def execute_with_retry(
self,
operation: Callable[[], Any],
should_retry: Callable[[Exception], bool] | None = None
) -> Any:
"""Execute an operation with retry logic.
Args:
operation: Async function to execute
should_retry: Optional function to determine if an exception should trigger a retry
Returns:
The result of the operation if successful
Raises:
The last exception encountered if all retries fail
"""
last_exception = None
backoff = self.initial_backoff_seconds
for attempt in range(self.max_retries + 1):
try:
return await operation()
except Exception as e:
last_exception = e
# Check if we should retry
if attempt >= self.max_retries:
break
should_retry_exception = True
if should_retry is not None:
should_retry_exception = should_retry(e)
if not should_retry_exception:
break
# Wait before retrying
await asyncio.sleep(backoff)
backoff = min(backoff * self.backoff_multiplier, self.max_backoff_seconds)
if last_exception:
raise last_exception
# This should never happen, but just in case
raise RuntimeError("Retry logic failed in an unexpected way")
class Model(abc.ABC):
"""The base interface for calling an LLM."""
@abc.abstractmethod
async def get_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
tracing: ModelTracing,
) -> ModelResponse:
"""Get a response from the model.
Args:
system_instructions: The system instructions to use.
input: The input items to the model, in OpenAI Responses format.
model_settings: The model settings to use.
tools: The tools available to the model.
output_schema: The output schema to use.
handoffs: The handoffs available to the model.
tracing: Tracing configuration.
Returns:
The full model response.
"""
pass
@abc.abstractmethod
def stream_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
tracing: ModelTracing,
) -> AsyncIterator[TResponseStreamEvent]:
"""Stream a response from the model.
Args:
system_instructions: The system instructions to use.
input: The input items to the model, in OpenAI Responses format.
model_settings: The model settings to use.
tools: The tools available to the model.
output_schema: The output schema to use.
handoffs: The handoffs available to the model.
tracing: Tracing configuration.
Returns:
An iterator of response stream events, in OpenAI Responses format.
"""
pass
class ModelProvider(abc.ABC):
"""The base interface for a model provider.
Model provider is responsible for looking up Models by name.
"""
@abc.abstractmethod
def get_model(self, model_name: str | None) -> Model:
"""Get a model by name.
Args:
model_name: The name of the model to get.
Returns:
The model.
"""