Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(proxy_server.py): save abbreviated key name if allow_user_auth enabled #1642

Merged
merged 4 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class GenerateRequestBase(LiteLLMBase):


class GenerateKeyRequest(GenerateRequestBase):
key_alias: Optional[str] = None
duration: Optional[str] = "1h"
aliases: Optional[dict] = {}
config: Optional[dict] = {}
Expand Down Expand Up @@ -304,6 +305,8 @@ class Config:

class LiteLLM_VerificationToken(LiteLLMBase):
token: str
key_name: Optional[str] = None
key_alias: Optional[str] = None
spend: float = 0.0
max_budget: Optional[float] = None
expires: Union[str, None]
Expand Down
6 changes: 6 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ async def user_api_key_auth(
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
if isinstance(api_key, str):
assert api_key.startswith("sk-") # prevent token hashes from being used
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key)
Expand Down Expand Up @@ -1239,6 +1241,7 @@ async def generate_key_helper_fn(
rpm_limit: Optional[int] = None,
query_type: Literal["insert_data", "update_data"] = "insert_data",
update_key_values: Optional[dict] = None,
key_alias: Optional[str] = None,
):
global prisma_client, custom_db_client

Expand Down Expand Up @@ -1312,6 +1315,7 @@ def _duration_in_seconds(duration: str):
}
key_data = {
"token": token,
"key_alias": key_alias,
"expires": expires,
"models": models,
"aliases": aliases_json,
Expand All @@ -1327,6 +1331,8 @@ def _duration_in_seconds(duration: str):
"budget_duration": key_budget_duration,
"budget_reset_at": key_reset_at,
}
if general_settings.get("allow_user_auth", False) == True:
key_data["key_name"] = f"sk-...{token[-4:]}"
if prisma_client is not None:
## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
Expand Down
2 changes: 2 additions & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ model LiteLLM_UserTable {
// required for token gen
model LiteLLM_VerificationToken {
token String @unique
key_name String?
key_alias String?
spend Float @default(0.0)
expires DateTime?
models String[]
Expand Down
62 changes: 34 additions & 28 deletions litellm/tests/test_configs/test_config_no_auth.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ model_list:
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/azure-embedding-model
model_name: azure-embedding-model
model_info:
mode: "embedding"
mode: embedding
model_name: azure-embedding-model
- litellm_params:
model: gpt-3.5-turbo
model_info:
Expand All @@ -80,43 +80,49 @@ model_list:
description: this is a test openai model
id: 9b1ef341-322c-410a-8992-903987fef439
model_name: test_openai_models
- model_name: amazon-embeddings
litellm_params:
model: "bedrock/amazon.titan-embed-text-v1"
- litellm_params:
model: bedrock/amazon.titan-embed-text-v1
model_info:
mode: embedding
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
model_name: amazon-embeddings
- litellm_params:
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
model_info:
mode: embedding
- model_name: dall-e-3
litellm_params:
model_name: GPT-J 6B - Sagemaker Text Embedding (Internal)
- litellm_params:
model: dall-e-3
model_info:
mode: image_generation
- model_name: dall-e-3
litellm_params:
model: "azure/dall-e-3-test"
api_version: "2023-12-01-preview"
api_base: "os.environ/AZURE_SWEDEN_API_BASE"
api_key: "os.environ/AZURE_SWEDEN_API_KEY"
model_name: dall-e-3
- litellm_params:
api_base: os.environ/AZURE_SWEDEN_API_BASE
api_key: os.environ/AZURE_SWEDEN_API_KEY
api_version: 2023-12-01-preview
model: azure/dall-e-3-test
model_info:
mode: image_generation
- model_name: dall-e-2
litellm_params:
model: "azure/"
api_version: "2023-06-01-preview"
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
model_name: dall-e-3
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-06-01-preview
model: azure/
model_info:
mode: image_generation
- model_name: text-embedding-ada-002
litellm_params:
model_name: dall-e-2
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/azure-embedding-model
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
api_version: "2023-07-01-preview"
model_info:
base_model: text-embedding-ada-002
mode: embedding
base_model: text-embedding-ada-002
model_name: text-embedding-ada-002
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 34cb2419-7c63-44ae-a189-53f1d1ce5953
model_name: test_openai_models
48 changes: 48 additions & 0 deletions litellm/tests/test_key_generate_prisma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# 11. Generate a Key, cal key/info, call key/update, call key/info
# 12. Make a call with key over budget, expect to fail
# 14. Make a streaming chat/completions call with key over budget, expect to fail
# 15. Generate key, when `allow_user_auth`=False - check if `/key/info` returns key_name=null
# 16. Generate key, when `allow_user_auth`=True - check if `/key/info` returns key_name=sk...<last-4-digits>


# function to call to generate key - async def new_user(data: NewUserRequest):
Expand Down Expand Up @@ -86,6 +88,7 @@ def prisma_client():
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}"
)
litellm.proxy.proxy_server.user_custom_key_generate = None

return prisma_client

Expand Down Expand Up @@ -1140,3 +1143,48 @@ async def test_view_spend_per_key(prisma_client):
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")


@pytest.mark.asyncio()
async def test_key_name_null(prisma_client):
"""
- create key
- get key info
- assert key_name is null
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
generated_key = key.key
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert result["info"]["key_name"] is None
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")


@pytest.mark.asyncio()
async def test_key_name_set(prisma_client):
"""
- create key
- get key info
- assert key_name is not null
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
generated_key = key.key
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert isinstance(result["info"]["key_name"], str)
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")
2 changes: 1 addition & 1 deletion litellm/tests/test_proxy_pass_user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
) # Replace with the actual module where your FastAPI router is defined

# Your bearer token
token = ""
token = "sk-1234"

headers = {"Authorization": f"Bearer {token}"}

Expand Down
2 changes: 1 addition & 1 deletion litellm/tests/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
) # Replace with the actual module where your FastAPI router is defined

# Your bearer token
token = ""
token = "sk-1234"

headers = {"Authorization": f"Bearer {token}"}

Expand Down
2 changes: 1 addition & 1 deletion litellm/tests/test_proxy_server_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
) # Replace with the actual module where your FastAPI router is defined

# Your bearer token
token = ""
token = "sk-1234"

headers = {"Authorization": f"Bearer {token}"}

Expand Down
2 changes: 2 additions & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ model LiteLLM_UserTable {
// Generate Tokens for Proxy
model LiteLLM_VerificationToken {
token String @unique
key_name String?
key_alias String?
spend Float @default(0.0)
expires DateTime?
models String[]
Expand Down