Skip to content

Commit

Permalink
feat: backoff as global or per-model config.
Browse files Browse the repository at this point in the history
  • Loading branch information
acompa committed Dec 19, 2024
1 parent 804e89d commit d0e1af0
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 6 deletions.
6 changes: 6 additions & 0 deletions notdiamond/_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def init(
model_messages: Dict[str, OpenAIMessagesType] = None,
api_key: Union[str, None] = None,
async_mode: bool = False,
backoff: Union[float, Dict[str, float]] = 2.0,
) -> RetryManager:
"""
Entrypoint for fallback and retry features without changing existing code.
Expand Down Expand Up @@ -51,6 +52,8 @@ def init(
Not Diamond API key for authentication. Unused for now - will offer logging and metrics in the future.
async_mode (bool):
Whether to manage clients as async.
backoff (Union[float, Dict[str, float]]):
Backoff factor for exponential backoff. Can be configured globally or per model.
Returns:
RetryManager: Manager object that handles retries and fallbacks. Not required for usage.
Expand Down Expand Up @@ -90,6 +93,7 @@ def init(
"azure/gpt-4o": [{"role": "user", "content": "Here is a prompt for Azure."}],
},
api_key="sk-...",
backoff=2.0,
)
# ...continue existing workflow code...
Expand Down Expand Up @@ -121,6 +125,7 @@ def init(
timeout=timeout,
model_messages=model_messages,
api_key=api_key,
backoff=backoff,
)
]
else:
Expand All @@ -132,6 +137,7 @@ def init(
timeout=timeout,
model_messages=model_messages,
api_key=api_key,
backoff=backoff,
)
for cc in client
]
Expand Down
33 changes: 27 additions & 6 deletions notdiamond/toolkit/_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
timeout: Union[float, Dict[str, float]] = 60.0,
model_messages: OpenAIMessagesType = {},
api_key: Union[str, None] = None,
backoff: float = 2.0,
backoff: Union[float, Dict[str, float]] = 2.0,
):
"""
Args:
Expand All @@ -103,8 +103,8 @@ def __init__(
The messages to send to each model. Prepended to any messages passed to the `create` method.
api_key: str | None
Not Diamond API key to use for logging. Currently unused.
backoff: float
The backoff factor for the retry logic.
backoff: float | Dict[str, float]
The backoff factor for the retry logic. Configured globally or per-model.
"""
self._client = client

Expand Down Expand Up @@ -152,7 +152,15 @@ def __init__(
if model_messages
else {}
)
self._backoff = backoff
self._backoff = (
{
m.split("/")[-1]: b
for m, b in backoff.items()
if self._model_client_match(m)
}
if isinstance(backoff, dict)
else backoff
)

self._api_key = api_key
self._nd_client = None
Expand Down Expand Up @@ -210,6 +218,19 @@ def get_max_retries(self, target_model: Optional[str] = None) -> int:
out = self._max_retries.get(target_model)
return out

def get_backoff(self, target_model: Optional[str] = None) -> float:
"""
Get the configured backoff (if per-model, for the target model).
"""
out = self._backoff
if isinstance(self._backoff, dict):
if not target_model:
raise ValueError(
"target_model must be provided if backoff is a dict"
)
out = self._backoff.get(target_model)
return out

def _update_model_kwargs(
self,
kwargs: Dict[str, Any],
Expand Down Expand Up @@ -258,7 +279,7 @@ async def async_wrapper(*args, **kwargs):
raise _RetryWrapperException([target_model], e)

attempt += 1
await asyncio.sleep(self._backoff**attempt)
await asyncio.sleep(self.get_backoff(target_model) ** attempt)

@wraps(func)
def sync_wrapper(*args, **kwargs):
Expand All @@ -279,7 +300,7 @@ def sync_wrapper(*args, **kwargs):
raise _RetryWrapperException([target_model], e)

attempt += 1
time.sleep(self._backoff**attempt)
time.sleep(self.get_backoff(target_model) ** attempt)

if isinstance(self, AsyncRetryWrapper):
return async_wrapper
Expand Down
Loading

0 comments on commit d0e1af0

Please sign in to comment.