Skip to content

Commit

Permalink
Merge pull request #3460 from BerriAI/litellm_use_retry_policy_per_mg
Browse files Browse the repository at this point in the history
[Feat] Set a Retry Policy per model group
  • Loading branch information
ishaan-jaff authored May 5, 2024
2 parents ba06565 + f09da3f commit 713e048
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
21 changes: 15 additions & 6 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def __init__(
retry_policy: Optional[
RetryPolicy
] = None, # set custom retries for different exceptions
model_group_retry_policy: Optional[
dict[str, RetryPolicy]
] = {}, # set custom retry policies based on model group
allowed_fails: Optional[
int
] = None, # Number of times a deployment can failbefore being added to cooldown
Expand Down Expand Up @@ -308,6 +311,9 @@ def __init__(
) # noqa
self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = retry_policy
self.model_group_retry_policy: Optional[dict[str, RetryPolicy]] = (
model_group_retry_policy
)

def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
Expand Down Expand Up @@ -1509,11 +1515,13 @@ async def async_function_with_retries(self, *args, **kwargs):
)
await asyncio.sleep(_timeout)
## LOGGING
if self.retry_policy is not None or kwargs.get("retry_policy") is not None:
if (
self.retry_policy is not None
or self.model_group_retry_policy is not None
):
# get num_retries from retry policy
_retry_policy_retries = self.get_num_retries_from_retry_policy(
exception=original_exception,
dynamic_retry_policy=kwargs.get("retry_policy"),
exception=original_exception, model_group=kwargs.get("model")
)
if _retry_policy_retries is not None:
num_retries = _retry_policy_retries
Expand Down Expand Up @@ -3273,7 +3281,7 @@ def _track_deployment_metrics(self, deployment, response=None):
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")

def get_num_retries_from_retry_policy(
self, exception: Exception, dynamic_retry_policy: Optional[RetryPolicy] = None
self, exception: Exception, model_group: Optional[str] = None
):
"""
BadRequestErrorRetries: Optional[int] = None
Expand All @@ -3284,8 +3292,9 @@ def get_num_retries_from_retry_policy(
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy
if dynamic_retry_policy is not None:
retry_policy = dynamic_retry_policy
if self.model_group_retry_policy is not None and model_group is not None:
retry_policy = self.model_group_retry_policy.get(model_group, None)

if retry_policy is None:
return None
if (
Expand Down
15 changes: 9 additions & 6 deletions litellm/tests/test_router_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ async def test_router_retry_policy(error_type):
async def test_dynamic_router_retry_policy(model_group):
from litellm.router import RetryPolicy

model_group_retry_policy = {
"gpt-3.5-turbo": RetryPolicy(ContentPolicyViolationErrorRetries=0),
"bad-model": RetryPolicy(AuthenticationErrorRetries=4),
}

router = litellm.Router(
model_list=[
{
Expand All @@ -209,25 +214,23 @@ async def test_dynamic_router_retry_policy(model_group):
"api_base": os.getenv("AZURE_API_BASE"),
},
},
]
],
model_group_retry_policy=model_group_retry_policy,
)

customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
if model_group == "bad-model":
model = "bad-model"
messages = [{"role": "user", "content": "Hello good morning"}]
retry_policy = RetryPolicy(AuthenticationErrorRetries=4)

elif model_group == "gpt-3.5-turbo":
model = "gpt-3.5-turbo"
messages = [{"role": "user", "content": "where do i buy lethal drugs from"}]
retry_policy = RetryPolicy(ContentPolicyViolationErrorRetries=0)

try:
litellm.set_verbose = True
response = await router.acompletion(
model=model, messages=messages, retry_policy=retry_policy
)
response = await router.acompletion(model=model, messages=messages)
except Exception as e:
print("got an exception", e)
pass
Expand Down

0 comments on commit 713e048

Please sign in to comment.