Skip to content

Commit

Permalink
(fix) ProxyStartup - Check that prisma connection is healthy when sta…
Browse files Browse the repository at this point in the history
…rting an instance of LiteLLM (#6627)

* fix debug statements

* fix assert prisma_client.health_check is called on _setup

* asser that _setup_prisma_client is called on startup proxy

* fix prisma client health_check

* add test_bad_database_url

* add strict checks on db startup

* temp remove fix to validate if check works as expected

* add health_check back

* test_proxy_server_prisma_setup_invalid_db
  • Loading branch information
ishaan-jaff authored Nov 7, 2024
1 parent 8a2b6fd commit 373f9d4
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 10 deletions.
42 changes: 42 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,41 @@ jobs:
- store_test_results:
path: test-results

test_bad_database_url:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project
steps:
- checkout
- run:
name: Build Docker image
command: |
docker build -t myapp . -f ./docker/Dockerfile.non_root
- run:
name: Run Docker container with bad DATABASE_URL
command: |
docker run --name my-app \
-p 4000:4000 \
-e DATABASE_URL="postgresql://wrong:wrong@wrong:5432/wrong" \
myapp:latest \
--port 4000 > docker_output.log 2>&1 || true
- run:
name: Display Docker logs
command: cat docker_output.log
- run:
name: Check for expected error
command: |
if grep -q "Error: P1001: Can't reach database server at" docker_output.log && \
grep -q "httpx.ConnectError: All connection attempts failed" docker_output.log && \
grep -q "ERROR: Application startup failed. Exiting." docker_output.log; then
echo "Expected error found. Test passed."
else
echo "Expected error not found. Test failed."
cat docker_output.log
exit 1
fi
workflows:
version: 2
build_and_test:
Expand Down Expand Up @@ -1082,11 +1117,18 @@ workflows:
only:
- main
- /litellm_.*/
- test_bad_database_url:
filters:
branches:
only:
- main
- /litellm_.*/
- publish_to_pypi:
requires:
- local_testing
- build_and_test
- load_testing
- test_bad_database_url
- llm_translation_testing
- logging_testing
- litellm_router_testing
Expand Down
2 changes: 2 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3052,6 +3052,8 @@ async def _setup_prisma_client(
prisma_client.check_view_exists()
) # check if all necessary views exist. Don't block execution

# run a health check to ensure the DB is ready
await prisma_client.health_check()
return prisma_client


Expand Down
13 changes: 3 additions & 10 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,19 +1083,16 @@ def __init__(
proxy_logging_obj: ProxyLogging,
http_client: Optional[Any] = None,
):
verbose_proxy_logger.debug(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
self.iam_token_db_auth: Optional[bool] = str_to_bool(
os.getenv("IAM_TOKEN_DB_AUTH")
)
verbose_proxy_logger.debug("Creating Prisma Client..")
try:
from prisma import Prisma # type: ignore
except Exception:
raise Exception("Unable to find Prisma binaries.")
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
if http_client is not None:
self.db = PrismaWrapper(
original_prisma=Prisma(http=http_client),
Expand All @@ -1114,7 +1111,7 @@ def __init__(
else False
),
) # Client to connect to Prisma db
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
verbose_proxy_logger.debug("Success - Created Prisma Client")

def hash_token(self, token: str):
# Hash the string using SHA-256
Expand Down Expand Up @@ -2348,11 +2345,7 @@ async def health_check(self):
"""
start_time = time.time()
try:
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
sql_query = "SELECT 1"

# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
Expand Down
39 changes: 39 additions & 0 deletions tests/local_testing/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,7 @@ async def test_proxy_server_prisma_setup():
mock_client = mock_prisma_client.return_value # This is the mocked instance
mock_client.connect = AsyncMock() # Mock the connect method
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
mock_client.health_check = AsyncMock() # Mock the health_check method

await ProxyStartupEvent._setup_prisma_client(
database_url=os.getenv("DATABASE_URL"),
Expand All @@ -1921,3 +1922,41 @@ async def test_proxy_server_prisma_setup():
# Verify our mocked methods were called
mock_client.connect.assert_called_once()
mock_client.check_view_exists.assert_called_once()

# Note: This is REALLY IMPORTANT to check that the health check is called
# This is how we ensure the DB is ready before proceeding
mock_client.health_check.assert_called_once()


@pytest.mark.asyncio
async def test_proxy_server_prisma_setup_invalid_db():
"""
PROD TEST: Test that proxy server startup fails when it's unable to connect to the database
Think 2-3 times before editing / deleting this test, it's important for PROD
"""
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache

user_api_key_cache = DualCache()
invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent"

_old_db_url = os.getenv("DATABASE_URL")
os.environ["DATABASE_URL"] = invalid_db_url

with pytest.raises(Exception) as exc_info:
await ProxyStartupEvent._setup_prisma_client(
database_url=invalid_db_url,
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
user_api_key_cache=user_api_key_cache,
)
print("GOT EXCEPTION=", exc_info)

assert "httpx.ConnectError" in str(exc_info.value)

# # Verify the error message indicates a database connection issue
# assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"])

if _old_db_url:
os.environ["DATABASE_URL"] = _old_db_url

0 comments on commit 373f9d4

Please sign in to comment.