Skip to content

Commit

Permalink
Merge pull request #2104 from BerriAI/litellm_fix_updating_keys
Browse files Browse the repository at this point in the history
[Fix] Unexpected Model Deletion in POST /key/update When Updating Team ID
  • Loading branch information
ishaan-jaff authored Feb 21, 2024
2 parents a436592 + 476f401 commit cb2ef26
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
21 changes: 19 additions & 2 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3307,7 +3307,15 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
if prisma_client is None:
raise Exception("Not connected to DB!")

non_default_values = {k: v for k, v in data_json.items() if v is not None}
# get non default values for key
non_default_values = {}
for k, v in data_json.items():
if v is not None and v not in (
[],
{},
0,
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
response = await prisma_client.update_data(
token=key, data={**non_default_values, "token": key}
)
Expand Down Expand Up @@ -4116,7 +4124,16 @@ async def user_update(data: UpdateUserRequest):
if prisma_client is None:
raise Exception("Not connected to DB!")

non_default_values = {k: v for k, v in data_json.items() if v is not None}
# get non default values for key
non_default_values = {}
for k, v in data_json.items():
if v is not None and v not in (
[],
{},
0,
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v

response = await prisma_client.update_data(
user_id=data_json["user_id"],
data=non_default_values,
Expand Down
12 changes: 11 additions & 1 deletion litellm/tests/test_key_generate_prisma.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,14 +887,23 @@ async def test():
request._url = URL(url="/update/key")

# update the key
await update_key_fn(
response1 = await update_key_fn(
request=Request,
data=UpdateKeyRequest(
key=generated_key,
models=["ada", "babbage", "curie", "davinci"],
),
)

print("response1=", response1)

# update the team id
response2 = await update_key_fn(
request=Request,
data=UpdateKeyRequest(key=generated_key, team_id="ishaan"),
)
print("response2=", response2)

# get info on key after update
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
Expand All @@ -906,6 +915,7 @@ async def test():
"project": "litellm-project3",
}
assert result["info"]["models"] == ["ada", "babbage", "curie", "davinci"]
assert result["info"]["team_id"] == "ishaan"

# cleanup - delete key
delete_key_request = KeyRequest(keys=[generated_key])
Expand Down

0 comments on commit cb2ef26

Please sign in to comment.