diff --git a/notdiamond/_init.py b/notdiamond/_init.py index 37fc7f6..bcd21ed 100644 --- a/notdiamond/_init.py +++ b/notdiamond/_init.py @@ -19,6 +19,7 @@ def init( model_messages: Dict[str, OpenAIMessagesType] = None, api_key: Union[str, None] = None, async_mode: bool = False, + backoff: Union[float, Dict[str, float]] = 2.0, ) -> RetryManager: """ Entrypoint for fallback and retry features without changing existing code. @@ -51,6 +52,8 @@ def init( Not Diamond API key for authentication. Unused for now - will offer logging and metrics in the future. async_mode (bool): Whether to manage clients as async. + backoff (Union[float, Dict[str, float]]): + Backoff factor for exponential backoff per each retry. Can be configured globally or per model. Returns: RetryManager: Manager object that handles retries and fallbacks. Not required for usage. @@ -90,6 +93,7 @@ def init( "azure/gpt-4o": [{"role": "user", "content": "Here is a prompt for Azure."}], }, api_key="sk-...", + backoff=2.0, ) # ...continue existing workflow code... @@ -121,6 +125,7 @@ def init( timeout=timeout, model_messages=model_messages, api_key=api_key, + backoff=backoff, ) ] else: @@ -132,6 +137,7 @@ def init( timeout=timeout, model_messages=model_messages, api_key=api_key, + backoff=backoff, ) for cc in client ] diff --git a/notdiamond/toolkit/_retry.py b/notdiamond/toolkit/_retry.py index 41dd091..dbcb964 100644 --- a/notdiamond/toolkit/_retry.py +++ b/notdiamond/toolkit/_retry.py @@ -87,7 +87,7 @@ def __init__( timeout: Union[float, Dict[str, float]] = 60.0, model_messages: OpenAIMessagesType = {}, api_key: Union[str, None] = None, - backoff: float = 2.0, + backoff: Union[float, Dict[str, float]] = 2.0, ): """ Args: @@ -103,8 +103,8 @@ def __init__( The messages to send to each model. Prepended to any messages passed to the `create` method. api_key: str | None Not Diamond API key to use for logging. Currently unused. - backoff: float - The backoff factor for the retry logic. + backoff: float | Dict[str, float] + Exponential backoff factor per each retry. Can be configured globally or per model. """ self._client = client @@ -152,7 +152,15 @@ def __init__( if model_messages else {} ) - self._backoff = backoff + self._backoff = ( + { + m.split("/")[-1]: b + for m, b in backoff.items() + if self._model_client_match(m) + } + if isinstance(backoff, dict) + else backoff + ) self._api_key = api_key self._nd_client = None @@ -210,6 +218,19 @@ def get_max_retries(self, target_model: Optional[str] = None) -> int: out = self._max_retries.get(target_model) return out + def get_backoff(self, target_model: Optional[str] = None) -> float: + """ + Get the configured backoff (if per-model, for the target model). + """ + out = self._backoff + if isinstance(self._backoff, dict): + if not target_model: + raise ValueError( + "target_model must be provided if backoff is a dict" + ) + out = self._backoff.get(target_model) + return out + def _update_model_kwargs( self, kwargs: Dict[str, Any], @@ -258,7 +279,7 @@ async def async_wrapper(*args, **kwargs): raise _RetryWrapperException([target_model], e) attempt += 1 - await asyncio.sleep(self._backoff**attempt) + await asyncio.sleep(self.get_backoff(target_model) ** attempt) @wraps(func) def sync_wrapper(*args, **kwargs): @@ -279,7 +300,7 @@ def sync_wrapper(*args, **kwargs): raise _RetryWrapperException([target_model], e) attempt += 1 - time.sleep(self._backoff**attempt) + time.sleep(self.get_backoff(target_model) ** attempt) if isinstance(self, AsyncRetryWrapper): return async_wrapper diff --git a/tests/test_toolkit/cassettes/test_retry/test_multi_model_backoff_config.yaml b/tests/test_toolkit/cassettes/test_retry/test_multi_model_backoff_config.yaml new file mode 100644 index 0000000..3fccdd4 --- /dev/null +++ b/tests/test_toolkit/cassettes/test_retry/test_multi_model_backoff_config.yaml @@ -0,0 +1,297 @@ +interactions: +- request: + body: '{"messages": [{"role": "user", "content": "Hello, how are you?"}], "model": + "gpt-4o-mini"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '90' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.55.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.55.1 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.10 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided: + broken-a**-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": + \"invalid_api_key\"\n }\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f437524583adb1d-MIA + Connection: + - keep-alive + Content-Length: + - '264' + Content-Type: + - application/json; charset=utf-8 + Date: + - Thu, 19 Dec 2024 01:00:46 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Z1WRLdGzN7h0s2NBGBRjVat3bEygGCIzrUj8yy1S01o-1734570046-1.0.1.1-8fphVW32naBztpk7fd1ayypn_LdmuDVBXRA6y44lVoosGqJ4Q4U6eddkhDuiolKTPmYy2CVwAuwePaZnBfZ4dQ; + path=/; expires=Thu, 19-Dec-24 01:30:46 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=3CPZcOj.Dsl.txLCvGozjseBU6qsGQPB_p4jTDHkkSk-1734570046258-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + X-Content-Type-Options: + - nosniff + alt-svc: + - h3=":443"; ma=86400 + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Origin + x-request-id: + - req_cc0d2fc697ff6e5cc778b49127d22867 + status: + code: 401 + message: Unauthorized +- request: + body: '{"messages": [{"role": "user", "content": "Hello, how are you?"}], "model": + "gpt-4o-mini"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '90' + content-type: + - application/json + cookie: + - __cf_bm=Z1WRLdGzN7h0s2NBGBRjVat3bEygGCIzrUj8yy1S01o-1734570046-1.0.1.1-8fphVW32naBztpk7fd1ayypn_LdmuDVBXRA6y44lVoosGqJ4Q4U6eddkhDuiolKTPmYy2CVwAuwePaZnBfZ4dQ; + _cfuvid=3CPZcOj.Dsl.txLCvGozjseBU6qsGQPB_p4jTDHkkSk-1734570046258-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.55.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.55.1 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.10 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided: + broken-a**-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": + \"invalid_api_key\"\n }\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f43752b8a65db1d-MIA + Connection: + - keep-alive + Content-Length: + - '264' + Content-Type: + - application/json; charset=utf-8 + Date: + - Thu, 19 Dec 2024 01:00:47 GMT + Server: + - cloudflare + X-Content-Type-Options: + - nosniff + alt-svc: + - h3=":443"; ma=86400 + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Origin + x-request-id: + - req_b662a40a4eed6b1293a08976609311e8 + status: + code: 401 + message: Unauthorized +- request: + body: '{"messages": [{"role": "user", "content": "Hello, how are you?"}], "model": + "gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '85' + content-type: + - application/json + cookie: + - __cf_bm=Z1WRLdGzN7h0s2NBGBRjVat3bEygGCIzrUj8yy1S01o-1734570046-1.0.1.1-8fphVW32naBztpk7fd1ayypn_LdmuDVBXRA6y44lVoosGqJ4Q4U6eddkhDuiolKTPmYy2CVwAuwePaZnBfZ4dQ; + _cfuvid=3CPZcOj.Dsl.txLCvGozjseBU6qsGQPB_p4jTDHkkSk-1734570046258-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.55.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.55.1 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.10 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided: + broken-a**-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": + \"invalid_api_key\"\n }\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f43752c4b87db1d-MIA + Connection: + - keep-alive + Content-Length: + - '264' + Content-Type: + - application/json; charset=utf-8 + Date: + - Thu, 19 Dec 2024 01:00:47 GMT + Server: + - cloudflare + X-Content-Type-Options: + - nosniff + alt-svc: + - h3=":443"; ma=86400 + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Origin + x-request-id: + - req_a2c04db895a2cf6098c53b60f35813bf + status: + code: 401 + message: Unauthorized +- request: + body: '{"messages": [{"role": "user", "content": "Hello, how are you?"}], "model": + "gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '85' + content-type: + - application/json + cookie: + - __cf_bm=Z1WRLdGzN7h0s2NBGBRjVat3bEygGCIzrUj8yy1S01o-1734570046-1.0.1.1-8fphVW32naBztpk7fd1ayypn_LdmuDVBXRA6y44lVoosGqJ4Q4U6eddkhDuiolKTPmYy2CVwAuwePaZnBfZ4dQ; + _cfuvid=3CPZcOj.Dsl.txLCvGozjseBU6qsGQPB_p4jTDHkkSk-1734570046258-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.55.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.55.1 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.10 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided: + broken-a**-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": + \"invalid_api_key\"\n }\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f4375399e0cdb1d-MIA + Connection: + - keep-alive + Content-Length: + - '264' + Content-Type: + - application/json; charset=utf-8 + Date: + - Thu, 19 Dec 2024 01:00:49 GMT + Server: + - cloudflare + X-Content-Type-Options: + - nosniff + alt-svc: + - h3=":443"; ma=86400 + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Origin + x-request-id: + - req_400b1cd4a4a2f49a63fb593bf45fb889 + status: + code: 401 + message: Unauthorized +version: 1 diff --git a/tests/test_toolkit/test_retry.py b/tests/test_toolkit/test_retry.py index cc3f8dc..e7fa772 100644 --- a/tests/test_toolkit/test_retry.py +++ b/tests/test_toolkit/test_retry.py @@ -728,6 +728,63 @@ def test_multi_model_timeout_config(model_messages, api_key): ) +@pytest.mark.vcr +def test_multi_model_backoff_config(model_messages, api_key): + oai_client = OpenAI(api_key="broken-api-key") + models = ["openai/gpt-4o-mini", "openai/gpt-4o"] + backoff = {"openai/gpt-4o-mini": 1.0, "openai/gpt-4o": 2.0} + + with patch.object( + oai_client.chat.completions, + "create", + wraps=oai_client.chat.completions.create, + ) as mock_create: + wrapper = RetryWrapper( + client=oai_client, + models=models, + max_retries=2, + timeout=60.0, + backoff=backoff, + model_messages=model_messages, + api_key=api_key, + ) + + with patch.object( + wrapper, + "get_backoff", + wraps=wrapper.get_backoff, + ) as mock_get_backoff: + _ = RetryManager(models, [wrapper]) + + with pytest.raises(AuthenticationError): + oai_client.chat.completions.create( + model="gpt-4o-mini", + messages=model_messages["gpt-4o-mini"], + ) + + mock_create.assert_has_calls( + [ + call( + model="gpt-4o-mini", + messages=model_messages["gpt-4o-mini"], + timeout=60.0, + ), + call( + model="gpt-4o", + messages=model_messages["gpt-4o"], + timeout=60.0, + ), + ] + ) + assert mock_get_backoff.call_count == 2 + mock_get_backoff.assert_has_calls( + [ + call("gpt-4o-mini"), + call("gpt-4o"), + ] + ) + + @pytest.mark.asyncio @pytest.mark.vcr @pytest.mark.timeout(10)