Skip to content

Commit

Permalink
feat(proxy_server.py): allow user to override api key auth
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Dec 5, 2023
1 parent 51cddf1 commit 030bd22
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 118 deletions.
14 changes: 14 additions & 0 deletions litellm/proxy/custom_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from litellm.proxy.types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os

load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception
158 changes: 42 additions & 116 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ def generate_feedback_box():

import litellm
from litellm.proxy.utils import (
PrismaClient
PrismaClient,
get_instance_fn
)
import pydantic
from litellm.proxy.types import *
from litellm.caching import DualCache
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None):
return True

from typing import Dict
from pydantic import BaseModel
######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated

# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None

class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)

class ModelParams(BaseModel):
model_name: str
litellm_params: dict
model_info: Optional[dict]
class Config:
protected_namespaces = ()

class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str] = None

class GenerateKeyResponse(BaseModel):
key: str
expires: datetime
user_id: str

class _DeleteKeyObject(BaseModel):
key: str

class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
user_api_base = None
user_model = None
user_debug = False
Expand All @@ -249,6 +191,7 @@ class DeleteKeyRequest(BaseModel):
otel_logging = False
prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache()
user_custom_auth = None
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
Expand All @@ -268,31 +211,29 @@ def usage_telemetry(
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
).start()

api_key_header = APIKeyHeader(name="Authorization", auto_error=False)

async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list
print(f"master_key - {master_key}; api_key - {api_key}")
if master_key is None:
if isinstance(api_key, str):
return {
"api_key": api_key.replace("Bearer ", "")
}
else:
return {
"api_key": api_key
}

async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth
try:
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)

if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
else:
return UserAPIKeyAuth()
if api_key is None:
raise Exception("No api key passed in.")
route = request.url.path

# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
if is_master_key_valid:
return {
"api_key": master_key
}
return UserAPIKeyAuth(api_key=master_key)

if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
Expand All @@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return return_dict
return UserAPIKeyAuth(**return_dict)
else:
data = await request.json()
model = data.get("model", None)
Expand All @@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return return_dict
return UserAPIKeyAuth(**return_dict)
else:
raise Exception(f"Invalid token")
except Exception as e:
print(f"An exception occurred - {traceback.format_exc()}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
detail="invalid user key",
)

def prisma_setup(database_url: Optional[str]):
Expand Down Expand Up @@ -464,7 +405,7 @@ def run_ollama_serve():
""")

def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging
global master_key, user_config_file_path, otel_logging, user_custom_auth
config = {}
try:
if os.path.exists(config_file_path):
Expand Down Expand Up @@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
### LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)

### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
Expand All @@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)

#### OpenTelemetry Logging (OTEL) ########
otel_logging = general_settings.get("otel", False)
if otel_logging == True:
print("\nOpenTelemetry Logging Activated")

### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
if litellm_settings:
Expand Down Expand Up @@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
password=cache_password
)
elif key == "callbacks":
print(f"{blue_color_code}\nSetting custom callbacks on Proxy")
passed_module, instance_name = value.split(".")

# Dynamically import the module
module = importlib.import_module(passed_module)
# Get the instance from the module
instance = getattr(module, instance_name)

methods = [method for method in dir(instance) if callable(getattr(instance, method))]
# Print the methods
print("Methods in the custom callbacks instance:")
for method in methods:
print(method)

litellm.callbacks = [instance]
print()

litellm.callbacks = [get_instance_fn(value=value)]
else:
setattr(litellm, key, value)

Expand Down Expand Up @@ -844,7 +770,7 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
try:
body = await request.body()
body_str = body.decode()
Expand All @@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
except:
data = json.loads(body_str)

data["user"] = user_api_key_dict.get("user_id", None)
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
Expand All @@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
data["model"] = user_model
data["call_type"] = "text_completion"
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}

return litellm_completion(
**data
Expand All @@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug
try:
data = {}
Expand All @@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.get("user_id", None)
data["user"] = user_api_key_dict.user_id

if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = request.headers
else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = request.headers
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
Expand Down Expand Up @@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap

@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
try:

# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)

data["user"] = user_api_key_dict.get("user_id", None)
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
Expand All @@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
if user_model:
data["model"] = user_model
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}

## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
Expand Down
Loading

0 comments on commit 030bd22

Please sign in to comment.