From bf98d48bbab65811873f13b11ce1a7709554b27a Mon Sep 17 00:00:00 2001 From: mc-marcocheng Date: Wed, 25 Oct 2023 11:19:45 +0800 Subject: [PATCH] Feature: Router aembedding --- litellm/router.py | 15 ++++++++++++++- litellm/tests/test_router.py | 26 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 9756e7714aa6..e8eb12b24d3f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -112,7 +112,20 @@ def embedding(self, data["caching"] = self.cache_responses # call via litellm.embedding() return litellm.embedding(**{**data, **kwargs}) - + + async def aembedding(self, + model: str, + input: Union[str, List], + is_async: Optional[bool] = True, + **kwargs) -> Union[List[float], None]: + # pick the one that is available (lowest TPM/RPM) + deployment = self.get_available_deployment(model=model, input=input) + + data = deployment["litellm_params"] + data["input"] = input + data["caching"] = self.cache_responses + return await litellm.aembedding(**{**data, **kwargs}) + def set_model_list(self, model_list: list): self.model_list = model_list diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 1e44b5d3110c..4fb8d6cae340 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -181,3 +181,29 @@ async def get_response(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") + + +def test_aembedding_on_router(): + try: + model_list = [ + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "text-embedding-ada-002", + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + async def embedding_call(): + router = Router(model_list=model_list) + response = await router.aembedding( + model="text-embedding-ada-002", + input=["good morning from litellm", "this is another item"], + ) + print(response) + asyncio.run(embedding_call()) + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}")