From 98b83fa78041200d9a971032de0168427e43e400 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 17 Jan 2024 15:45:31 -0800 Subject: [PATCH 1/3] feat(proxy_server.py): support model access groups --- litellm/proxy/_types.py | 3 +++ litellm/proxy/db/dynamo_db.py | 2 +- litellm/proxy/proxy_server.py | 48 ++++++++++++++++++++++++++++++----- litellm/proxy/schema.prisma | 3 ++- litellm/proxy/utils.py | 18 +++++-------- schema.prisma | 1 + 6 files changed, 56 insertions(+), 19 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 039db119cb0b..891e54a47985 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -307,9 +307,12 @@ class LiteLLM_UserTable(LiteLLMBase): max_budget: Optional[float] spend: float = 0.0 user_email: Optional[str] + models: list = [] @root_validator(pre=True) def set_model_info(cls, values): if values.get("spend") is None: values.update({"spend": 0.0}) + if values.get("models") is None: + values.update({"models", []}) return values diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py index 73e200549764..1ebcf97232b3 100644 --- a/litellm/proxy/db/dynamo_db.py +++ b/litellm/proxy/db/dynamo_db.py @@ -171,7 +171,7 @@ async def insert_data( if isinstance(v, datetime): value[k] = v.isoformat() - await table.put_item(item=value) + return await table.put_item(item=value) async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): from aiodynamo.client import Client diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0503d2aad840..0647ef0f753f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -325,7 +325,28 @@ async def user_api_key_auth( model = data.get("model", None) if model in litellm.model_alias_map: model = litellm.model_alias_map[model] - if model and model not in valid_token.models: + + ## check if model in allowed model names + verbose_proxy_logger.debug( + f"LLM Model List pre access group check: {llm_model_list}" + ) + access_groups = [] + for m in llm_model_list: + for group in m.get("model_info", {}).get("access_groups", []): + access_groups.append((m["model_name"], group)) + + allowed_models = valid_token.models + if ( + len(access_groups) > 0 + ): # check if token contains any model access groups + for m in valid_token.models: + for model_name, group in access_groups: + if m == group: + allowed_models.append(model_name) + verbose_proxy_logger.debug( + f"model: {model}; allowed_models: {allowed_models}" + ) + if model is not None and model not in allowed_models: raise ValueError( f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) @@ -1057,6 +1078,7 @@ def _duration_in_seconds(duration: str): "user_email": user_email, "user_id": user_id, "spend": spend, + "models": models, } key_data = { "token": token, @@ -1070,14 +1092,28 @@ def _duration_in_seconds(duration: str): "metadata": metadata_json, } if prisma_client is not None: - verification_token_data = dict(key_data) - verification_token_data.update(user_data) - verbose_proxy_logger.debug("PrismaClient: Before Insert Data") - await prisma_client.insert_data(data=verification_token_data) + ## CREATE USER (If necessary) + verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}") + user_row = await prisma_client.insert_data( + data=user_data, table_name="user" + ) + + ## use default user model list if no key-specific model list provided + if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore + key_data["models"] = user_row.models + ## CREATE KEY + verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}") + await prisma_client.insert_data(data=key_data, table_name="key") elif custom_db_client is not None: ## CREATE USER (If necessary) verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}") - await custom_db_client.insert_data(value=user_data, table_name="user") + user_row = await custom_db_client.insert_data( + value=user_data, table_name="user" + ) + + ## use default user model list if no key-specific model list provided + if len(user_row["models"]) > 0 and len(key_data["models"]) == 0: # type: ignore + key_data["models"] = user_row["models"] ## CREATE KEY verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}") await custom_db_client.insert_data(value=key_data, table_name="key") diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index d12cac8f20f3..aa45a8818658 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -12,6 +12,7 @@ model LiteLLM_UserTable { max_budget Float? spend Float @default(0.0) user_email String? + models String[] @default([]) } // required for token gen @@ -19,7 +20,7 @@ model LiteLLM_VerificationToken { token String @unique spend Float @default(0.0) expires DateTime? - models String[] + models String[] @default([]) aliases Json @default("{}") config Json @default("{}") user_id String? diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 80c05e78da49..3d12eb874bfb 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -409,19 +409,17 @@ async def get_data( on_backoff=on_backoff, # specifying the function to call on backoff ) async def insert_data( - self, data: dict, table_name: Literal["user+key", "config"] = "user+key" + self, data: dict, table_name: Literal["user", "key", "config"] ): """ Add a key to the database. If it already exists, do nothing. """ try: - if table_name == "user+key": + if table_name == "key": token = data["token"] hashed_token = self.hash_token(token=token) db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token - max_budget = db_data.pop("max_budget", None) - user_email = db_data.pop("user_email", None) print_verbose( "PrismaClient: Before upsert into litellm_verificationtoken" ) @@ -434,19 +432,17 @@ async def insert_data( "update": {}, # don't do anything if it already exists }, ) - + return new_verification_token + elif table_name == "user": + db_data = self.jsonify_object(data=data) new_user_row = await self.db.litellm_usertable.upsert( where={"user_id": data["user_id"]}, data={ - "create": { - "user_id": data["user_id"], - "max_budget": max_budget, - "user_email": user_email, - }, + "create": {**db_data}, # type: ignore "update": {}, # don't do anything if it already exists }, ) - return new_verification_token + return new_user_row elif table_name == "config": """ For each param, diff --git a/schema.prisma b/schema.prisma index d12cac8f20f3..704ada42c980 100644 --- a/schema.prisma +++ b/schema.prisma @@ -12,6 +12,7 @@ model LiteLLM_UserTable { max_budget Float? spend Float @default(0.0) user_email String? + models String[] } // required for token gen From cff9f7fee694d848cec4f18c8d6b0e418fd4eab7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 17 Jan 2024 17:28:23 -0800 Subject: [PATCH 2/3] fix(proxy_server.py): handle empty insert_data response --- litellm/proxy/db/dynamo_db.py | 2 +- litellm/proxy/proxy_server.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py index 1ebcf97232b3..eb1c0852861d 100644 --- a/litellm/proxy/db/dynamo_db.py +++ b/litellm/proxy/db/dynamo_db.py @@ -171,7 +171,7 @@ async def insert_data( if isinstance(v, datetime): value[k] = v.isoformat() - return await table.put_item(item=value) + return await table.put_item(item=value, return_values=ReturnValues.all_old) async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): from aiodynamo.client import Client diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0647ef0f753f..6d74607b9bdd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1110,10 +1110,15 @@ def _duration_in_seconds(duration: str): user_row = await custom_db_client.insert_data( value=user_data, table_name="user" ) + if user_row is None: + # GET USER ROW + user_row = await custom_db_client.get_data( + key=user_id, table_name="user" + ) ## use default user model list if no key-specific model list provided - if len(user_row["models"]) > 0 and len(key_data["models"]) == 0: # type: ignore - key_data["models"] = user_row["models"] + if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore + key_data["models"] = user_row.models ## CREATE KEY verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}") await custom_db_client.insert_data(value=key_data, table_name="key") From 73daee7e07d2257257516cfabab2507396162db0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 17 Jan 2024 17:37:59 -0800 Subject: [PATCH 3/3] fix(proxy_cli.py): ensure proxy always retries if db push fails to connect to db --- litellm/proxy/proxy_cli.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 918e2ecebe28..19c8e1b7e12f 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -405,19 +405,21 @@ def _make_openai_completion(): is_prisma_runnable = False if is_prisma_runnable: - # run prisma db push, before starting server - # Save the current working directory - original_dir = os.getcwd() - # set the working directory to where this script is - abspath = os.path.abspath(__file__) - dname = os.path.dirname(abspath) - os.chdir(dname) - try: - subprocess.run( - ["prisma", "db", "push", "--accept-data-loss"] - ) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss - finally: - os.chdir(original_dir) + for _ in range(4): + # run prisma db push, before starting server + # Save the current working directory + original_dir = os.getcwd() + # set the working directory to where this script is + abspath = os.path.abspath(__file__) + dname = os.path.dirname(abspath) + os.chdir(dname) + try: + subprocess.run(["prisma", "db", "push", "--accept-data-loss"]) + break # Exit the loop if the subprocess succeeds + except subprocess.CalledProcessError as e: + print(f"Error: {e}") + finally: + os.chdir(original_dir) else: print( f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found."