diff --git a/python/ray/_private/collections_utils.py b/python/ray/_private/collections_utils.py new file mode 100644 index 000000000000..9b7ff6f3c310 --- /dev/null +++ b/python/ray/_private/collections_utils.py @@ -0,0 +1,10 @@ +from typing import List, Any + + +def split(items: List[Any], chunk_size: int): + """Splits provided list into chunks of given size""" + + assert chunk_size > 0, "Chunk size has to be > 0" + + for i in range(0, len(items), chunk_size): + yield items[i : i + chunk_size] diff --git a/python/ray/dashboard/agent.py b/python/ray/dashboard/agent.py index 48bdad05fa62..11b2d5397473 100644 --- a/python/ray/dashboard/agent.py +++ b/python/ray/dashboard/agent.py @@ -7,7 +7,6 @@ import pathlib import signal import sys -from concurrent.futures import ThreadPoolExecutor import ray import ray._private.ray_constants as ray_constants @@ -49,10 +48,6 @@ def __init__( # Public attributes are accessible for all agent modules. self.ip = node_ip_address self.minimal = minimal - self.thread_pool_executor = ThreadPoolExecutor( - max_workers=dashboard_consts.RAY_AGENT_THREAD_POOL_MAX_WORKERS, - thread_name_prefix="dashboard_agent_tpe", - ) assert gcs_address is not None self.gcs_address = gcs_address diff --git a/python/ray/dashboard/consts.py b/python/ray/dashboard/consts.py index 6ab2b547dda3..0542b1296503 100644 --- a/python/ray/dashboard/consts.py +++ b/python/ray/dashboard/consts.py @@ -26,7 +26,7 @@ "RAY_DASHBOARD_STATS_PURGING_INTERVAL", 60 * 10 ) RAY_DASHBOARD_STATS_UPDATING_INTERVAL = env_integer( - "RAY_DASHBOARD_STATS_UPDATING_INTERVAL", 2 + "RAY_DASHBOARD_STATS_UPDATING_INTERVAL", 15 ) DASHBOARD_RPC_ADDRESS = "dashboard_rpc" DASHBOARD_RPC_PORT = env_integer("RAY_DASHBOARD_RPC_PORT", 0) @@ -49,12 +49,6 @@ # Example: "your.module.ray_cluster_activity_hook". RAY_CLUSTER_ACTIVITY_HOOK = "RAY_CLUSTER_ACTIVITY_HOOK" -# Works in the thread pool should not starve the main thread loop, so we default to 1. -RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS = env_integer( - "RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS", 1 -) -RAY_AGENT_THREAD_POOL_MAX_WORKERS = env_integer("RAY_AGENT_THREAD_POOL_MAX_WORKERS", 1) - # The number of candidate agents CANDIDATE_AGENT_NUMBER = max(env_integer("CANDIDATE_AGENT_NUMBER", 1), 1) # when head receive JobSubmitRequest, maybe not any agent is available, diff --git a/python/ray/dashboard/dashboard_metrics.py b/python/ray/dashboard/dashboard_metrics.py index c1027ec47c53..7f1b6f2b22a9 100644 --- a/python/ray/dashboard/dashboard_metrics.py +++ b/python/ray/dashboard/dashboard_metrics.py @@ -75,6 +75,14 @@ def __init__(self, registry: Optional[CollectorRegistry] = None): namespace="ray", registry=self.registry, ) + self.metrics_event_loop_tasks = Gauge( + "dashboard_event_loop_tasks", + "Number of tasks currently pending in the event loop's queue.", + tuple(COMPONENT_METRICS_TAG_KEYS), + unit="tasks", + namespace="ray", + registry=self.registry, + ) self.metrics_event_loop_lag = Gauge( "dashboard_event_loop_lag", "Event loop lag in seconds.", diff --git a/python/ray/dashboard/datacenter.py b/python/ray/dashboard/datacenter.py index b0c663733f2b..96988a93fbc3 100644 --- a/python/ray/dashboard/datacenter.py +++ b/python/ray/dashboard/datacenter.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import Any, List, Optional @@ -72,39 +71,47 @@ async def organize(cls, thread_pool_executor): to make sure it's on the main event loop thread. To avoid blocking the main event loop, we yield after each node processed. """ + loop = get_or_create_event_loop() + node_workers = {} core_worker_stats = {} - # nodes may change during process, so we create a copy of keys(). + + # NOTE: We copy keys of the `DataSource.nodes` to make sure + # it doesn't change during the iteration (since its being updated + # from another async task) for node_id in list(DataSource.nodes.keys()): node_physical_stats = DataSource.node_physical_stats.get(node_id, {}) node_stats = DataSource.node_stats.get(node_id, {}) # Offloads the blocking operation to a thread pool executor. This also # yields to the event loop. - workers = await get_or_create_event_loop().run_in_executor( + workers = await loop.run_in_executor( thread_pool_executor, - cls.merge_workers_for_node, + cls._extract_workers_for_node, node_physical_stats, node_stats, ) + for worker in workers: for stats in worker.get("coreWorkerStats", []): worker_id = stats["workerId"] core_worker_stats[worker_id] = stats + node_workers[node_id] = workers + DataSource.node_workers.reset(node_workers) DataSource.core_worker_stats.reset(core_worker_stats) @classmethod - def merge_workers_for_node(cls, node_physical_stats, node_stats): + def _extract_workers_for_node(cls, node_physical_stats, node_stats): workers = [] # Merge coreWorkerStats (node stats) to workers (node physical stats) pid_to_worker_stats = {} pid_to_language = {} pid_to_job_id = {} - pids_on_node = set() + for core_worker_stats in node_stats.get("coreWorkersStats", []): pid = core_worker_stats["pid"] - pids_on_node.add(pid) + pid_to_worker_stats[pid] = core_worker_stats pid_to_language[pid] = core_worker_stats["language"] pid_to_job_id[pid] = core_worker_stats["jobId"] @@ -112,6 +119,7 @@ def merge_workers_for_node(cls, node_physical_stats, node_stats): for worker in node_physical_stats.get("workers", []): worker = dict(worker) pid = worker["pid"] + core_worker_stats = pid_to_worker_stats.get(pid) # Empty list means core worker stats is not available. worker["coreWorkerStats"] = [core_worker_stats] if core_worker_stats else [] @@ -121,6 +129,7 @@ def merge_workers_for_node(cls, node_physical_stats, node_stats): worker["jobId"] = pid_to_job_id.get(pid, dashboard_consts.DEFAULT_JOB_ID) workers.append(worker) + return workers @classmethod @@ -156,10 +165,14 @@ async def get_node_info(cls, node_id, get_summary=False): ) if not get_summary: + actor_table_entries = DataSource.node_actors.get(node_id, {}) + # Merge actors to node physical stats - node_info["actors"] = await DataOrganizer._get_all_actors( - DataSource.node_actors.get(node_id, {}) - ) + node_info["actors"] = { + actor_id: await DataOrganizer._get_actor_info(actor_table_entry) + for actor_id, actor_table_entry in actor_table_entries.items() + } + # Update workers to node physical stats node_info["workers"] = DataSource.node_workers.get(node_id, []) @@ -168,6 +181,8 @@ async def get_node_info(cls, node_id, get_summary=False): @classmethod async def get_all_node_summary(cls): return [ + # NOTE: We're intentionally awaiting in a loop to avoid excessive + # concurrency spinning up excessive # of tasks for large clusters await DataOrganizer.get_node_info(node_id, get_summary=True) for node_id in DataSource.nodes.keys() ] @@ -209,28 +224,25 @@ def _create_agent_info(node_id: str): return {node_id: _create_agent_info(node_id) for node_id in target_node_ids} @classmethod - async def get_all_actors(cls): - return await cls._get_all_actors(DataSource.actors) + async def get_actor_infos(cls, actor_ids: Optional[List[str]] = None): + target_actor_table_entries: dict[str, Optional[dict]] + if actor_ids is not None: + target_actor_table_entries = { + actor_id: DataSource.actors.get(actor_id) for actor_id in actor_ids + } + else: + target_actor_table_entries = DataSource.actors - @staticmethod - async def _get_all_actors(actors): - result = {} - for index, (actor_id, actor) in enumerate(actors.items()): - result[actor_id] = await DataOrganizer._get_actor(actor) - # There can be thousands of actors including dead ones. Processing - # them all can take many seconds, which blocks all other requests - # to the dashboard. The ideal solution might be to implement - # pagination. For now, use a workaround to yield to the event loop - # periodically, so other request handlers have a chance to run and - # avoid long latencies. - if index % 1000 == 0 and index > 0: - # Canonical way to yield to the event loop: - # https://github.com/python/asyncio/issues/284 - await asyncio.sleep(0) - return result + return { + actor_id: await DataOrganizer._get_actor_info(actor_table_entry) + for actor_id, actor_table_entry in target_actor_table_entries.items() + } @staticmethod - async def _get_actor(actor): + async def _get_actor_info(actor): + if actor is None: + return None + actor = dict(actor) worker_id = actor["address"]["workerId"] core_worker_stats = DataSource.core_worker_stats.get(worker_id, {}) diff --git a/python/ray/dashboard/head.py b/python/ray/dashboard/head.py index e6d905e32de0..fe6e6cb4365f 100644 --- a/python/ray/dashboard/head.py +++ b/python/ray/dashboard/head.py @@ -9,6 +9,7 @@ import ray.experimental.internal_kv as internal_kv from ray._private import ray_constants from ray._private.gcs_utils import GcsAioClient +from ray._private.ray_constants import env_integer from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag from ray._raylet import GcsClient from ray.dashboard.consts import DASHBOARD_METRIC_PORT @@ -30,6 +31,13 @@ ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ) +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_DASHBOARD_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_DASHBOARD_HEAD_TPE_MAX_WORKERS", 1 +) + def initialize_grpc_port_and_server(grpc_ip, grpc_port): try: @@ -98,11 +106,9 @@ def __init__( self._modules_to_load = modules_to_load self._modules_loaded = False - # A TPE holding background, compute-heavy, latency-tolerant jobs, typically - # state updates. - self._thread_pool_executor = ThreadPoolExecutor( - max_workers=dashboard_consts.RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS, - thread_name_prefix="dashboard_head_tpe", + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_DASHBOARD_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="dashboard_head_executor", ) self.gcs_address = None @@ -326,7 +332,7 @@ async def _async_notify(): self._gcs_check_alive(), _async_notify(), DataOrganizer.purge(), - DataOrganizer.organize(self._thread_pool_executor), + DataOrganizer.organize(self._executor), ] for m in modules: concurrent_tasks.append(m.run(self.server)) diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index 4ce372a751a3..043564d305ec 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -1,9 +1,10 @@ +import abc import asyncio import logging import os -import time -from collections import deque -from typing import Dict +from collections import defaultdict, deque +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Optional import aiohttp.web @@ -11,7 +12,9 @@ import ray.dashboard.utils as dashboard_utils from ray import ActorID from ray._private.gcs_pubsub import GcsAioActorSubscriber +from ray._private.utils import get_or_create_event_loop from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc +from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.actor import actor_consts @@ -22,6 +25,19 @@ ACTOR_CLEANUP_FREQUENCY = 1 # seconds +ACTOR_TABLE_STATE_COLUMNS = ( + "state", + "address", + "numRestarts", + "timestamp", + "pid", + "exitDetail", + "startTime", + "endTime", + "reprName", +) + + def actor_table_data_to_dict(message): orig_message = dashboard_utils.message_to_dict( message, @@ -79,171 +95,215 @@ def actor_table_data_to_dict(message): return light_message -class GetAllActorInfo: +class GetAllActorInfoClient(abc.ABC): """ Gets all actor info from GCS via gRPC ActorInfoGcsService.GetAllActorInfo. It makes the call via GcsAioClient or a direct gRPC stub, depends on the env var RAY_USE_OLD_GCS_CLIENT. """ - def __new__(cls, *args, **kwargs): + @classmethod + def create(cls, *args, **kwargs): use_old_client = os.getenv("RAY_USE_OLD_GCS_CLIENT") == "1" if use_old_client: return GetAllActorInfoFromGrpc(*args, **kwargs) else: return GetAllActorInfoFromNewGcsClient(*args, **kwargs) + @abc.abstractmethod + def __call__(self, *, timeout_s: int, state: Optional[str] = None): + pass -class GetAllActorInfoFromNewGcsClient: + +class GetAllActorInfoFromNewGcsClient(GetAllActorInfoClient): def __init__(self, dashboard_head): self.gcs_aio_client = dashboard_head.gcs_aio_client - async def __call__(self, timeout) -> Dict[ActorID, gcs_pb2.ActorTableData]: - return await self.gcs_aio_client.get_all_actor_info(timeout=timeout) + async def __call__( + self, *, timeout, state: Optional[str] = None + ) -> Dict[ActorID, gcs_pb2.ActorTableData]: + return await self.gcs_aio_client.get_all_actor_info( + timeout=timeout, actor_state_name=state + ) -class GetAllActorInfoFromGrpc: +class GetAllActorInfoFromGrpc(GetAllActorInfoClient): def __init__(self, dashboard_head): gcs_channel = dashboard_head.aiogrpc_gcs_channel self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub( gcs_channel ) - async def __call__(self, timeout) -> Dict[ActorID, gcs_pb2.ActorTableData]: + async def __call__( + self, *, timeout, state: Optional[str] = None + ) -> Dict[ActorID, gcs_pb2.ActorTableData]: request = gcs_service_pb2.GetAllActorInfoRequest() + reply = await self._gcs_actor_info_stub.GetAllActorInfo( - request, timeout=timeout + request, + timeout=timeout, + filters=gcs_service_pb2.GetAllActorInfoRequest.Filters(state=state) + if state + else None, ) + if reply.status.code != 0: raise Exception(f"Failed to GetAllActorInfo: {reply.status.message}") - actors = {} - for message in reply.actor_table_data: - actors[ActorID(message.actor_id)] = message - return actors + + return { + ActorID(message.actor_id): message for message in reply.actor_table_data + } class ActorHead(dashboard_utils.DashboardHeadModule): def __init__(self, dashboard_head): super().__init__(dashboard_head) - self.get_all_actor_info = None + self._gcs_actor_channel_subscriber = None + self.get_all_actor_info_client = None # A queue of dead actors in order of when they died self.dead_actors_queue = deque() - # -- Internal states -- + # -- Internal state -- self.total_published_events = 0 - self.subscriber_queue_size = 0 - self.accumulative_event_processing_s = 0 + + self._loop = get_or_create_event_loop() + # NOTE: This executor is intentionally constrained to just 1 thread to + # limit its concurrency, therefore reducing potential for GIL contention + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="actor_head_executor" + ) async def _update_actors(self): """ Processes actor info. First gets all actors from GCS, then subscribes to actor updates. For each actor update, updates DataSource.node_actors and DataSource.actors. + """ - To prevent Time-of-check to time-of-use issue [1], the get-all-actor-info - happens after the subscription. That is, an update between get-all-actor-info - and the subscription is not missed. + # To prevent Time-of-check to time-of-use issue [1], the get-all-actor-info + # happens after the subscription. That is, an update between get-all-actor-info + # and the subscription is not missed. + # # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use - """ - # Receive actors from channel. gcs_addr = self._dashboard_head.gcs_address - subscriber = GcsAioActorSubscriber(address=gcs_addr) - await subscriber.subscribe() + actor_channel_subscriber = GcsAioActorSubscriber(address=gcs_addr) + await actor_channel_subscriber.subscribe() # Get all actor info. while True: try: logger.info("Getting all actor info from GCS.") - actors = await self.get_all_actor_info(timeout=5) - actor_dicts: Dict[str, dict] = { - actor_id.hex(): actor_table_data_to_dict(actor_table_data) - for actor_id, actor_table_data in actors.items() - } - # Update actors. + actor_dicts = await self._get_all_actors() + # Update actors DataSource.actors.reset(actor_dicts) + # Update node actors and job actors. - node_actors = {} - for actor_id, actor_table_data in actor_dicts.items(): - node_id = actor_table_data["address"]["rayletId"] + node_actors = defaultdict(dict) + for actor_id_bytes, updated_actor_table in actor_dicts.items(): + node_id = updated_actor_table["address"]["rayletId"] # Update only when node_id is not Nil. if node_id != actor_consts.NIL_NODE_ID: - node_actors.setdefault(node_id, {})[actor_id] = actor_table_data + node_actors[node_id][actor_id_bytes] = updated_actor_table + + # Update node's actor info DataSource.node_actors.reset(node_actors) - logger.info("Received %d actor info from GCS.", len(actors)) - break # breaks the while True. - except Exception: - logger.exception("Error Getting all actor info from GCS.") + + logger.info("Received %d actor info from GCS.", len(actor_dicts)) + + # Break, once all initial actors are successfully fetched + break + except Exception as e: + logger.exception("Error Getting all actor info from GCS", exc_info=e) await asyncio.sleep( actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS ) - state_keys = ( - "state", - "address", - "numRestarts", - "timestamp", - "pid", - "exitDetail", - "startTime", - "endTime", - "reprName", - ) - - def process_actor_data_from_pubsub(actor_id, actor_table_data): - actor_table_data = actor_table_data_to_dict(actor_table_data) - # If actor is not new registered but updated, we only update - # states related fields. - if actor_table_data["state"] != "DEPENDENCIES_UNREADY": - actors = DataSource.actors[actor_id] - for k in state_keys: - if k in actor_table_data: - actors[k] = actor_table_data[k] - actor_table_data = actors - actor_id = actor_table_data["actorId"] - node_id = actor_table_data["address"]["rayletId"] - if actor_table_data["state"] == "DEAD": - self.dead_actors_queue.append(actor_id) - # Update actors. - DataSource.actors[actor_id] = actor_table_data - # Update node actors (only when node_id is not Nil). - if node_id != actor_consts.NIL_NODE_ID: - node_actors = DataSource.node_actors.get(node_id, {}) - node_actors[actor_id] = actor_table_data - DataSource.node_actors[node_id] = node_actors - + # Pull incremental updates from the GCS channel while True: try: - published = await subscriber.poll(batch_size=200) - start = time.monotonic() - for actor_id, actor_table_data in published: - if actor_id is not None: - # Convert to lower case hex ID. - actor_id = actor_id.hex() - process_actor_data_from_pubsub(actor_id, actor_table_data) - - # Yield so that we can give time for - # user-facing APIs to reply to the frontend. - elapsed = time.monotonic() - start - await asyncio.sleep(elapsed) - - # Update the internal states for debugging. - self.accumulative_event_processing_s += elapsed - self.total_published_events += len(published) - self.subscriber_queue_size = subscriber.queue_size + updated_actor_table_entries = await self._poll_updated_actor_table_data( + actor_channel_subscriber + ) + + for ( + actor_id, + updated_actor_table, + ) in updated_actor_table_entries.items(): + self._process_updated_actor_table(actor_id, updated_actor_table) + + self.total_published_events += len(updated_actor_table_entries) + + # TODO emit metrics logger.debug( - f"Processing takes {elapsed}. Total process: " f"{len(published)}" + f"Total events processed: {len(updated_actor_table_entries)}, " + f"queue size: {actor_channel_subscriber.queue_size}" ) - if self.accumulative_event_processing_s > 0: - logger.debug( - "Processing throughput: " - f"{self.total_published_events / self.accumulative_event_processing_s}" # noqa - " / s" - ) - logger.debug(f"queue size: {self.subscriber_queue_size}") - except Exception: - logger.exception("Error processing actor info from GCS.") + + except Exception as e: + logger.exception("Error processing actor info from GCS.", exc_info=e) + + async def _poll_updated_actor_table_data( + self, actor_channel_subscriber: GcsAioActorSubscriber + ) -> Dict[str, Dict[str, Any]]: + # TODO make batch size configurable + batch = await actor_channel_subscriber.poll(batch_size=200) + + # NOTE: We're offloading conversion to a TPE to make sure we're not + # blocking the event-loop for prolonged period of time irrespective + # of the batch size + def _convert_to_dict(): + return { + actor_id_bytes.hex(): actor_table_data_to_dict(actor_table_data_message) + for actor_id_bytes, actor_table_data_message in batch + if actor_id_bytes is not None + } + + return await self._loop.run_in_executor(self._executor, _convert_to_dict) + + def _process_updated_actor_table( + self, actor_id: str, actor_table_data: Dict[str, Any] + ): + """NOTE: This method has to be executed on the event-loop, provided that it accesses + DataSource data structures (to follow its thread-safety model)""" + + # If actor is not new registered but updated, we only update + # states related fields. + actor = DataSource.actors.get(actor_id) + + if actor and actor_table_data["state"] != "DEPENDENCIES_UNREADY": + for k in ACTOR_TABLE_STATE_COLUMNS: + if k in actor_table_data: + actor[k] = actor_table_data[k] + actor_table_data = actor + + actor_id = actor_table_data["actorId"] + node_id = actor_table_data["address"]["rayletId"] + + if actor_table_data["state"] == "DEAD": + self.dead_actors_queue.append(actor_id) + + # Update actors. + DataSource.actors[actor_id] = actor_table_data + # Update node actors (only when node_id is not Nil). + if node_id != actor_consts.NIL_NODE_ID: + node_actors = DataSource.node_actors.get(node_id, {}) + node_actors[actor_id] = actor_table_data + DataSource.node_actors[node_id] = node_actors + + async def _get_all_actors(self) -> Dict[str, dict]: + actors = await self.get_all_actor_info_client(timeout=GCS_RPC_TIMEOUT_SECONDS) + + # NOTE: We're offloading conversion to a TPE to make sure we're not + # blocking the event-loop for prolonged period of time for large clusters + def _convert_to_dict(): + return { + actor_id.hex(): actor_table_data_to_dict(actor_table_data) + for actor_id, actor_table_data in actors.items() + } + + return await self._loop.run_in_executor(self._executor, _convert_to_dict) async def _cleanup_actors(self): while True: @@ -267,23 +327,10 @@ async def _cleanup_actors(self): except Exception: logger.exception("Error cleaning up actor info from GCS.") - def get_internal_states(self): - states = { - "total_published_events": self.total_published_events, - "total_dead_actors": len(self.dead_actors_queue), - "total_actors": len(DataSource.actors), - "queue_size": self.subscriber_queue_size, - } - if self.accumulative_event_processing_s > 0: - states["event_processing_per_s"] = ( - self.total_published_events / self.accumulative_event_processing_s - ) - return states - @routes.get("/logical/actors") @dashboard_optional_utils.aiohttp_cache async def get_all_actors(self, req) -> aiohttp.web.Response: - actors = await DataOrganizer.get_all_actors() + actors = await DataOrganizer.get_actor_infos() return dashboard_optional_utils.rest_response( success=True, message="All actors fetched.", @@ -298,13 +345,16 @@ async def get_all_actors(self, req) -> aiohttp.web.Response: @dashboard_optional_utils.aiohttp_cache async def get_actor(self, req) -> aiohttp.web.Response: actor_id = req.match_info.get("actor_id") - actors = await DataOrganizer.get_all_actors() + actors = await DataOrganizer.get_actor_infos(actor_ids=[actor_id]) return dashboard_optional_utils.rest_response( success=True, message="Actor details fetched.", detail=actors[actor_id] ) async def run(self, server): - self.get_all_actor_info = GetAllActorInfo(self._dashboard_head) + self.get_all_actor_info_client = GetAllActorInfoClient.create( + self._dashboard_head + ) + await asyncio.gather(self._update_actors(), self._cleanup_actors()) @staticmethod diff --git a/python/ray/dashboard/modules/event/event_agent.py b/python/ray/dashboard/modules/event/event_agent.py index f098d870cd35..9ea08c32a7c0 100644 --- a/python/ray/dashboard/modules/event/event_agent.py +++ b/python/ray/dashboard/modules/event/event_agent.py @@ -2,6 +2,7 @@ import logging import os import time +from concurrent.futures import ThreadPoolExecutor from typing import Union import ray._private.ray_constants as ray_constants @@ -16,6 +17,14 @@ logger = logging.getLogger(__name__) +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_EVENT_AGENT_TPE_MAX_WORKERS = ray_constants.env_integer( + "RAY_DASHBOARD_EVENT_AGENT_TPE_MAX_WORKERS", 1 +) + + class EventAgent(dashboard_utils.DashboardAgentModule): def __init__(self, dashboard_agent): super().__init__(dashboard_agent) @@ -31,6 +40,11 @@ def __init__(self, dashboard_agent): self.total_request_sent = 0 self.module_started = time.monotonic() + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_EVENT_AGENT_TPE_MAX_WORKERS, + thread_name_prefix="event_agent_executor", + ) + logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize) async def _connect_to_dashboard(self): @@ -107,7 +121,7 @@ async def run(self, server): self._monitor = monitor_events( self._event_dir, lambda data: create_task(self._cached_events.put(data)), - self._dashboard_agent.thread_pool_executor, + self._executor, ) await asyncio.gather( diff --git a/python/ray/dashboard/modules/event/event_head.py b/python/ray/dashboard/modules/event/event_head.py index c2408b716b39..f82a70607f8a 100644 --- a/python/ray/dashboard/modules/event/event_head.py +++ b/python/ray/dashboard/modules/event/event_head.py @@ -3,12 +3,14 @@ import os import time from collections import OrderedDict, defaultdict +from concurrent.futures import ThreadPoolExecutor from typing import Union import aiohttp.web import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils +from ray._private.ray_constants import env_integer from ray.core.generated import event_pb2, event_pb2_grpc from ray.dashboard.datacenter import DataSource from ray.dashboard.modules.event.event_utils import monitor_events, parse_event_strings @@ -21,6 +23,13 @@ MAX_EVENTS_TO_CACHE = int(os.environ.get("RAY_DASHBOARD_MAX_EVENTS_TO_CACHE", 10000)) +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_EVENT_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_EVENT_HEAD_TPE_MAX_WORKERS", 1 +) + class EventHead( dashboard_utils.DashboardHeadModule, event_pb2_grpc.ReportEventServiceServicer @@ -34,6 +43,11 @@ def __init__(self, dashboard_head): self.total_events_received = 0 self.module_started = time.monotonic() + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_EVENT_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="event_head_executor", + ) + @staticmethod def _update_events(event_list): # {job_id: {event_id: event}} @@ -107,7 +121,7 @@ async def run(self, server): self._monitor = monitor_events( self._event_dir, lambda data: self._update_events(parse_event_strings(data)), - self._dashboard_head._thread_pool_executor, + self._executor, ) @staticmethod diff --git a/python/ray/dashboard/modules/job/job_head.py b/python/ray/dashboard/modules/job/job_head.py index b2c436d29805..2e5f7ae6f4a7 100644 --- a/python/ray/dashboard/modules/job/job_head.py +++ b/python/ray/dashboard/modules/job/job_head.py @@ -313,7 +313,7 @@ async def upload_package(self, req: Request): try: data = await req.read() await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, + None, upload_package_to_gcs, package_uri, data, diff --git a/python/ray/dashboard/modules/metrics/metrics_head.py b/python/ray/dashboard/modules/metrics/metrics_head.py index b12a282b1777..04c80255fecb 100644 --- a/python/ray/dashboard/modules/metrics/metrics_head.py +++ b/python/ray/dashboard/modules/metrics/metrics_head.py @@ -12,6 +12,7 @@ import ray.dashboard.utils as dashboard_utils from ray._private.async_utils import enable_monitor_loop_lag from ray._private.ray_constants import env_integer +from ray._private.utils import get_or_create_event_loop from ray.dashboard.consts import ( AVAILABLE_COMPONENT_NAMES_FOR_METRICS, METRICS_INPUT_ROOT, @@ -324,6 +325,12 @@ async def record_dashboard_metrics(self): float(self._dashboard_proc.memory_full_info().rss) / 1.0e6 ) + loop = get_or_create_event_loop() + + self._dashboard_head.metrics.metrics_event_loop_tasks.labels(**labels).set( + len(asyncio.all_tasks(loop)) + ) + # Report the max lag since the last export, if any. if self._event_loop_lag_s_max is not None: self._dashboard_head.metrics.metrics_event_loop_lag.labels(**labels).set( diff --git a/python/ray/dashboard/modules/node/node_consts.py b/python/ray/dashboard/modules/node/node_consts.py index c3939cd66e3b..c70d86f86dfd 100644 --- a/python/ray/dashboard/modules/node/node_consts.py +++ b/python/ray/dashboard/modules/node/node_consts.py @@ -1,7 +1,7 @@ from ray._private.ray_constants import env_integer NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer( - "NODE_STATS_UPDATE_INTERVAL_SECONDS", 5 + "NODE_STATS_UPDATE_INTERVAL_SECONDS", 15 ) RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer( "RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT", 10 diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index caf098a95b27..e146d9409cb6 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -1,11 +1,13 @@ +import abc import asyncio import json import logging import os import time from collections import deque +from concurrent.futures import ThreadPoolExecutor from itertools import chain -from typing import AsyncGenerator, Dict, List, Tuple +from typing import AsyncGenerator, Dict, Iterable, List, Optional import aiohttp.web import grpc @@ -16,8 +18,13 @@ import ray.dashboard.utils as dashboard_utils from ray import NodeID from ray._private import ray_constants +from ray._private.collections_utils import split from ray._private.gcs_pubsub import GcsAioNodeInfoSubscriber -from ray._private.ray_constants import DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS +from ray._private.ray_constants import ( + DEBUG_AUTOSCALING_ERROR, + DEBUG_AUTOSCALING_STATUS, + env_integer, +) from ray._private.utils import get_or_create_event_loop from ray.autoscaler._private.util import ( LoadMetricsSummary, @@ -43,24 +50,20 @@ routes = dashboard_optional_utils.DashboardHeadRouteTable +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS", 1 +) + + def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: return dashboard_utils.message_to_dict( message, {"nodeId"}, always_print_fields_with_no_presence=True ) -def _map_batch_node_info_to_dict( - messages: Dict[NodeID, gcs_pb2.GcsNodeInfo] -) -> List[dict]: - return [_gcs_node_info_to_dict(message) for message in messages.values()] - - -def _list_gcs_node_info_to_dict( - messages: List[Tuple[bytes, gcs_pb2.GcsNodeInfo]] -) -> List[dict]: - return [_gcs_node_info_to_dict(node_info) for _, node_info in messages] - - def node_stats_to_dict(message): decode_keys = { "actorId", @@ -88,38 +91,55 @@ def node_stats_to_dict(message): message.core_workers_stats.extend(core_workers_stats) -class GetAllNodeInfo: +class GetAllNodeInfoClient(abc.ABC): """ Gets all node info from GCS via gRPC NodeInfoGcsService.GetAllNodeInfo. It makes the call via GcsAioClient or a direct gRPC stub, depending on the env var RAY_USE_OLD_GCS_CLIENT. """ - def __new__(cls, *args, **kwargs): + @classmethod + def create(cls, *args, **kwargs): use_old_client = os.getenv("RAY_USE_OLD_GCS_CLIENT") == "1" if use_old_client: return GetAllNodeInfoFromGrpc(*args, **kwargs) else: return GetAllNodeInfoFromNewGcsClient(*args, **kwargs) + async def __call__( + self, + *, + timeout: Optional[int] = None, + ) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]: + pass -class GetAllNodeInfoFromNewGcsClient: + +class GetAllNodeInfoFromNewGcsClient(GetAllNodeInfoClient): def __init__(self, dashboard_head): self.gcs_aio_client = dashboard_head.gcs_aio_client - async def __call__(self, timeout) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]: + async def __call__( + self, + *, + timeout: Optional[int] = None, + ) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]: return await self.gcs_aio_client.get_all_node_info(timeout=timeout) -class GetAllNodeInfoFromGrpc: +class GetAllNodeInfoFromGrpc(GetAllNodeInfoClient): def __init__(self, dashboard_head): gcs_channel = dashboard_head.aiogrpc_gcs_channel self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub( gcs_channel ) - async def __call__(self, timeout) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]: + async def __call__( + self, + *, + timeout: Optional[int] = None, + ) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]: request = gcs_service_pb2.GetAllNodeInfoRequest() + reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout) if reply.status.code != 0: raise Exception(f"Failed to GetAllNodeInfo: {reply.status.message}") @@ -132,9 +152,11 @@ async def __call__(self, timeout) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]: class NodeHead(dashboard_utils.DashboardHeadModule): def __init__(self, dashboard_head): super().__init__(dashboard_head) + self._stubs = {} - self.get_all_node_info = None + self._get_all_node_info_client: GetAllNodeInfoClient = None self._collect_memory_info = False + DataSource.nodes.signal.append(self._update_stubs) # The time where the module is started. self._module_start_time = time.time() @@ -146,6 +168,11 @@ def __init__(self, dashboard_head): self._gcs_aio_client = dashboard_head.gcs_aio_client self._gcs_address = dashboard_head.gcs_address + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="node_head_executor", + ) + async def _update_stubs(self, change): if change.old: node_id, node_info = change.old @@ -186,27 +213,38 @@ async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: # it happens after the subscription. That is, an update between # get-all-node-info and the subscription is not missed. # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use - all_node_info = await self.get_all_node_info(timeout=None) + all_node_info = await self._get_all_node_info_client(timeout=None) + + def _convert_to_dict(messages: Iterable[gcs_pb2.GcsNodeInfo]) -> List[dict]: + return [_gcs_node_info_to_dict(m) for m in messages] - all_node_dicts = await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, - _map_batch_node_info_to_dict, - all_node_info, + all_node_infos = await get_or_create_event_loop().run_in_executor( + self._executor, + _convert_to_dict, + all_node_info.values(), ) - for node in all_node_dicts: + + for node in all_node_infos: yield node while True: try: - published = await subscriber.poll( + node_id_updated_info_tuples = await subscriber.poll( batch_size=node_consts.RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE ) - updated_dicts = await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, - _list_gcs_node_info_to_dict, - published, + + if node_id_updated_info_tuples: + _, updated_infos_proto = zip(*node_id_updated_info_tuples) + else: + updated_infos_proto = [] + + updated_infos = await get_or_create_event_loop().run_in_executor( + self._executor, + _convert_to_dict, + updated_infos_proto, ) - for node in updated_dicts: + + for node in updated_infos: yield node except Exception: logger.exception("Failed handling updated nodes.") @@ -397,66 +435,100 @@ async def get_node(self, req) -> aiohttp.web.Response: @async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS) async def _update_node_stats(self): - # Copy self._stubs to avoid `dictionary changed size during iteration`. + timeout = max(2, node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS - 1) + + # NOTE: We copy stubs to make sure + # it doesn't change during the iteration (since its being updated + # from another async task) + current_stub_node_id_tuples = list(self._stubs.items()) + + if current_stub_node_id_tuples: + node_ids, _ = zip(*current_stub_node_id_tuples) + else: + node_ids = [] + get_node_stats_tasks = [] - nodes = list(self._stubs.items()) - TIMEOUT = node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS - 1 - for node_id, stub in nodes: + for i, (node_id, stub) in enumerate(current_stub_node_id_tuples): node_info = DataSource.nodes.get(node_id) if node_info["state"] != "ALIVE": continue + get_node_stats_tasks.append( stub.GetNodeStats( node_manager_pb2.GetNodeStatsRequest( include_memory_info=self._collect_memory_info ), - timeout=min(2, TIMEOUT), + timeout=timeout, ) ) - replies = await asyncio.gather( - *get_node_stats_tasks, - return_exceptions=True, - ) + responses = [] + + # NOTE: We're chunking up fetching of the stats to run in batches of no more + # than 100 nodes at a time to avoid flooding the event-loop's queue + # with potentially a large, uninterrupted sequence of tasks updating + # the node stats for very large clusters. + for get_node_stats_tasks_chunk in split(get_node_stats_tasks, 100): + current_chunk_responses = await asyncio.gather( + *get_node_stats_tasks_chunk, + return_exceptions=True, + ) - def postprocess(nodes, replies): + responses.extend(current_chunk_responses) + + # We're doing short (25ms) yield after every chunk to make sure + # - We're not overloading the event-loop with excessive # of tasks + # - Allowing 10k nodes stats fetches be sent out performed in 2.5s + await asyncio.sleep(0.025) + + def postprocess(node_id_response_tuples): """Pure function reorganizing the data into {node_id: stats}.""" new_node_stats = {} - for node_info, reply in zip(nodes, replies): - node_id, _ = node_info - if isinstance(reply, asyncio.CancelledError): + + for node_id, response in node_id_response_tuples: + if isinstance(response, asyncio.CancelledError): pass - elif isinstance(reply, grpc.RpcError): - if reply.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - logger.exception( + elif isinstance(response, grpc.RpcError): + if response.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + message = ( f"Cannot reach the node, {node_id}, after timeout " - f" {TIMEOUT}. This node may have been overloaded, " + f" {timeout}. This node may have been overloaded, " "terminated, or the network is slow." ) - elif reply.code() == grpc.StatusCode.UNAVAILABLE: - logger.exception( + elif response.code() == grpc.StatusCode.UNAVAILABLE: + message = ( f"Cannot reach the node, {node_id}. " "The node may have been terminated." ) else: - logger.exception(f"Error updating node stats of {node_id}.") - logger.exception(reply) - elif isinstance(reply, Exception): - logger.exception(f"Error updating node stats of {node_id}.") - logger.exception(reply) + message = f"Error updating node stats of {node_id}." + + logger.error(message, exc_info=response) + elif isinstance(response, Exception): + logger.error( + f"Error updating node stats of {node_id}.", exc_info=response + ) else: - new_node_stats[node_id] = node_stats_to_dict(reply) + new_node_stats[node_id] = node_stats_to_dict(response) + return new_node_stats + # NOTE: Zip will silently truncate to shorter argument that potentially + # could lead to subtle hard to catch issues, hence the assertion + assert len(node_ids) == len(responses) + new_node_stats = await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, postprocess, nodes, replies + self._executor, postprocess, zip(node_ids, responses) ) + for node_id, new_stat in new_node_stats.items(): DataSource.node_stats[node_id] = new_stat async def run(self, server): - self.get_all_node_info = GetAllNodeInfo(self._dashboard_head) + self._get_all_node_info_client = GetAllNodeInfoClient.create( + self._dashboard_head + ) await asyncio.gather( self._update_nodes(), self._update_node_stats(), diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index 8c655bfff4ea..cb71105046e3 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -17,6 +17,7 @@ wait_until_server_available, ) from ray.cluster_utils import Cluster +from ray.dashboard.consts import RAY_DASHBOARD_STATS_UPDATING_INTERVAL from ray.dashboard.tests.conftest import * # noqa logger = logging.getLogger(__name__) @@ -79,7 +80,9 @@ def getpid(self): webui_url = format_web_url(webui_url) node_id = ray_start_with_dashboard["node_id"] - timeout_seconds = 10 + # NOTE: Leaving sum buffer time for data to get refreshed + timeout_seconds = RAY_DASHBOARD_STATS_UPDATING_INTERVAL * 1.5 + start_time = time.time() last_ex = None while True: diff --git a/python/ray/dashboard/modules/reporter/reporter_agent.py b/python/ray/dashboard/modules/reporter/reporter_agent.py index b6cb1d11af2d..84cdf5803be4 100644 --- a/python/ray/dashboard/modules/reporter/reporter_agent.py +++ b/python/ray/dashboard/modules/reporter/reporter_agent.py @@ -7,6 +7,7 @@ import sys import traceback from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from typing import List, Optional, Tuple, TypedDict, Union from opencensus.stats import stats as stats_module @@ -20,7 +21,7 @@ import ray.dashboard.utils as dashboard_utils from ray._private import utils from ray._private.metrics_agent import Gauge, MetricsAgent, Record -from ray._private.ray_constants import DEBUG_AUTOSCALING_STATUS +from ray._private.ray_constants import DEBUG_AUTOSCALING_STATUS, env_integer from ray._raylet import WorkerID from ray.core.generated import reporter_pb2, reporter_pb2_grpc from ray.dashboard import k8s_utils @@ -47,6 +48,13 @@ # Using existence of /sys/fs/cgroup as the criterion is consistent with # Ray's existing resource logic, see e.g. ray._private.utils.get_num_cpus(). +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_REPORTER_AGENT_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_REPORTER_AGENT_TPE_MAX_WORKERS", 1 +) + def recursive_asdict(o): if isinstance(o, tuple) and hasattr(o, "_asdict"): @@ -393,6 +401,11 @@ def __init__(self, dashboard_agent): f"{reporter_consts.REPORTER_PREFIX}" f"{self._dashboard_agent.node_id}" ) + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_REPORTER_AGENT_TPE_MAX_WORKERS, + thread_name_prefix="reporter_agent_executor", + ) + async def GetTraceback(self, request, context): pid = request.pid native = request.native @@ -1223,9 +1236,9 @@ async def _run_loop(self, publisher): ) # NOTE: Stats collection is executed inside the thread-pool - # executor (TPE) to avoid blocking the Dashboard's event-loop + # executor (TPE) to avoid blocking the Agent's event-loop json_payload = await loop.run_in_executor( - self._dashboard_agent.thread_pool_executor, + self._executor, self._compose_stats_payload, autoscaler_status_json_bytes, ) diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 9ca0894d3c79..6553f561ea50 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -1,6 +1,7 @@ import asyncio import json import logging +from concurrent.futures import ThreadPoolExecutor from typing import List, Optional, Tuple import aiohttp.web @@ -15,6 +16,7 @@ DEBUG_AUTOSCALING_STATUS_LEGACY, GLOBAL_GRPC_OPTIONS, KV_NAMESPACE_CLUSTER, + env_integer, ) from ray._private.usage.usage_constants import CLUSTER_METADATA_KEY from ray._private.utils import get_or_create_event_loop, init_grpc_channel @@ -43,6 +45,13 @@ } \n""" +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_REPORTER_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_REPORTER_HEAD_TPE_MAX_WORKERS", 1 +) + class ReportHead(dashboard_utils.DashboardHeadModule): def __init__(self, dashboard_head): @@ -63,6 +72,11 @@ def __init__(self, dashboard_head): self._gcs_aio_client = dashboard_head.gcs_aio_client self._state_api = None + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_REPORTER_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="reporter_head_executor", + ) + async def _update_stubs(self, change): if change.old: node_id, port = change.old @@ -612,7 +626,7 @@ async def run(self, server): # TODO(ryw): unify the StateAPIManager in reporter_head and state_head. self._state_api = StateAPIManager( self._state_api_data_source_client, - self._dashboard_head._thread_pool_executor, + self._executor, ) # Need daemon True to avoid dashboard hangs at exit. @@ -640,7 +654,7 @@ async def run(self, server): # NOTE: Every iteration is executed inside the thread-pool executor # (TPE) to avoid blocking the Dashboard's event-loop parsed_data = await loop.run_in_executor( - self._dashboard_head._thread_pool_executor, json.loads, data + self._executor, json.loads, data ) node_id = key.split(":")[-1] diff --git a/python/ray/dashboard/modules/state/state_head.py b/python/ray/dashboard/modules/state/state_head.py index f4c9515b6517..74f7871068b3 100644 --- a/python/ray/dashboard/modules/state/state_head.py +++ b/python/ray/dashboard/modules/state/state_head.py @@ -2,6 +2,7 @@ import functools import logging from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from datetime import datetime from typing import AsyncIterable, Callable, List, Optional, Tuple @@ -11,6 +12,7 @@ import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils +from ray._private.ray_constants import env_integer from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag from ray.dashboard.consts import ( RAY_STATE_SERVER_MAX_HTTP_REQUEST, @@ -41,6 +43,13 @@ logger = logging.getLogger(__name__) routes = dashboard_optional_utils.DashboardHeadRouteTable +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_STATE_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_STATE_HEAD_TPE_MAX_WORKERS", 1 +) + class RateLimitedModule(ABC): """Simple rate limiter @@ -151,6 +160,11 @@ def __init__( self._state_api = None self._log_api = None + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_STATE_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="state_head_executor", + ) + DataSource.nodes.signal.append(self._update_raylet_stubs) DataSource.agents.signal.append(self._update_agent_stubs) @@ -550,7 +564,7 @@ async def run(self, server): ) self._state_api = StateAPIManager( self._state_api_data_source_client, - self._dashboard_head._thread_pool_executor, + self._executor, ) self._log_api = LogsManager(self._state_api_data_source_client) diff --git a/python/ray/dashboard/modules/train/train_head.py b/python/ray/dashboard/modules/train/train_head.py index fd7c7e1d5eb9..b23995a82421 100644 --- a/python/ray/dashboard/modules/train/train_head.py +++ b/python/ray/dashboard/modules/train/train_head.py @@ -98,20 +98,21 @@ async def _add_actor_status_and_update_run_status(self, train_runs): TrainWorkerInfoWithDetails, ) - try: - logger.info("Getting all actor info from GCS.") - actors = await DataOrganizer.get_all_actors() - - except Exception: - logger.exception("Error Getting all actor info from GCS.") - train_runs_with_details: List[TrainRunInfoWithDetails] = [] for train_run in train_runs.values(): worker_infos_with_details: List[TrainWorkerInfoWithDetails] = [] + actor_ids = [worker.actor_id for worker in train_run.workers] + + logger.info(f"Getting all actor info from GCS (actor_ids={actor_ids})") + + train_run_actors = await DataOrganizer.get_actor_infos( + actor_ids=actor_ids, + ) + for worker_info in train_run.workers: - actor = actors.get(worker_info.actor_id, None) + actor = train_run_actors.get(worker_info.actor_id, None) # Add hardware metrics to API response if actor: gpus = [ @@ -161,9 +162,8 @@ async def _add_actor_status_and_update_run_status(self, train_runs): # function (e.g., system failure or user interruption) that crashed the # train controller. # We need to detect this case and mark the train run as ABORTED. - controller_actor_status = actors.get( - train_run.controller_actor_id, None - ).get("state") + actor = train_run_actors.get(train_run.controller_actor_id) + controller_actor_status = actor.get("state") if actor else None if ( controller_actor_status == ActorStatusEnum.DEAD and train_run.run_status == RunStatusEnum.RUNNING diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 0f3df9703133..06641d0d5777 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -690,6 +690,14 @@ py_test( deps = ["//:ray_lib", ":conftest"], ) +py_test( + name = "test_collections_utils", + size = "small", + srcs = ["test_collections_utils.py"], + tags = ["exclusive", "small_size_python_tests", "team:core"], + deps = ["//:ray_lib", ":conftest"] +) + py_test( name = "test_runtime_env_validation", size = "small", diff --git a/python/ray/tests/test_collections_utils.py b/python/ray/tests/test_collections_utils.py new file mode 100644 index 000000000000..75b0be429097 --- /dev/null +++ b/python/ray/tests/test_collections_utils.py @@ -0,0 +1,18 @@ +import sys + +import pytest + +from ray._private.collections_utils import split + + +def test_split(): + ints = list(range(0, 5)) + + assert list(split(ints, 5)) == [ints] + assert list(split(ints, 10)) == [ints] + assert list(split(ints, 1)) == [[e] for e in ints] + assert list(split(ints, 2)) == [[0, 1], [2, 3], [4]] + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_object_store_metrics.py b/python/ray/tests/test_object_store_metrics.py index 7e72919e761b..489e86ce0fd6 100644 --- a/python/ray/tests/test_object_store_metrics.py +++ b/python/ray/tests/test_object_store_metrics.py @@ -11,6 +11,7 @@ wait_for_condition, ) from ray._private.worker import RayContext +from ray.dashboard.consts import RAY_DASHBOARD_STATS_UPDATING_INTERVAL KiB = 1 << 10 MiB = 1 << 20 @@ -381,7 +382,7 @@ def verify(): assert object_store_memory_bytes_from_dashboard == 500 * MiB return True - wait_for_condition(verify) + wait_for_condition(verify, timeout=RAY_DASHBOARD_STATS_UPDATING_INTERVAL * 1.5) if __name__ == "__main__": diff --git a/python/ray/tests/test_output.py b/python/ray/tests/test_output.py index 563b8aa282f2..6d8c3a3be421 100644 --- a/python/ray/tests/test_output.py +++ b/python/ray/tests/test_output.py @@ -101,16 +101,18 @@ def test_spill_logs(): for _ in range(10): x.append(ray.put(bytes(100 * 1024 * 1024))) -""" - proc = run_string_as_driver_nonblocking(script, env={"RAY_verbose_spill_logs": "1"}) - out_str = proc.stdout.read().decode("ascii") + proc.stderr.read().decode("ascii") - print(out_str) +""" + stdout_str, stderr_str = run_string_as_driver_stdout_stderr( + script, env={"RAY_verbose_spill_logs": "1"} + ) + out_str = stdout_str + stderr_str assert "Spilled " in out_str - proc = run_string_as_driver_nonblocking(script, env={"RAY_verbose_spill_logs": "0"}) - out_str = proc.stdout.read().decode("ascii") + proc.stderr.read().decode("ascii") - print(out_str) + stdout_str, stderr_str = run_string_as_driver_stdout_stderr( + script, env={"RAY_verbose_spill_logs": "0"} + ) + out_str = stdout_str + stderr_str assert "Spilled " not in out_str diff --git a/release/nightly_tests/stress_tests/test_state_api_scale.py b/release/nightly_tests/stress_tests/test_state_api_scale.py index 5d5658216f39..081066f24b15 100644 --- a/release/nightly_tests/stress_tests/test_state_api_scale.py +++ b/release/nightly_tests/stress_tests/test_state_api_scale.py @@ -188,7 +188,7 @@ def create_objs(self, num_objects): for i in range(num_objects): # Object size shouldn't matter here. self.objs.append(ray.put(bytearray(os.urandom(1024)))) - if i + 1 % 100 == 0: + if (i + 1) % 100 == 0: logger.info(f"Created object {i+1}...") return self.objs