Skip to content

Commit

Permalink
fix(google_kms.py): support enums for key management system
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Dec 27, 2023
1 parent 4cc59d2 commit 9ba520c
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 18 deletions.
2 changes: 2 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache
from litellm._logging import set_verbose
from litellm.proxy._types import KeyManagementSystem
import httpx

input_callback: List[Union[str, Callable]] = []
Expand Down Expand Up @@ -144,6 +145,7 @@
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
#############################################


Expand Down
13 changes: 13 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator
import enum
from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
Expand Down Expand Up @@ -175,6 +176,12 @@ class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None


class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
LOCAL = "local"


class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
Expand All @@ -183,6 +190,12 @@ class ConfigGeneralSettings(LiteLLMBase):
completion_model: Optional[str] = Field(
None, description="proxy level default model for all chat completion calls"
)
key_management_system: Optional[KeyManagementSystem] = Field(
None, description="key manager to load keys from / decrypt keys with"
)
use_google_kms: Optional[bool] = Field(
None, description="decrypt keys with google kms"
)
use_azure_key_vault: Optional[bool] = Field(
None, description="load keys from azure key vault"
)
Expand Down
16 changes: 14 additions & 2 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
client = SecretClient(vault_url=KVUri, credential=credential)

litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
else:
raise Exception(
f"Missing KVUri or client_id or client_secret or tenant_id from environment"
Expand Down Expand Up @@ -691,10 +692,21 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings is None:
general_settings = {}
if general_settings:
### LOAD FROM GOOGLE KMS ###
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get("key_management_system", None)
if key_management_system is not None:
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ###
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### LOAD FROM AZURE KEY VAULT ###
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ###
Expand Down
22 changes: 13 additions & 9 deletions litellm/proxy/secret_managers/google_kms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
import litellm, os
from typing import Optional
from litellm.proxy._types import KeyManagementSystem


def validate_environment():
Expand All @@ -25,12 +26,15 @@ def validate_environment():
def load_google_kms(use_google_kms: Optional[bool]):
if use_google_kms is None or use_google_kms == False:
return

from google.cloud import kms_v1 # type: ignore

validate_environment()

# Create the KMS client
client = kms_v1.KeyManagementServiceClient()
litellm.secret_manager_client = client
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
try:
from google.cloud import kms_v1 # type: ignore

validate_environment()

# Create the KMS client
client = kms_v1.KeyManagementServiceClient()
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
except Exception as e:
raise e
36 changes: 29 additions & 7 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan

import sys, re
import sys, re, binascii
import litellm
import dotenv, json, traceback, threading, base64
import subprocess, os
Expand Down Expand Up @@ -43,6 +43,7 @@
from .integrations.langfuse import LangFuseLogger
from .integrations.dynamodb import DyanmoDBLogger
from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .exceptions import (
Expand All @@ -59,7 +60,7 @@
BudgetExceededError,
UnprocessableEntityError,
)
from typing import cast, List, Dict, Union, Optional, Literal
from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -6331,24 +6332,45 @@ def litellm_telemetry(data):
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def get_secret(secret_name: str, default_value: Optional[str] = None):
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False


def get_secret(
secret_name: str,
default_value: Optional[str] = None,
):
key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if (
type(client).__module__ + "." + type(client).__name__
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = retrieved_secret = client.get_secret(secret_name).value
elif client.__class__.__name__ == "KeyManagementServiceClient":
encrypted_secret = os.getenv(secret_name)
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
if not isinstance(encrypted_secret, bytes):
# If it's not, assume it's a string and encode it to bytes
ciphertext = eval(
Expand Down
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[mypy]
warn_return_any = False
ignore_missing_imports = False

[mypy-google.*]
ignore_missing_imports = True

0 comments on commit 9ba520c

Please sign in to comment.