Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(proxy_server.py): support model access groups #1483

Merged
merged 3 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,12 @@ class LiteLLM_UserTable(LiteLLMBase):
max_budget: Optional[float]
spend: float = 0.0
user_email: Optional[str]
models: list = []

@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models", []})
return values
2 changes: 1 addition & 1 deletion litellm/proxy/db/dynamo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def insert_data(
if isinstance(v, datetime):
value[k] = v.isoformat()

await table.put_item(item=value)
return await table.put_item(item=value, return_values=ReturnValues.all_old)

async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
from aiodynamo.client import Client
Expand Down
28 changes: 15 additions & 13 deletions litellm/proxy/proxy_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,19 +405,21 @@ def _make_openai_completion():
is_prisma_runnable = False

if is_prisma_runnable:
# run prisma db push, before starting server
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
for _ in range(4):
# run prisma db push, before starting server
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(["prisma", "db", "push", "--accept-data-loss"])
break # Exit the loop if the subprocess succeeds
except subprocess.CalledProcessError as e:
print(f"Error: {e}")
finally:
os.chdir(original_dir)
else:
print(
f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found."
Expand Down
53 changes: 47 additions & 6 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,28 @@ async def user_api_key_auth(
model = data.get("model", None)
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:

## check if model in allowed model names
verbose_proxy_logger.debug(
f"LLM Model List pre access group check: {llm_model_list}"
)
access_groups = []
for m in llm_model_list:
for group in m.get("model_info", {}).get("access_groups", []):
access_groups.append((m["model_name"], group))

allowed_models = valid_token.models
if (
len(access_groups) > 0
): # check if token contains any model access groups
for m in valid_token.models:
for model_name, group in access_groups:
if m == group:
allowed_models.append(model_name)
verbose_proxy_logger.debug(
f"model: {model}; allowed_models: {allowed_models}"
)
if model is not None and model not in allowed_models:
raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
)
Expand Down Expand Up @@ -1057,6 +1078,7 @@ def _duration_in_seconds(duration: str):
"user_email": user_email,
"user_id": user_id,
"spend": spend,
"models": models,
}
key_data = {
"token": token,
Expand All @@ -1070,14 +1092,33 @@ def _duration_in_seconds(duration: str):
"metadata": metadata_json,
}
if prisma_client is not None:
verification_token_data = dict(key_data)
verification_token_data.update(user_data)
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
await prisma_client.insert_data(data=verification_token_data)
## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
user_row = await prisma_client.insert_data(
data=user_data, table_name="user"
)

## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
## CREATE KEY
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
await prisma_client.insert_data(data=key_data, table_name="key")
elif custom_db_client is not None:
## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
await custom_db_client.insert_data(value=user_data, table_name="user")
user_row = await custom_db_client.insert_data(
value=user_data, table_name="user"
)
if user_row is None:
# GET USER ROW
user_row = await custom_db_client.get_data(
key=user_id, table_name="user"
)

## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
## CREATE KEY
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
await custom_db_client.insert_data(value=key_data, table_name="key")
Expand Down
3 changes: 2 additions & 1 deletion litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ model LiteLLM_UserTable {
max_budget Float?
spend Float @default(0.0)
user_email String?
models String[] @default([])
}

// required for token gen
model LiteLLM_VerificationToken {
token String @unique
spend Float @default(0.0)
expires DateTime?
models String[]
models String[] @default([])
aliases Json @default("{}")
config Json @default("{}")
user_id String?
Expand Down
18 changes: 7 additions & 11 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,17 @@ async def get_data(
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
self, data: dict, table_name: Literal["user", "key", "config"]
):
"""
Add a key to the database. If it already exists, do nothing.
"""
try:
if table_name == "user+key":
if table_name == "key":
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
print_verbose(
"PrismaClient: Before upsert into litellm_verificationtoken"
)
Expand All @@ -434,19 +432,17 @@ async def insert_data(
"update": {}, # don't do anything if it already exists
},
)

return new_verification_token
elif table_name == "user":
db_data = self.jsonify_object(data=data)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
return new_user_row
elif table_name == "config":
"""
For each param,
Expand Down
1 change: 1 addition & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ model LiteLLM_UserTable {
max_budget Float?
spend Float @default(0.0)
user_email String?
models String[]
}

// required for token gen
Expand Down