Skip to content

Commit

Permalink
/health endpoint for K8 Liveness and Readiness probes (#3855)
Browse files Browse the repository at this point in the history
* Added API Endpoint

* Added API Endpoint

* Added Unit Tests

* Added Unit Tests

* main

* Apply suggestions from Code Review

* Fix Ruff Formatting

* Update Socket Events

* Async Functions
  • Loading branch information
samarth9201 authored Sep 4, 2024
1 parent 15a9f0a commit 5904730
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 2 deletions.
37 changes: 35 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer
Expand Down Expand Up @@ -65,7 +65,7 @@
from reflex.components.radix import themes
from reflex.config import get_config
from reflex.event import Event, EventHandler, EventSpec, window_alert
from reflex.model import Model
from reflex.model import Model, get_db_status
from reflex.page import (
DECORATED_PAGES,
)
Expand Down Expand Up @@ -377,6 +377,7 @@ def _add_default_endpoints(self):
"""Add default api endpoints (ping)."""
# To test the server.
self.api.get(str(constants.Endpoint.PING))(ping)
self.api.get(str(constants.Endpoint.HEALTH))(health)

def _add_optional_endpoints(self):
"""Add optional api endpoints (_upload)."""
Expand Down Expand Up @@ -1319,6 +1320,38 @@ async def ping() -> str:
return "pong"


async def health() -> JSONResponse:
"""Health check endpoint to assess the status of the database and Redis services.
Returns:
JSONResponse: A JSON object with the health status:
- "status" (bool): Overall health, True if all checks pass.
- "db" (bool or str): Database status - True, False, or "NA".
- "redis" (bool or str): Redis status - True, False, or "NA".
"""
health_status = {"status": True}
status_code = 200

db_status, redis_status = await asyncio.gather(
get_db_status(), prerequisites.get_redis_status()
)

health_status["db"] = db_status

if redis_status is None:
health_status["redis"] = False
else:
health_status["redis"] = redis_status

if not health_status["db"] or (
not health_status["redis"] and redis_status is not None
):
health_status["status"] = False
status_code = 503

return JSONResponse(content=health_status, status_code=status_code)


def upload(app: App):
"""Upload a file.
Expand Down
1 change: 1 addition & 0 deletions reflex/constants/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Endpoint(Enum):
EVENT = "_event"
UPLOAD = "_upload"
AUTH_CODESPACE = "auth-codespace"
HEALTH = "_health"

def __str__(self) -> str:
"""Get the string representation of the endpoint.
Expand Down
22 changes: 22 additions & 0 deletions reflex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.orm

from reflex import constants
Expand Down Expand Up @@ -51,6 +52,27 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)


async def get_db_status() -> bool:
"""Checks the status of the database connection.
Attempts to connect to the database and execute a simple query to verify connectivity.
Returns:
bool: The status of the database connection:
- True: The database is accessible.
- False: The database is not accessible.
"""
status = True
try:
engine = get_engine()
with engine.connect() as connection:
connection.execute(sqlalchemy.text("SELECT 1"))
except sqlalchemy.exc.OperationalError:
status = False

return status


SQLModelOrSqlAlchemy = Union[
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
]
Expand Down
25 changes: 25 additions & 0 deletions reflex/utils/prerequisites.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from alembic.util.exc import CommandError
from packaging import version
from redis import Redis as RedisSync
from redis import exceptions
from redis.asyncio import Redis

from reflex import constants, model
Expand Down Expand Up @@ -344,6 +345,30 @@ def parse_redis_url() -> str | dict | None:
return dict(host=redis_url, port=int(redis_port), db=0)


async def get_redis_status() -> bool | None:
"""Checks the status of the Redis connection.
Attempts to connect to Redis and send a ping command to verify connectivity.
Returns:
bool or None: The status of the Redis connection:
- True: Redis is accessible and responding.
- False: Redis is not accessible due to a connection error.
- None: Redis not used i.e redis_url is not set in rxconfig.
"""
try:
status = True
redis_client = get_redis_sync()
if redis_client is not None:
redis_client.ping()
else:
status = None
except exceptions.RedisError:
status = False

return status


def validate_app_name(app_name: str | None = None) -> str:
"""Validate the app name.
Expand Down
106 changes: 106 additions & 0 deletions tests/test_health_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import json
from unittest.mock import MagicMock, Mock

import pytest
import sqlalchemy
from redis.exceptions import RedisError

from reflex.app import health
from reflex.model import get_db_status
from reflex.utils.prerequisites import get_redis_status


@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_redis_client, expected_status",
[
# Case 1: Redis client is available and responds to ping
(Mock(ping=lambda: None), True),
# Case 2: Redis client raises RedisError
(Mock(ping=lambda: (_ for _ in ()).throw(RedisError)), False),
# Case 3: Redis client is not used
(None, None),
],
)
async def test_get_redis_status(mock_redis_client, expected_status, mocker):
# Mock the `get_redis_sync` function to return the mock Redis client
mock_get_redis_sync = mocker.patch(
"reflex.utils.prerequisites.get_redis_sync", return_value=mock_redis_client
)

# Call the function
status = await get_redis_status()

# Verify the result
assert status == expected_status
mock_get_redis_sync.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_engine, execute_side_effect, expected_status",
[
# Case 1: Database is accessible
(MagicMock(), None, True),
# Case 2: Database connection error (OperationalError)
(
MagicMock(),
sqlalchemy.exc.OperationalError("error", "error", "error"),
False,
),
],
)
async def test_get_db_status(mock_engine, execute_side_effect, expected_status, mocker):
# Mock get_engine to return the mock_engine
mock_get_engine = mocker.patch("reflex.model.get_engine", return_value=mock_engine)

# Mock the connection and its execute method
if mock_engine:
mock_connection = mock_engine.connect.return_value.__enter__.return_value
if execute_side_effect:
# Simulate execute method raising an exception
mock_connection.execute.side_effect = execute_side_effect
else:
# Simulate successful execute call
mock_connection.execute.return_value = None

# Call the function
status = await get_db_status()

# Verify the result
assert status == expected_status
mock_get_engine.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.parametrize(
"db_status, redis_status, expected_status, expected_code",
[
# Case 1: Both services are connected
(True, True, {"status": True, "db": True, "redis": True}, 200),
# Case 2: Database not connected, Redis connected
(False, True, {"status": False, "db": False, "redis": True}, 503),
# Case 3: Database connected, Redis not connected
(True, False, {"status": False, "db": True, "redis": False}, 503),
# Case 4: Both services not connected
(False, False, {"status": False, "db": False, "redis": False}, 503),
# Case 5: Database Connected, Redis not used
(True, None, {"status": True, "db": True, "redis": False}, 200),
],
)
async def test_health(db_status, redis_status, expected_status, expected_code, mocker):
# Mock get_db_status and get_redis_status
mocker.patch("reflex.app.get_db_status", return_value=db_status)
mocker.patch(
"reflex.utils.prerequisites.get_redis_status", return_value=redis_status
)

# Call the async health function
response = await health()

print(json.loads(response.body))
print(expected_status)

# Verify the response content and status code
assert response.status_code == expected_code
assert json.loads(response.body) == expected_status

0 comments on commit 5904730

Please sign in to comment.