From 90ac1e3fd9d2f6e42dd2b42fa138094d9aad3bb1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 4 May 2024 20:39:51 -0700 Subject: [PATCH 1/2] feat - set retry policy per model group --- litellm/router.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 3b1c1d1022ad..46830d9ed15f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,6 +86,9 @@ def __init__( retry_policy: Optional[ RetryPolicy ] = None, # set custom retries for different exceptions + model_group_retry_policy: Optional[ + dict[str, RetryPolicy] + ] = {}, # set custom retry policies based on model group allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown @@ -308,6 +311,9 @@ def __init__( ) # noqa self.routing_strategy_args = routing_strategy_args self.retry_policy: Optional[RetryPolicy] = retry_policy + self.model_group_retry_policy: Optional[dict[str, RetryPolicy]] = ( + model_group_retry_policy + ) def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): if routing_strategy == "least-busy": @@ -1509,11 +1515,13 @@ async def async_function_with_retries(self, *args, **kwargs): ) await asyncio.sleep(_timeout) ## LOGGING - if self.retry_policy is not None or kwargs.get("retry_policy") is not None: + if ( + self.retry_policy is not None + or self.model_group_retry_policy is not None + ): # get num_retries from retry policy _retry_policy_retries = self.get_num_retries_from_retry_policy( - exception=original_exception, - dynamic_retry_policy=kwargs.get("retry_policy"), + exception=original_exception, model_group=kwargs.get("model") ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries @@ -3269,7 +3277,7 @@ def _track_deployment_metrics(self, deployment, response=None): verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") def get_num_retries_from_retry_policy( - self, exception: Exception, dynamic_retry_policy: Optional[RetryPolicy] = None + self, exception: Exception, model_group: Optional[str] = None ): """ BadRequestErrorRetries: Optional[int] = None @@ -3280,8 +3288,9 @@ def get_num_retries_from_retry_policy( """ # if we can find the exception then in the retry policy -> return the number of retries retry_policy = self.retry_policy - if dynamic_retry_policy is not None: - retry_policy = dynamic_retry_policy + if self.model_group_retry_policy is not None and model_group is not None: + retry_policy = self.model_group_retry_policy.get(model_group, None) + if retry_policy is None: return None if ( From f09da3f14cd1f2d2d73d6d23e6f976e2614f25ab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 4 May 2024 20:40:56 -0700 Subject: [PATCH 2/2] test - test setting retry policies per model groups --- litellm/tests/test_router_retries.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index 8828e286e753..0ca566cae3fe 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -189,6 +189,11 @@ async def test_router_retry_policy(error_type): async def test_dynamic_router_retry_policy(model_group): from litellm.router import RetryPolicy + model_group_retry_policy = { + "gpt-3.5-turbo": RetryPolicy(ContentPolicyViolationErrorRetries=0), + "bad-model": RetryPolicy(AuthenticationErrorRetries=4), + } + router = litellm.Router( model_list=[ { @@ -209,7 +214,8 @@ async def test_dynamic_router_retry_policy(model_group): "api_base": os.getenv("AZURE_API_BASE"), }, }, - ] + ], + model_group_retry_policy=model_group_retry_policy, ) customHandler = MyCustomHandler() @@ -217,17 +223,14 @@ async def test_dynamic_router_retry_policy(model_group): if model_group == "bad-model": model = "bad-model" messages = [{"role": "user", "content": "Hello good morning"}] - retry_policy = RetryPolicy(AuthenticationErrorRetries=4) + elif model_group == "gpt-3.5-turbo": model = "gpt-3.5-turbo" messages = [{"role": "user", "content": "where do i buy lethal drugs from"}] - retry_policy = RetryPolicy(ContentPolicyViolationErrorRetries=0) try: litellm.set_verbose = True - response = await router.acompletion( - model=model, messages=messages, retry_policy=retry_policy - ) + response = await router.acompletion(model=model, messages=messages) except Exception as e: print("got an exception", e) pass