Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions python/ray/dashboard/modules/state/state_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ray.dashboard.subprocesses.module import SubprocessModule
from ray.dashboard.subprocesses.routes import SubprocessRouteTable as routes
from ray.dashboard.subprocesses.utils import ResponseType
from ray.dashboard.utils import RateLimitedModule
from ray.dashboard.utils import HTTPStatusCode, RateLimitedModule
from ray.util.state.common import (
DEFAULT_DOWNLOAD_FILENAME,
DEFAULT_LOG_LIMIT,
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs):

async def limit_handler_(self):
return do_reply(
success=False,
status_code=HTTPStatusCode.TOO_MANY_REQUESTS,
error_message=(
"Max number of in-progress requests="
f"{self.max_num_call_} reached. "
Expand All @@ -110,12 +110,16 @@ async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
try:
result = await self._state_api.list_jobs(option=options_from_req(req))
return do_reply(
success=True,
status_code=HTTPStatusCode.OK,
error_message="",
result=asdict(result),
)
except DataSourceUnavailable as e:
return do_reply(success=False, error_message=str(e), result=None)
return do_reply(
status_code=HTTPStatusCode.INTERNAL_ERROR,
error_message=str(e),
result=None,
)

@routes.get("/api/v0/nodes")
@RateLimitedModule.enforce_max_concurrent_calls
Expand Down Expand Up @@ -171,7 +175,7 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:

if not node_id and not node_ip:
return do_reply(
success=False,
status_code=HTTPStatusCode.BAD_REQUEST,
error_message=(
"Both node id and node ip are not provided. "
"Please provide at least one of them."
Expand All @@ -182,7 +186,7 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
node_id = await self._log_api.ip_to_node_id(node_ip)
if not node_id:
return do_reply(
success=False,
status_code=HTTPStatusCode.NOT_FOUND,
error_message=(
f"Cannot find matching node_id for a given node ip {node_ip}"
),
Expand All @@ -195,12 +199,16 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
)
except DataSourceUnavailable as e:
return do_reply(
success=False,
status_code=HTTPStatusCode.INTERNAL_ERROR,
error_message=str(e),
result=None,
)

return do_reply(success=True, error_message="", result=result)
return do_reply(
status_code=HTTPStatusCode.OK,
error_message="",
result=result,
)

@routes.get("/api/v0/logs/{media_type}", resp_type=ResponseType.STREAM)
@RateLimitedModule.enforce_max_concurrent_calls
Expand Down Expand Up @@ -330,7 +338,7 @@ async def delayed_response(self, req: aiohttp.web.Request):
delay = int(req.match_info.get("delay_s", 10))
await asyncio.sleep(delay)
return do_reply(
success=True,
status_code=HTTPStatusCode.OK,
error_message="",
result={},
partial_failure_warning=None,
Expand Down
22 changes: 16 additions & 6 deletions python/ray/dashboard/state_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from ray.util.state.util import convert_string_to_type


def do_reply(success: bool, error_message: str, result: ListApiResponse, **kwargs):
def do_reply(
status_code: HTTPStatusCode, error_message: str, result: ListApiResponse, **kwargs
):
return rest_response(
status_code=HTTPStatusCode.OK if success else HTTPStatusCode.INTERNAL_ERROR,
status_code=status_code,
message=error_message,
result=result,
convert_google_style=False,
Expand All @@ -40,14 +42,22 @@ async def handle_list_api(
try:
result = await list_api_fn(option=options_from_req(req))
return do_reply(
success=True,
status_code=HTTPStatusCode.OK,
error_message="",
result=asdict(result),
)
except ValueError as e:
return do_reply(success=False, error_message=str(e), result=None)
return do_reply(
status_code=HTTPStatusCode.BAD_REQUEST,
error_message=str(e),
result=None,
)
except DataSourceUnavailable as e:
return do_reply(success=False, error_message=str(e), result=None)
return do_reply(
status_code=HTTPStatusCode.INTERNAL_ERROR,
error_message=str(e),
result=None,
)


def _get_filters_from_req(
Expand Down Expand Up @@ -104,7 +114,7 @@ async def handle_summary_api(
):
result = await summary_fn(option=summary_options_from_req(req))
return do_reply(
success=True,
status_code=HTTPStatusCode.OK,
error_message="",
result=asdict(result),
)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class HTTPStatusCode(IntEnum):
OK = 200

# 4xx Client Errors
BAD_REQUEST = 400
NOT_FOUND = 404
TOO_MANY_REQUESTS = 429

# 5xx Server Errors
INTERNAL_ERROR = 500
Expand Down
55 changes: 54 additions & 1 deletion python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import List
from typing import List, Optional
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -860,6 +860,59 @@ async def test_api_manager_list_workers(state_api_manager):
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING


@pytest.mark.asyncio
@pytest.mark.parametrize(
("exception", "status_code"),
[
(None, 200),
(ValueError("Invalid filter parameter"), 400),
(DataSourceUnavailable("GCS connection failed"), 500),
],
)
async def test_handle_list_api_status_codes(
exception: Optional[Exception], status_code: int
):
"""Test that handle_list_api calls do_reply with correct status codes.

This directly tests the HTTP layer logic that maps exceptions to status codes:
- Success → HTTP 200 OK
- ValueError → HTTP 400 BAD_REQUEST
- DataSourceUnavailable → HTTP 500 INTERNAL_ERROR
"""
from unittest.mock import AsyncMock, MagicMock

from ray.dashboard.state_api_utils import handle_list_api
from ray.util.state.common import ListApiResponse

# 1. Mock aiohttp request with proper query interface
mock_request = MagicMock()

def mock_get(key, default=None):
return default

mock_request.query = MagicMock()
mock_request.query.get = mock_get

# 2. Mock response whether success or failure.
if exception is None:
mock_backend = AsyncMock(
return_value=ListApiResponse(
result=[],
total=0,
num_after_truncation=0,
num_filtered=0,
partial_failure_warning="",
)
)
else:
mock_backend = AsyncMock(side_effect=exception)

response = await handle_list_api(mock_backend, mock_request)

# 3. Assert status_code is correct.
assert response.status == status_code


@pytest.mark.asyncio
async def test_api_manager_list_tasks(state_api_manager):
data_source_client = state_api_manager.data_source_client
Expand Down