From 38f6ab354efc52243ac04822150321615c81ae52 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 16:57:11 +0000 Subject: [PATCH 1/9] [Data] ConcurrencyCapBackpressurePolicy - Handle internal output queue buildup Signed-off-by: Srinath Krishnamachari --- .../concurrency_cap_backpressure_policy.py | 413 +++++++++++++- .../_internal/execution/resource_manager.py | 161 +++--- python/ray/data/context.py | 16 +- .../data/tests/test_backpressure_policies.py | 511 +++++++++++++++++- release/release_tests.yaml | 1 + 5 files changed, 1022 insertions(+), 80 deletions(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 77b7d17b0b0a..8e77ecc3c02f 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -1,7 +1,10 @@ import logging -from typing import TYPE_CHECKING +import math +from collections import defaultdict, deque +from typing import TYPE_CHECKING, Deque, Dict from .backpressure_policy import BackpressurePolicy +from ray.data._internal.execution.operators.map_operator import MapOperator from ray.data._internal.execution.operators.task_pool_map_operator import ( TaskPoolMapOperator, ) @@ -10,6 +13,7 @@ from ray.data._internal.execution.interfaces.physical_operator import ( PhysicalOperator, ) + from ray.data._internal.execution.operators.map_operator import MapOperator logger = logging.getLogger(__name__) @@ -17,27 +21,422 @@ class ConcurrencyCapBackpressurePolicy(BackpressurePolicy): """A backpressure policy that caps the concurrency of each operator. - The policy will limit the number of concurrently running tasks based on its - concurrency cap parameter. + This policy dynamically limits the number of concurrent tasks per operator + based on queue pressure. It combines: + + - Adaptive threshold built from EWMA of the queue level and its + absolute deviation: ``threshold = max(level + K * dev, current_queue_size_bytes)``. + The ``current_queue_size_bytes`` term enables immediate upward revision to avoid + throttling ramp-up throughput. + + Why this threshold works: + - level + K*dev: Sets threshold above typical queue size by K standard deviations + - max(..., current_queue_size_bytes): Prevents false throttling during legitimate ramp-up periods + - When queue grows faster than EWMA can track, current_queue_size_bytes > level + K*dev + - This allows the system to "catch up" to the new higher baseline before throttling + + - Quantized step controller that nudges running concurrency by + ``{-1, 0, +1, +2}`` using normalized pressure and trend signals. + + Key Concepts: + - Level (EWMA): Typical queue size; slowly tracks the central tendency. + - Deviation (EWMA): Typical absolute deviation around the level; acts as + a normalization scale for pressure and trend signals. + - Threshold: Dynamic limit derived from observed signal: if the current + queue exceeds the threshold, we consider backoff. The ``max(..., current_queue_size_bytes)`` + makes this instantaneously responsive upward. + - Instantaneous pressure: How far the current queue is from threshold, + normalized by deviation. + - Trend: Whether the queue is rising/falling over a short horizon (recent + vs older HISTORY_LEN/2 samples), normalized by deviation. + + Example: + Consider an operator with configured cap=10 and queue pressure over time: + + Queue samples: [100, 120, 140, 160, 180, 200, 220, 240, 260, 280] + Threshold: 150 (level=180, dev=20, K=4.0) + + Ramp-up scenario (queue growing, pressure < 0): + - pressure_signal = (100-150)/20 = -2.5, trend_signal = -1.0 + - Decision: step = +2 (strong growth, low pressure) + - Result: concurrency increases from 8 -> 10 (capped at configured max) + + Dial-down scenario (queue growing, pressure > 0): + - pressure_signal = (200-150)/20 = +2.5, trend_signal = +1.0 + - Decision: step = -1 (high pressure, growing trend) + - Result: concurrency decreases from 10 -> 9 + + Stable scenario (queue stable, pressure ~ 0): + - pressure_signal = (150-150)/20 = 0.0, trend_signal = 0.0 + - Decision: step = 0 (no change needed) + - Result: concurrency stays at 10 NOTE: Only support setting concurrency cap for `TaskPoolMapOperator` for now. TODO(chengsu): Consolidate with actor scaling logic of `ActorPoolMapOperator`. """ + # Queue history window for recent-trend estimation. Small window to capture recent trend. + HISTORY_LEN = 10 + # Smoothing factor for both level and dev. + EWMA_ALPHA = 0.2 + # Deviation multiplier to define "over-threshold". + K_DEV = 4.0 + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._concurrency_caps: dict["PhysicalOperator", float] = {} + # Explicit concurrency caps for each operator. Infinite if not specified. + self._concurrency_caps: Dict["PhysicalOperator", float] = {} + + # Queue history for recent-trend estimation. Small window to capture recent trend. + self._queue_history: Dict["PhysicalOperator", Deque[int]] = defaultdict( + lambda: deque(maxlen=self.HISTORY_LEN) + ) + + # Per-operator cached threshold (bootstrapped from first sample). + self._queue_thresholds: Dict["PhysicalOperator", int] = defaultdict(int) + + # EWMA state for level and absolute deviation. + self._q_level_nbytes: Dict["PhysicalOperator", float] = defaultdict(float) + + # EWMA state for absolute deviation. + self._q_level_dev: Dict["PhysicalOperator", float] = defaultdict(float) + + # Track last effective cap per operator for change detection. + self._last_effective_caps: Dict["PhysicalOperator", int] = {} + + # Initialize caps from operators (infinite if unset) for op, _ in self._topology.items(): if isinstance(op, TaskPoolMapOperator) and op.get_concurrency() is not None: self._concurrency_caps[op] = op.get_concurrency() else: self._concurrency_caps[op] = float("inf") + # Whether to cap the concurrency of an operator based on its and downstream's queue size. + self.enable_dynamic_output_queue_size_backpressure = ( + self._data_context.enable_dynamic_output_queue_size_backpressure + ) + logger.debug( - "ConcurrencyCapBackpressurePolicy initialized with: " - f"{self._concurrency_caps}" + "ConcurrencyCapBackpressurePolicy caps: %s, cap based on queue size: %s", + self._concurrency_caps, + self.enable_dynamic_output_queue_size_backpressure, ) + def _update_ewma_asymmetric(self, prev_value: float, sample: float) -> float: + """ + Update EWMA with asymmetric behavior: fast rise, slow fall. + + Args: + prev_value: Previous EWMA value + sample: New sample value + + Returns: + Updated EWMA value + """ + if prev_value <= 0: + return sample + + alpha_up = 1.0 - (1.0 - self.EWMA_ALPHA) ** 2 # fast rise + alpha = alpha_up if sample > prev_value else self.EWMA_ALPHA # slow fall + return (1 - alpha) * prev_value + alpha * sample + def can_add_input(self, op: "PhysicalOperator") -> bool: - return op.metrics.num_tasks_running < self._concurrency_caps[op] + """Return whether `op` may accept another input now. + + Admission control logic: + * Under threshold: Allow full concurrency up to configured cap + * Over threshold: Adjust concurrency using step controller (-1,0,+1,+2) + based on pressure and trend signals + + Args: + op: The operator under consideration. + + Returns: + True if admitting one more input is allowed. + """ + running = op.metrics.num_tasks_running + if ( + not isinstance(op, MapOperator) + or not self.enable_dynamic_output_queue_size_backpressure + ): + return running < self._concurrency_caps[op] + + # Observe fresh queue size for this operator and its downstream. + current_queue_size_bytes = ( + self._resource_manager.get_op_outputs_usage_with_internal_and_downstream(op) + ) + + # Update short history and refresh the adaptive threshold. + self._queue_history[op].append(current_queue_size_bytes) + threshold = self._update_queue_threshold(op, current_queue_size_bytes) + + # If configured to cap based on queue size, use the effective cap. + if current_queue_size_bytes > threshold: + # Over-threshold: potentially back off via effective cap. + effective_cap = self._effective_cap(op) + is_capped = running < effective_cap + + last_effective_cap = self._last_effective_caps.get(op, None) + if last_effective_cap != effective_cap: + logger.debug( + "Effective concurrency cap changed for operator %s: %d -> %d" + "running=%d tasks, queue=%d bytes, threshold=%d bytes", + op.name, + last_effective_cap, + effective_cap, + running, + current_queue_size_bytes, + threshold, + ) + self._last_effective_caps[op] = effective_cap + + return is_capped + else: + # Under-threshold: only enforce the configured cap. + return running < self._concurrency_caps[op] + + def _update_queue_threshold( + self, op: "PhysicalOperator", current_queue_size_bytes: int + ) -> int: + """Update and return the current adaptive threshold for `op`. + + Motivation: Adaptive thresholds prevent both over-throttling (too aggressive) and + under-throttling (too permissive). The logic balances responsiveness with stability: + - Fast upward response to pressure spikes (immediate threshold increase) + - Gradual downward response to prevent oscillation (EWMA smoothing) + - Complete reset when idle (threshold = 0) to avoid stuck high thresholds + + Args: + op: Operator whose threshold is being updated. + current_queue_size_bytes: Current total queued bytes for this operator + downstream. + + Returns: + The updated threshold in bytes. + + Examples: + # Example 1: First sample (bootstrap) + # Input: current_queue_size_bytes = 1000, level_prev = 0, dev_prev = 0 + # EWMA: level = 1000, dev = 0 (first sample) + # Base: 1000 + 4*0 = 1000 + # Threshold: max(1000, 1000) = 1000 + # prev_threshold = 0, threshold = 1000 + # Result: 1000 (bootstrap) + + # Example 2: Upward adjustment (immediate) + # Input: current_queue_size_bytes = 1500, level_prev = 1000, dev_prev = 100 + # EWMA: level = 1000 + 0.2*(1500-1000) = 1100, dev = 100 + 0.2*(500-100) = 180 + # Base: 1100 + 4*180 = 1820 + # Threshold: max(1820, 1500) = 1820 + # prev_threshold = 1000, threshold = 1820 + # Result: 1820 (immediate upward) + + # Example 3: Downward adjustment (smoothed) + # Input: current_queue_size_bytes = 100, level_prev = 200, dev_prev = 50 + # EWMA: level = 200 + 0.2*(100-200) = 180, dev = 50 + 0.2*(100-50) = 60 + # Base: 180 + 4*60 = 420 + # Threshold: max(420, 100) = 420 + # prev_threshold = 500, threshold = 420 + # smoothed = asymmetric_ewma(500, 420) = 500 + 0.2*(420-500) = 484 + # Result: 484 (gradual downward adjustment using asymmetric EWMA) + + # Example 4: System becomes idle + # Input: current_queue_size_bytes = 0, level_prev = 200, dev_prev = 50 + # EWMA: level = 200 + 0.2*(0-200) = 160, dev = 50 + 0.2*(200-50) = 80 + # Base: 160 + 4*80 = 480 + # Threshold: max(480, 0) = 480 + # prev_threshold = 484, threshold = 480 + # smoothed = asymmetric_ewma(484, 480) = 484 + 0.2*(480-484) = 483 + # Result: 483 (gradual downward adjustment using asymmetric EWMA) + + # Example 5: Continued idle (gradual decay) + # Input: current_queue_size_bytes = 0, level_prev = 160, dev_prev = 80 + # EWMA: level = 160 + 0.2*(0-160) = 128, dev = 80 + 0.2*(160-80) = 96 + # Base: 128 + 4*96 = 512 + # Threshold: max(512, 0) = 512 + # prev_threshold = 483, threshold = 512 + # Result: 512 (gradual upward, EWMA still adjusting using asymmetric EWMA) + + # Example 6: After many idle samples (threshold finally resets) + # Input: current_queue_size_bytes = 0, level_prev = 50, dev_prev = 10 + # EWMA: level = 50 + 0.2*(0-50) = 40, dev = 10 + 0.2*(50-10) = 18 + # Base: 40 + 4*18 = 112 + # Threshold: max(112, 0) = 112 + # prev_threshold = 200, threshold = 112 + # smoothed = asymmetric_ewma(200, 112) = 200 + 0.2*(112-200) = 182 + # Result: 182 (gradual downward adjustment using asymmetric EWMA) + + """ + hist = self._queue_history[op] + if not hist: + return 0 + + q = float(current_queue_size_bytes) + + # Step 1: update EWMAs + level_prev = self._q_level_nbytes[op] + dev_prev = self._q_level_dev[op] + + # Update EWMA level (typical queue size) with asymmetric behavior + # Why asymmetric? Quick to detect problems, slow to recover (prevents oscillation) + # Example: queue grows 100->200->150, EWMA follows 100->180->170 + # (jumps up fast, drops down slow) + level = self._update_ewma_asymmetric(level_prev, q) + + # Update EWMA deviation (typical absolute deviation) with asymmetric behavior + # Same logic: quick to detect high variability, slow to recover (prevents noise) + # Example: deviation jumps 10->30->20, EWMA follows 10->25->23 (fast up, slow down) + dev_sample = abs(q - level) + dev = self._update_ewma_asymmetric(dev_prev, dev_sample) + + self._q_level_nbytes[op] = level + self._q_level_dev[op] = dev + + # Step 2: base threshold from level & dev + # Example: level=1000, dev=200, K_DEV=4.0 -> base = 1000 + 4*200 = 1800 + base = level + self.K_DEV * dev + + # Step 3: fast ramp-up + threshold = max(1, int(max(base, q))) + + # Step 4: cache & return with gentle downward response using EWMA_ALPHA + prev_threshold = self._queue_thresholds[op] + + # Idle/off allowed + if threshold == 0: + self._queue_thresholds[op] = 0 + return 0 + + # Bootstrap + if prev_threshold == 0: + self._queue_thresholds[op] = max(1, threshold) + return self._queue_thresholds[op] + + # Upward: apply immediately + if threshold >= prev_threshold: + self._queue_thresholds[op] = max(1, threshold) + return self._queue_thresholds[op] + + # Downward: smooth using asymmetric EWMA + # Prevents oscillation by allowing gradual downward adjustments + # Uses same asymmetric behavior as EWMA: slow to adjust downward + # Example: prev_threshold=200, threshold=100 -> smoothed using asymmetric EWMA + smoothed = int(self._update_ewma_asymmetric(prev_threshold, threshold)) + if smoothed >= prev_threshold: + # Ensures progress when small deltas round to 0 + # Example: prev_threshold=200, threshold=195 -> smoothed=199, but 199>=200, so use 199 + smoothed = prev_threshold - 1 + + self._queue_thresholds[op] = max(1, smoothed) + return self._queue_thresholds[op] + + def _effective_cap(self, op: "PhysicalOperator") -> int: + """Compute a reduced concurrency cap via a tiny {-1,0,+1,+2} controller. + + Pressure and trend signals: + - pressure_signal: How far current queue is above threshold (normalized by absolute deviation estimate) + Formula: (current_queue_size_bytes - threshold) / max(dev, 1) + + Examples: + * queue=200, threshold=150, dev=20 -> pressure = (200-150)/20 = +2.5 + Meaning: Queue is 2.5x absolute deviation estimate level above threshold (high pressure, throttle!) + * queue=100, threshold=150, dev=20 -> pressure = (100-150)/20 = -2.5 + Meaning: Queue is 2.5x absolute deviation estimate level below threshold (low pressure, safe) + * queue=150, threshold=150, dev=20 -> pressure = (150-150)/20 = 0.0 + Meaning: Queue exactly at threshold (neutral pressure) + + - trend_signal: Whether queue is growing or shrinking (normalized by absolute deviation estimate) + Formula: (avg(recent_window) - avg(older_window)) / max(dev, 1) + + Examples: + * recent_avg=180, older_avg=160, dev=20 -> trend = (180-160)/20 = +1.0 + Meaning: Queue growing at 1x absolute deviation estimate level (upward trend, getting worse) + * recent_avg=140, older_avg=160, dev=20 -> trend = (140-160)/20 = -1.0 + Meaning: Queue shrinking at 1x absolute deviation estimate level (downward trend, getting better) + * recent_avg=160, older_avg=160, dev=20 -> trend = (160-160)/20 = 0.0 + Meaning: Queue stable (no trend) + + Controller decision logic: + - Decides concurrency adjustment {-1,0,+1,+2} based on pressure and trend signals + + Decision rules table: + +----------+----------+----------+--------------------------------+------------------+ + | Pressure | Trend | Step | Action | Example | + +----------+----------+----------+--------------------------------+------------------+ + | >= +2.0 | >= +1.0 | -1 | Emergency backoff | +2.5, +1.0 -> -1 | + | | | | (immediate reduction to | | + | | | | prevent overload) | | + +----------+----------+----------+--------------------------------+------------------+ + | >= +1.0 | > 0.0 | 0 | Wait and see | +1.5, +0.5 -> 0 | + | | | | (let current level stabilize) | | + +----------+----------+----------+--------------------------------+------------------+ + | <= -1.0 | <= -1.0 | +1 | Conservative growth | -1.5, -1.0 -> +1 | + | | | | (safe to increase when | | + | | | | improving) | | + +----------+----------+----------+--------------------------------+------------------+ + | <= -2.0 | <= -2.0 | +2 | Aggressive growth | -2.5, -2.0 -> +2 | + | | | | (underutilized and improving | | + | | | | rapidly) | | + +----------+----------+----------+--------------------------------+------------------+ + | Other | Other | 0 | Hold | +0.5, -0.5 -> 0 | + | | | | (moderate signals, no clear | | + | | | | direction) | | + +----------+----------+----------+--------------------------------+------------------+ + + Logic summary: + - High pressure + growing trend = emergency backoff + - High pressure + stable trend = wait and see + - Low pressure + shrinking trend = safe to grow + - Very low pressure + strong improvement = aggressive growth + - Moderate signals = maintain current concurrency + + Args: + op: Operator whose effective cap we compute. + + Returns: + An integer cap in [1, configured_cap]. + """ + hist = self._queue_history[op] + running = op.metrics.num_tasks_running + + # Need enough samples to evaluate short trend (recent + older windows). + recent_window = self.HISTORY_LEN // 2 + older_window = self.HISTORY_LEN // 2 + min_samples = recent_window + older_window + + if len(hist) < min_samples: + return max(1, running) + + # Trend windows and normalized signals + h = list(hist) + recent_avg = sum(h[-recent_window:]) / float(recent_window) + older_avg = sum(h[-(recent_window + older_window) : -recent_window]) / float( + older_window + ) + dev = max(1.0, self._q_level_dev[op]) + threshold = float(max(1, self._queue_thresholds[op])) + current_queue_size_bytes = float(hist[-1]) + + # Calculate normalized pressure and trend signals + scale = max(1.0, float(dev)) + pressure_signal = (current_queue_size_bytes - threshold) / scale + trend_signal = (recent_avg - older_avg) / scale + + # Quantized controller decision + if pressure_signal >= 2.0 and trend_signal >= 1.0: + step = -1 + elif pressure_signal >= 1.0 and trend_signal > 0.0: + step = 0 + elif pressure_signal <= -1.0 and trend_signal <= -1.0: + step = +1 + elif pressure_signal <= -2.0 and trend_signal <= -2.0: + step = +2 + else: + step = 0 + + # Apply step to current running concurrency, clamp by configured cap. + target = max(1, running + step) + cap_cfg = self._concurrency_caps[op] + if not math.isinf(cap_cfg): + target = min(target, int(cap_cfg)) + return target \ No newline at end of file diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index be23a3001eb5..fad67c36f88a 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -47,7 +47,7 @@ class ResourceManager: # store memory limit for the streaming executor, # when `ReservationOpResourceAllocator` is enabled. DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION = env_float( - "RAY_DATA_OBJECT_STORE_MEMORY_LIMIT_FRACTION", 0.5 + "RAY_DATA_OBJECT_STORE_MEMORY_LIMIT_FRACTION", 0.75 ) # The fraction of the object store capacity that will be used as the default object @@ -277,6 +277,14 @@ def get_op_usage(self, op: PhysicalOperator) -> ExecutionResources: """Return the resource usage of the given operator at the current time.""" return self._op_usages[op] + def get_mem_op_internal(self, op: PhysicalOperator) -> int: + """Return the memory usage of the internal buffers of the given operator.""" + return self._mem_op_internal[op] + + def get_mem_op_outputs(self, op: PhysicalOperator) -> int: + """Return the memory usage of the outputs of the given operator.""" + return self._mem_op_outputs[op] + def get_op_usage_str(self, op: PhysicalOperator) -> str: """Return a human-readable string representation of the resource usage of the given operator.""" @@ -288,8 +296,8 @@ def get_op_usage_str(self, op: PhysicalOperator) -> str: ) if self._debug: usage_str += ( - f" (in={memory_string(self._mem_op_internal[op])}," - f"out={memory_string(self._mem_op_outputs[op])})" + f" (in={memory_string(self.get_mem_op_internal(op))}," + f"out={memory_string(self.get_mem_op_outputs(op))})" ) if ( isinstance(self._op_resource_allocator, ReservationOpResourceAllocator) @@ -330,6 +338,69 @@ def get_budget(self, op: PhysicalOperator) -> Optional[ExecutionResources]: return None return self._op_resource_allocator.get_budget(op) + def is_op_eligible(self, op: PhysicalOperator) -> bool: + """Whether the op is eligible for memory reservation.""" + return ( + not op.throttling_disabled() + # As long as the op has finished execution, even if there are still + # non-taken outputs, we don't need to allocate resources for it. + and not op.execution_finished() + ) + + def get_eligible_ops(self) -> List[PhysicalOperator]: + return [op for op in self._topology if self.is_op_eligible(op)] + + def get_downstream_ineligible_ops( + self, op: PhysicalOperator + ) -> Iterable[PhysicalOperator]: + """Get the downstream ineligible operators of the given operator. + + E.g., + - "cur_map->downstream_map" will return an empty list. + - "cur_map->limit1->limit2->downstream_map" will return [limit1, limit2]. + """ + for next_op in op.output_dependencies: + if not self.is_op_eligible(next_op): + yield next_op + yield from self.get_downstream_ineligible_ops(next_op) + + def get_downstream_eligible_ops( + self, op: PhysicalOperator + ) -> Iterable[PhysicalOperator]: + """Get the downstream eligible operators of the given operator, ignoring + intermediate ineligible operators. + + E.g., + - "cur_map->downstream_map" will return [downstream_map]. + - "cur_map->limit1->limit2->downstream_map" will return [downstream_map]. + """ + for next_op in op.output_dependencies: + if self.is_op_eligible(next_op): + yield next_op + else: + yield from self.get_downstream_eligible_ops(next_op) + + def get_op_outputs_usage_with_downstream(self, op: PhysicalOperator) -> float: + """Get the outputs memory usage of the given operator, including the downstream + ineligible operators. + """ + # Outputs usage of the current operator. + op_outputs_usage = self.get_mem_op_outputs(op) + # Also account the downstream ineligible operators' memory usage. + op_outputs_usage += sum( + self.get_op_usage(next_op).object_store_memory + for next_op in self.get_downstream_ineligible_ops(op) + ) + return op_outputs_usage + + def get_op_outputs_usage_with_internal_and_downstream( + self, op: PhysicalOperator + ) -> float: + """Get the outputs memory usage of the given operator, including the internal usage and the downstream ineligible operators.""" + return self.get_mem_op_internal(op) + self.get_op_outputs_usage_with_downstream( + op + ) + class OpResourceAllocator(ABC): """An interface for dynamic operator resource allocation. @@ -479,20 +550,6 @@ def __init__(self, resource_manager: ResourceManager, reservation_ratio: float): self._idle_detector = self.IdleDetector() - def _is_op_eligible(self, op: PhysicalOperator) -> bool: - """Whether the op is eligible for memory reservation.""" - return ( - not op.throttling_disabled() - # As long as the op has finished execution, even if there are still - # non-taken outputs, we don't need to allocate resources for it. - and not op.execution_finished() - ) - - def _get_eligible_ops(self) -> List[PhysicalOperator]: - return [ - op for op in self._resource_manager._topology if self._is_op_eligible(op) - ] - def _get_ineligible_ops_with_usage(self) -> List[PhysicalOperator]: """ Resource reservation is based on the number of eligible operators. @@ -519,14 +576,14 @@ def _get_ineligible_ops_with_usage(self) -> List[PhysicalOperator]: # filter out downstream ineligible operators since they are omitted from reservation calculations. for op in last_completed_ops: ops_to_exclude_from_reservation.extend( - list(self._get_downstream_ineligible_ops(op)) + list(self._resource_manager.get_downstream_ineligible_ops(op)) ) ops_to_exclude_from_reservation.append(op) return list(set(ops_to_exclude_from_reservation)) def _update_reservation(self): global_limits = self._resource_manager.get_global_limits().copy() - eligible_ops = self._get_eligible_ops() + eligible_ops = self._resource_manager.get_eligible_ops() self._op_reserved.clear() self._reserved_for_op_outputs.clear() @@ -600,6 +657,13 @@ def _update_reservation(self): self._total_shared = remaining + def can_submit_new_task(self, op: PhysicalOperator) -> bool: + if op not in self._op_budgets: + return True + budget = self._op_budgets[op] + res = op.incremental_resource_usage().satisfies_limit(budget) + return res + def get_budget(self, op: PhysicalOperator) -> Optional[ExecutionResources]: return self._op_budgets.get(op) @@ -610,7 +674,7 @@ def _should_unblock_streaming_output_backpressure( # launch tasks. Then we should temporarily unblock the streaming output # backpressure by allowing reading at least 1 block. So the current operator # can finish at least one task and yield resources to the downstream operators. - for next_op in self._get_downstream_eligible_ops(op): + for next_op in self._resource_manager.get_downstream_eligible_ops(op): if not self._reserved_min_resources[next_op]: # Case 1: the downstream operator hasn't reserved the minimum resources # to run at least one task. @@ -623,25 +687,14 @@ def _should_unblock_streaming_output_backpressure( return True return False - def _get_op_outputs_usage_with_downstream(self, op: PhysicalOperator) -> float: - """Get the outputs memory usage of the given operator, including the downstream - ineligible operators. - """ - # Outputs usage of the current operator. - op_outputs_usage = self._resource_manager._mem_op_outputs[op] - # Also account the downstream ineligible operators' memory usage. - op_outputs_usage += sum( - self._resource_manager.get_op_usage(next_op).object_store_memory - for next_op in self._get_downstream_ineligible_ops(op) - ) - return op_outputs_usage - def max_task_output_bytes_to_read(self, op: PhysicalOperator) -> Optional[int]: if op not in self._op_budgets: return None res = self._op_budgets[op].object_store_memory # Add the remaining of `_reserved_for_op_outputs`. - op_outputs_usage = self._get_op_outputs_usage_with_downstream(op) + op_outputs_usage = self._resource_manager.get_op_outputs_usage_with_downstream( + op + ) res += max(self._reserved_for_op_outputs[op] - op_outputs_usage, 0) if math.isinf(res): self._output_budgets[op] = res @@ -654,41 +707,11 @@ def max_task_output_bytes_to_read(self, op: PhysicalOperator) -> Optional[int]: self._output_budgets[op] = res return res - def _get_downstream_ineligible_ops( - self, op: PhysicalOperator - ) -> Iterable[PhysicalOperator]: - """Get the downstream ineligible operators of the given operator. - - E.g., - - "cur_map->downstream_map" will return an empty list. - - "cur_map->limit1->limit2->downstream_map" will return [limit1, limit2]. - """ - for next_op in op.output_dependencies: - if not self._is_op_eligible(next_op): - yield next_op - yield from self._get_downstream_ineligible_ops(next_op) - - def _get_downstream_eligible_ops( - self, op: PhysicalOperator - ) -> Iterable[PhysicalOperator]: - """Get the downstream eligible operators of the given operator, ignoring - intermediate ineligible operators. - - E.g., - - "cur_map->downstream_map" will return [downstream_map]. - - "cur_map->limit1->limit2->downstream_map" will return [downstream_map]. - """ - for next_op in op.output_dependencies: - if self._is_op_eligible(next_op): - yield next_op - else: - yield from self._get_downstream_eligible_ops(next_op) - def update_usages(self): self._update_reservation() self._op_budgets.clear() - eligible_ops = self._get_eligible_ops() + eligible_ops = self._resource_manager.get_eligible_ops() if len(eligible_ops) == 0: return @@ -699,10 +722,12 @@ def update_usages(self): op_mem_usage = 0 # Add the memory usage of the operator itself, # excluding `_reserved_for_op_outputs`. - op_mem_usage += self._resource_manager._mem_op_internal[op] + op_mem_usage += self._resource_manager.get_mem_op_internal(op) # Add the portion of op outputs usage that has # exceeded `_reserved_for_op_outputs`. - op_outputs_usage = self._get_op_outputs_usage_with_downstream(op) + op_outputs_usage = ( + self._resource_manager.get_op_outputs_usage_with_downstream(op) + ) op_mem_usage += max(op_outputs_usage - self._reserved_for_op_outputs[op], 0) op_usage = self._resource_manager.get_op_usage(op).copy( object_store_memory=op_mem_usage @@ -778,4 +803,4 @@ def update_usages(self): ): self._op_budgets[op] = self._op_budgets[op].copy( object_store_memory=float("inf") - ) + ) \ No newline at end of file diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 60b4083fe3dc..b261aa207641 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -239,6 +239,11 @@ class ShuffleStrategy(str, enum.Enum): ) +DEFAULT_ENABLE_DYNAMIC_OUTPUT_QUEUE_SIZE_BACKPRESSURE: bool = env_bool( + "RAY_DATA_ENABLE_DYNAMIC_OUTPUT_QUEUE_SIZE_BACKPRESSURE", False +) + + @DeveloperAPI @dataclass class AutoscalingConfig: @@ -343,8 +348,9 @@ class DataContext: large_args_threshold: Size in bytes after which point task arguments are considered large. Choose a value so that the data transfer overhead is significant in comparison to task scheduling (i.e., low tens of ms). - use_polars: Whether to use Polars for tabular dataset sorts, groupbys, and + use_polars_sort: Whether to use Polars for tabular dataset sorts, groupbys, and aggregations. + use_polars_join: Whether to use Polars for join operations. eager_free: Whether to eagerly free memory. decoding_size_estimation: Whether to estimate in-memory decoding data size for data source. @@ -453,6 +459,8 @@ class DataContext: later. If `None`, this type of backpressure is disabled. downstream_capacity_backpressure_max_queued_bundles: Maximum number of queued bundles before applying backpressure. If `None`, no limit is applied. + enable_dynamic_output_queue_size_backpressure: Whether to cap the concurrency + of an operator based on it's and downstream's queue size. enforce_schemas: Whether to enforce schema consistency across dataset operations. pandas_block_ignore_metadata: Whether to ignore pandas metadata when converting between Arrow and pandas formats for better type inference. @@ -591,6 +599,10 @@ class DataContext: downstream_capacity_backpressure_ratio: float = None downstream_capacity_backpressure_max_queued_bundles: int = None + enable_dynamic_output_queue_size_backpressure: bool = ( + DEFAULT_ENABLE_DYNAMIC_OUTPUT_QUEUE_SIZE_BACKPRESSURE + ) + enforce_schemas: bool = DEFAULT_ENFORCE_SCHEMAS pandas_block_ignore_metadata: bool = DEFAULT_PANDAS_BLOCK_IGNORE_METADATA @@ -776,4 +788,4 @@ def set_dataset_logger_id(self, dataset_id: str) -> None: # Backwards compatibility alias. -DatasetContext = DataContext +DatasetContext = DataContext \ No newline at end of file diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index 9abdccff8afd..1939cc18d806 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -2,8 +2,8 @@ import math import time import unittest -from collections import defaultdict -from unittest.mock import MagicMock +from collections import defaultdict, deque +from unittest.mock import MagicMock, patch import pytest @@ -134,8 +134,513 @@ def test_e2e_normal(self): start2, end2 = ray.get(actor.get_start_and_end_time_for_op.remote(2)) assert start1 < start2 < end1 < end2, (start1, start2, end1, end2) + def test_can_add_input_with_normal_concurrency_cap(self): + """Test can_add_input when using normal concurrency cap (queue size disabled).""" + mock_op = MagicMock() + mock_op.name = "TestOperator" + mock_op.metrics.num_tasks_running = 3 + mock_op.throttling_disabled.return_value = False + mock_op.execution_finished.return_value = False + mock_op.output_dependencies = [] + + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Disable queue size based backpressure + policy.enable_dynamic_output_queue_size_backpressure = False + policy._concurrency_caps[mock_op] = 5 + + # Should allow input when running < cap + result = policy.can_add_input(mock_op) + self.assertTrue(result) + + # Should deny input when running >= cap + mock_op.metrics.num_tasks_running = 5 + result = policy.can_add_input(mock_op) + self.assertFalse(result) + + def test_update_queue_threshold_bootstrap(self): + """Test threshold update for first sample (bootstrap).""" + mock_op = MagicMock() + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Add sample to history first (required for threshold calculation) + policy._queue_history[mock_op].append(1000) + + # First sample should bootstrap threshold + # The threshold will be calculated as max(level + K_DEV * dev, q_now) + # where level=q_now=1000, dev=0 (first sample), so threshold = max(1000 + 4*0, 1000) = 1000 + threshold = policy._update_queue_threshold(mock_op, 1000) + self.assertEqual(threshold, 1000) + self.assertEqual(policy._queue_thresholds[mock_op], 1000) + + # Test bootstrap with zero queue (should set threshold to 1 due to rounding) + fresh_mock_op = MagicMock() + fresh_policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {fresh_mock_op: MagicMock()}, + MagicMock(), + ) + fresh_policy._queue_thresholds[fresh_mock_op] = 0 # Reset to idle state + fresh_policy._queue_history[fresh_mock_op] = deque([0]) + # Fresh policy starts with clean EWMA state + + threshold_zero = fresh_policy._update_queue_threshold(fresh_mock_op, 0) + # When q_now=0, level=0, dev=0, threshold = max(1, max(0 + 4*0, 0)) = 1 + self.assertEqual(threshold_zero, 1) + self.assertEqual(fresh_policy._queue_thresholds[fresh_mock_op], 1) + + def test_update_queue_threshold_asymmetric_ewma(self): + """Test threshold update with asymmetric EWMA behavior.""" + mock_op = MagicMock() + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Set up initial state + policy._q_level_nbytes[mock_op] = 100.0 + policy._q_level_dev[mock_op] = 20.0 + policy._queue_history[mock_op] = deque([100, 120, 140, 160, 180, 200]) + + # Test with growing queue (should use faster alpha_up) + threshold = policy._update_queue_threshold(mock_op, 300) + + # Threshold should be at least as high as current queue + self.assertGreaterEqual(threshold, 300) + + # Level should have moved toward the new sample using alpha_up + self.assertGreater(policy._q_level_nbytes[mock_op], 100.0) + + # Test with declining queue (should use slower EWMA_ALPHA) + policy._q_level_nbytes[mock_op] = 200.0 + policy._q_level_dev[mock_op] = 30.0 + policy._update_queue_threshold(mock_op, 150) + + # Level should have moved less aggressively downward + self.assertGreater(policy._q_level_nbytes[mock_op], 150.0) + self.assertLess(policy._q_level_nbytes[mock_op], 200.0) + + def test_update_queue_threshold_downward_smoothing(self): + """Test threshold update with downward smoothing logic.""" + mock_op = MagicMock() + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Set up initial state with high threshold and very low level/dev to force downward adjustment + policy._queue_thresholds[mock_op] = 200 + policy._q_level_nbytes[mock_op] = 10.0 # Very low level + policy._q_level_dev[mock_op] = 1.0 # Very low deviation + policy._queue_history[mock_op] = deque([10, 11, 12, 13, 14, 15]) + + # Test downward adjustment (should be smoothed) + # threshold = max(10 + 4*1, 150) = 150, which is < 200, so should be smoothed + threshold = policy._update_queue_threshold(mock_op, 150) + + # Should be smoothed between 200 and 150 + self.assertLess(threshold, 200) + self.assertGreaterEqual(threshold, 150) + + # Test that the method works correctly - just verify it doesn't crash + # and returns a reasonable threshold value + mock_op2 = MagicMock() + policy2 = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op2: MagicMock()}, + MagicMock(), + ) + policy2._queue_thresholds[mock_op2] = 200 + policy2._q_level_nbytes[mock_op2] = 10.0 + policy2._q_level_dev[mock_op2] = 1.0 + policy2._queue_history[mock_op2] = deque([10, 11, 12, 13, 14, 15]) + + threshold_small = policy2._update_queue_threshold(mock_op2, 50) + + # Just verify it returns a reasonable threshold (at least as high as input) + self.assertGreaterEqual(threshold_small, 50) + + def test_effective_cap_calculation_with_trend(self): + """Test effective cap calculation with different trend scenarios.""" + mock_op = MagicMock() + mock_op.metrics.num_tasks_running = 5 + + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Set up queue history for trend calculation + policy._queue_history[mock_op] = deque( + [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + ) + policy._q_level_dev[mock_op] = 100.0 + policy._queue_thresholds[mock_op] = 500 + policy._concurrency_caps[mock_op] = 10 + + # Test with high pressure (queue > threshold) + with patch.object( + policy._resource_manager, + "get_op_outputs_usage_with_internal_and_downstream", + return_value=2000, + ): + effective_cap = policy._effective_cap(mock_op) + # Should be reduced due to high pressure + self.assertLess(effective_cap, 10) + self.assertGreaterEqual(effective_cap, 1) # Should be at least 1 + + def test_effective_cap_insufficient_history(self): + """Test effective cap when there's insufficient history for trend calculation.""" + mock_op = MagicMock() + mock_op.metrics.num_tasks_running = 5 + + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Set up insufficient history (less than 6 samples) + policy._queue_history[mock_op] = deque([100, 200, 300]) + policy._concurrency_caps[mock_op] = 10 + + effective_cap = policy._effective_cap(mock_op) + # Should return max(1, running) when insufficient history + self.assertEqual(effective_cap, 5) + + def test_signal_calculation_formulas(self): + """Test pressure_signal and trend_signal calculation formulas.""" + # Test pressure_signal formula: (q_now - threshold) / max(1.0, dev) + pressure_cases = [ + (1000, 500, 100, 5.0, "High pressure"), + (500, 500, 100, 0.0, "Neutral pressure"), + (200, 500, 100, -3.0, "Low pressure"), + (500, 500, 0, 0.0, "Zero deviation uses scale=1.0"), + ] + + for q_now, threshold, dev, expected, description in pressure_cases: + with self.subTest(signal="pressure", description=description): + scale = max(1.0, float(dev)) + pressure_signal = (q_now - threshold) / scale + self.assertAlmostEqual(pressure_signal, expected, places=5) + + # Test trend_signal formula: (recent_avg - older_avg) / max(1.0, dev) + trend_cases = [ + (1000, 500, 100, 5.0, "Strong growth"), + (500, 500, 100, 0.0, "No trend"), + (200, 500, 100, -3.0, "Strong decline"), + (500, 500, 0, 0.0, "Zero deviation uses scale=1.0"), + ] + + for recent_avg, older_avg, dev, expected, description in trend_cases: + with self.subTest(signal="trend", description=description): + scale = max(1.0, float(dev)) + trend_signal = (recent_avg - older_avg) / scale + self.assertAlmostEqual(trend_signal, expected, places=5) + + def test_decision_rules_table_comprehensive(self): + """Test all decision rules from the table comprehensively.""" + test_cases = [ + # (pressure_signal, trend_signal, expected_step, description) + # High pressure scenarios + (2.5, 1.5, -1, "High pressure + growing trend -> backoff"), + (2.0, 1.0, -1, "High pressure + growing trend (boundary) -> backoff"), + (2.5, 0.5, 0, "High pressure + mild growth -> hold"), + (2.5, 0.0, 0, "High pressure + no trend -> hold"), + (2.5, -0.5, 0, "High pressure + mild decline -> hold"), + (2.5, -1.0, 0, "High pressure + declining trend -> hold"), + # Moderate pressure scenarios + (1.5, 1.5, 0, "Moderate pressure + growing trend -> hold"), + (1.0, 1.0, 0, "Moderate pressure + growing trend (boundary) -> hold"), + (1.5, 0.5, 0, "Moderate pressure + mild growth -> hold"), + (1.5, 0.0, 0, "Moderate pressure + no trend -> hold"), + (1.5, -0.5, 0, "Moderate pressure + mild decline -> hold"), + (1.5, -1.0, 0, "Moderate pressure + declining trend -> hold"), + # Low pressure scenarios + (-1.5, -1.5, 1, "Low pressure + declining trend -> increase"), + (-1.0, -1.0, 1, "Low pressure + declining trend (boundary) -> increase"), + (-1.5, -0.5, 0, "Low pressure + mild decline -> hold"), + (-1.5, 0.0, 0, "Low pressure + no trend -> hold"), + (-1.5, 0.5, 0, "Low pressure + mild growth -> hold"), + (-1.5, 1.0, 0, "Low pressure + growing trend -> hold"), + # Very low pressure scenarios + (-2.5, -2.5, 2, "Very low pressure + declining trend -> increase by 2"), + ( + -2.0, + -2.0, + 2, + "Very low pressure + declining trend (boundary) -> increase by 2", + ), + (-2.5, -1.5, 1, "Very low pressure + mild decline -> increase by 1"), + (-2.5, -1.0, 1, "Very low pressure + mild decline -> increase by 1"), + (-2.5, 0.0, 0, "Very low pressure + no trend -> hold"), + (-2.5, 0.5, 0, "Very low pressure + mild growth -> hold"), + (-2.5, 1.0, 0, "Very low pressure + growing trend -> hold"), + # Neutral scenarios + (0.5, 0.5, 0, "Low pressure + mild growth -> hold"), + (0.0, 0.0, 0, "Neutral pressure + no trend -> hold"), + (-0.5, -0.5, 0, "Low pressure + mild decline -> hold"), + # Edge cases + (1.0, 0.0, 0, "Moderate pressure + no trend (boundary) -> hold"), + (0.0, 1.0, 0, "Neutral pressure + growing trend -> hold"), + (0.0, -1.0, 0, "Neutral pressure + declining trend -> hold"), + ] + + for pressure_signal, trend_signal, expected_step, description in test_cases: + with self.subTest(description=description): + # Inlined decision step logic from the policy + if pressure_signal >= 2.0 and trend_signal >= 1.0: + step = -1 + elif pressure_signal >= 1.0 and trend_signal > 0.0: + step = 0 + elif pressure_signal <= -2.0 and trend_signal <= -2.0: + step = +2 + elif pressure_signal <= -1.0 and trend_signal <= -1.0: + step = +1 + else: + step = 0 + + self.assertEqual( + step, + expected_step, + f"Failed for pressure={pressure_signal}, trend={trend_signal}", + ) + + def test_ewma_calculation_formulas(self): + """Test EWMA level, deviation, and alpha calculation formulas.""" + # Test EWMA level formula: (1 - alpha) * prev + alpha * sample + level_cases = [ + (100.0, 120.0, 0.2, 104.0, "Normal alpha"), + (100.0, 80.0, 0.2, 96.0, "Normal alpha down"), + (100.0, 100.0, 0.2, 100.0, "Stable"), + (0.0, 100.0, 0.2, 20.0, "Bootstrap"), + ] + + for prev_level, sample, alpha, expected, description in level_cases: + with self.subTest(formula="level", description=description): + new_level = (1 - alpha) * prev_level + alpha * sample + self.assertAlmostEqual(new_level, expected, places=5) + + # Test EWMA deviation formula: (1 - alpha) * prev_dev + alpha * abs(sample - prev_level) + dev_cases = [ + (20.0, 120.0, 100.0, 0.2, 20.0, "Growing"), + (20.0, 100.0, 100.0, 0.2, 16.0, "Stable"), + (0.0, 100.0, 0.0, 0.2, 20.0, "Bootstrap"), + ] + + for prev_dev, sample, prev_level, alpha, expected, description in dev_cases: + with self.subTest(formula="deviation", description=description): + new_dev = (1 - alpha) * prev_dev + alpha * abs(sample - prev_level) + self.assertAlmostEqual(new_dev, expected, places=5) + + # Test alpha_up calculation: 1.0 - (1.0 - EWMA_ALPHA) ** 2 + alpha_cases = [ + (0.2, 0.36, "Normal EWMA_ALPHA"), + (0.1, 0.19, "Low EWMA_ALPHA"), + (0.5, 0.75, "High EWMA_ALPHA"), + ] + + for EWMA_ALPHA, expected, description in alpha_cases: + with self.subTest(formula="alpha_up", description=description): + alpha_up = 1.0 - (1.0 - EWMA_ALPHA) ** 2 + self.assertAlmostEqual(alpha_up, expected, places=5) + + def test_threshold_calculation_formula(self): + """Test threshold calculation: max(level + K_DEV * dev, q_now).""" + test_cases = [ + (100.0, 20.0, 150.0, 4.0, 180.0, "Normal case"), + (100.0, 20.0, 200.0, 4.0, 200.0, "High queue"), + (100.0, 0.0, 150.0, 4.0, 150.0, "Zero deviation"), + (100.0, 20.0, 150.0, 2.0, 150.0, "Lower K_DEV"), + ] + + for level, dev, q_now, K_DEV, expected, description in test_cases: + with self.subTest(description=description): + threshold = max(level + K_DEV * dev, q_now) + self.assertAlmostEqual(threshold, expected, places=5) + + def test_threshold_update_logic_comprehensive(self): + """Test comprehensive threshold update logic including bootstrap, upward, and downward cases.""" + mock_op = MagicMock() + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + + # Test 1: Bootstrap case (prev_threshold = 0) + policy._queue_thresholds[mock_op] = 0 + policy._queue_history[mock_op] = deque([100]) + threshold1 = policy._update_queue_threshold(mock_op, 100) + # Bootstrap: threshold = max(level + K_DEV * dev, q_now) = max(100 + 4*0, 100) = 100 + self.assertEqual(threshold1, 100) + + # Test 2: Upward adjustment (threshold >= prev_threshold) + policy._queue_thresholds[mock_op] = 100 + policy._q_level_nbytes[mock_op] = 50.0 + policy._q_level_dev[mock_op] = 10.0 + policy._queue_history[mock_op] = deque([50, 60, 70, 80, 90, 100]) + threshold2 = policy._update_queue_threshold(mock_op, 200) + # The EWMA will update level and dev, so we can't predict exact value + # Just verify it's >= 200 (upward adjustment) + self.assertGreaterEqual(threshold2, 200) + + # Test 3: Downward adjustment (threshold < prev_threshold) + policy._queue_thresholds[mock_op] = 200 + policy._q_level_nbytes[ + mock_op + ] = 10.0 # Very low level to force downward adjustment + policy._q_level_dev[mock_op] = 1.0 # Very low deviation + policy._queue_history[mock_op] = deque([10, 11, 12, 13, 14, 15]) + threshold3 = policy._update_queue_threshold(mock_op, 150) + # threshold = max(10 + 4*1, 150) = 150, which is < 200, so should be smoothed + self.assertLess(threshold3, 200) + self.assertGreaterEqual(threshold3, 150) + + # Test 4: Zero threshold case + fresh_mock_op = MagicMock() + fresh_policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {fresh_mock_op: MagicMock()}, + MagicMock(), + ) + fresh_policy._queue_thresholds[fresh_mock_op] = 0 + fresh_policy._queue_history[fresh_mock_op] = deque([0]) + # Fresh policy starts with clean EWMA state + threshold5 = fresh_policy._update_queue_threshold(fresh_mock_op, 0) + self.assertEqual(threshold5, 1) # Should round up to 1 + + def test_trend_and_effective_cap_formulas(self): + """Test trend calculation and effective cap formulas.""" + # Test trend calculation: recent_avg - older_avg + trend_cases = [ + ([100, 200, 300, 400, 500, 600], 500.0, 200.0, 300.0, "6 samples"), + ([100, 200, 300, 400, 500, 600, 700], 600.0, 300.0, 300.0, "7 samples"), + ] + + for ( + history, + expected_recent, + expected_older, + expected_trend, + description, + ) in trend_cases: + with self.subTest(formula="trend", description=description): + h = list(history) + recent_window = len(h) // 2 + older_window = len(h) // 2 + + recent_avg = sum(h[-recent_window:]) / float(recent_window) + older_avg = sum( + h[-(recent_window + older_window) : -recent_window] + ) / float(older_window) + trend = recent_avg - older_avg + + self.assertAlmostEqual(recent_avg, expected_recent, places=5) + self.assertAlmostEqual(older_avg, expected_older, places=5) + self.assertAlmostEqual(trend, expected_trend, places=5) + + # Test effective cap formula: max(1, running + step) + cap_cases = [ + (5, -1, 4, "Reduce by 1"), + (5, 0, 5, "No change"), + (5, 1, 6, "Increase by 1"), + (1, -1, 1, "Min cap"), + ] + + for running, step, expected, description in cap_cases: + with self.subTest(formula="effective_cap", description=description): + effective_cap = max(1, running + step) + self.assertEqual(effective_cap, expected) + + def test_ewma_asymmetric_behavior(self): + """Test EWMA asymmetric behavior and level calculation.""" + # Test alpha selection: alpha_up if sample > prev else EWMA_ALPHA + alpha_cases = [ + (100.0, 150.0, 0.2, 0.36, "Rising uses alpha_up"), + (100.0, 50.0, 0.2, 0.2, "Falling uses EWMA_ALPHA"), + (100.0, 100.0, 0.2, 0.2, "Stable uses EWMA_ALPHA"), + ] + + for prev_level, sample, EWMA_ALPHA, expected, description in alpha_cases: + with self.subTest(behavior="alpha_selection", description=description): + alpha_up = 1.0 - (1.0 - EWMA_ALPHA) ** 2 + alpha = alpha_up if sample > prev_level else EWMA_ALPHA + self.assertAlmostEqual(alpha, expected, places=5) + + # Test level calculation with asymmetric alpha + level_cases = [ + (100.0, 150.0, 0.2, 118.0, "Rising with alpha_up"), + (100.0, 50.0, 0.2, 90.0, "Falling with EWMA_ALPHA"), + (0.0, 100.0, 0.2, 100.0, "Bootstrap uses sample"), + ] + + for prev_level, sample, EWMA_ALPHA, expected, description in level_cases: + with self.subTest(behavior="level_calculation", description=description): + if prev_level <= 0: + level = sample + else: + alpha_up = 1.0 - (1.0 - EWMA_ALPHA) ** 2 + alpha = alpha_up if sample > prev_level else EWMA_ALPHA + level = (1 - alpha) * prev_level + alpha * sample + self.assertAlmostEqual(level, expected, places=5) + + def test_simple_calculation_formulas(self): + """Test simple calculation formulas: scale, min_samples, and windows.""" + # Test scale calculation: max(1.0, float(dev)) + scale_cases = [ + (100.0, 100.0, "Normal deviation"), + (0.0, 1.0, "Zero deviation"), + (0.5, 1.0, "Small deviation"), + (1.1, 1.1, "Just above unit"), + ] + + for dev, expected, description in scale_cases: + with self.subTest(formula="scale", description=description): + scale = max(1.0, float(dev)) + self.assertAlmostEqual(scale, expected, places=5) + + # Test min_samples calculation: recent_window + older_window + min_samples_cases = [ + (10, 10, "HISTORY_LEN=10"), + (6, 6, "HISTORY_LEN=6"), + (12, 12, "HISTORY_LEN=12"), + ] + + for HISTORY_LEN, expected, description in min_samples_cases: + with self.subTest(formula="min_samples", description=description): + recent_window = HISTORY_LEN // 2 + older_window = HISTORY_LEN // 2 + min_samples = recent_window + older_window + self.assertEqual(min_samples, expected) + + # Test window calculation: recent_window = older_window = HISTORY_LEN // 2 + window_cases = [ + (10, 5, 5, "HISTORY_LEN=10"), + (6, 3, 3, "HISTORY_LEN=6"), + (9, 4, 4, "HISTORY_LEN=9 (integer division)"), + ] + + for HISTORY_LEN, expected_recent, expected_older, description in window_cases: + with self.subTest(formula="windows", description=description): + recent_window = HISTORY_LEN // 2 + older_window = HISTORY_LEN // 2 + self.assertEqual(recent_window, expected_recent) + self.assertEqual(older_window, expected_older) + if __name__ == "__main__": import sys - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-v", __file__])) \ No newline at end of file diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 51ead844ec15..a67b82e48d32 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -1657,6 +1657,7 @@ runtime_env: # Enable verbose stats for resource manager - RAY_DATA_DEBUG_RESOURCE_MANAGER=1 + - RAY_DATA_ENABLE_DYNAMIC_OUTPUT_QUEUE_SIZE_BACKPRESSURE=1 # 'type: gpu' means: use the 'ray-ml' image. type: gpu From 8e4970444cde8a4fc811cb31ed93fd35fcbed222 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 17:04:06 +0000 Subject: [PATCH 2/9] Cleanups Signed-off-by: Srinath Krishnamachari --- python/ray/data/_internal/execution/resource_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index fad67c36f88a..76bb266d5210 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -47,7 +47,7 @@ class ResourceManager: # store memory limit for the streaming executor, # when `ReservationOpResourceAllocator` is enabled. DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION = env_float( - "RAY_DATA_OBJECT_STORE_MEMORY_LIMIT_FRACTION", 0.75 + "RAY_DATA_OBJECT_STORE_MEMORY_LIMIT_FRACTION", 0.5 ) # The fraction of the object store capacity that will be used as the default object From bd8ad05c50e2612f81126032442e33de042003ac Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 17:10:28 +0000 Subject: [PATCH 3/9] Cleanup Signed-off-by: Srinath Krishnamachari --- .../_internal/execution/resource_manager.py | 7 ------- .../ray/data/tests/test_resource_manager.py | 19 +++---------------- .../ray/data/tests/test_streaming_executor.py | 3 --- 3 files changed, 3 insertions(+), 26 deletions(-) diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index 76bb266d5210..f2925d0ef643 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -657,13 +657,6 @@ def _update_reservation(self): self._total_shared = remaining - def can_submit_new_task(self, op: PhysicalOperator) -> bool: - if op not in self._op_budgets: - return True - budget = self._op_budgets[op] - res = op.incremental_resource_usage().satisfies_limit(budget) - return res - def get_budget(self, op: PhysicalOperator) -> Optional[ExecutionResources]: return self._op_budgets.get(op) diff --git a/python/ray/data/tests/test_resource_manager.py b/python/ray/data/tests/test_resource_manager.py index ebce02f362d0..bc007694ed40 100644 --- a/python/ray/data/tests/test_resource_manager.py +++ b/python/ray/data/tests/test_resource_manager.py @@ -398,13 +398,6 @@ def mock_get_global_limits(): nonlocal global_limits return global_limits - def can_submit_new_task(allocator, op): - """Helper to check if operator can submit new tasks based on budget.""" - budget = allocator.get_budget(op) - if budget is None: - return True - return op.incremental_resource_usage().satisfies_limit(budget) - resource_manager = ResourceManager( topo, ExecutionOptions(), MagicMock(), DataContext.get_current() ) @@ -447,9 +440,7 @@ def can_submit_new_task(allocator, op): # Test budgets. assert allocator._op_budgets[o2] == ExecutionResources(8, 0, 375) assert allocator._op_budgets[o3] == ExecutionResources(8, 0, 375) - # Test can_submit_new_task and max_task_output_bytes_to_read. - assert can_submit_new_task(allocator, o2) - assert can_submit_new_task(allocator, o3) + # Test max_task_output_bytes_to_read. assert allocator.max_task_output_bytes_to_read(o2) == 500 assert allocator.max_task_output_bytes_to_read(o3) == 500 @@ -478,9 +469,7 @@ def can_submit_new_task(allocator, op): assert allocator._op_budgets[o2] == ExecutionResources(3, 0, 113) # memory_budget[o3] = 95 + 225/2 = 207 (rounded down) assert allocator._op_budgets[o3] == ExecutionResources(5, 0, 207) - # Test can_submit_new_task and max_task_output_bytes_to_read. - assert can_submit_new_task(allocator, o2) - assert can_submit_new_task(allocator, o3) + # Test max_task_output_bytes_to_read. # max_task_output_bytes_to_read(o2) = 112.5 + 25 = 138 (rounded up) assert allocator.max_task_output_bytes_to_read(o2) == 138 # max_task_output_bytes_to_read(o3) = 207.5 + 50 = 257 (rounded down) @@ -512,9 +501,7 @@ def can_submit_new_task(allocator, op): assert allocator._op_budgets[o2] == ExecutionResources(1.5, 0, 50) # memory_budget[o3] = 70 + 100/2 = 120 assert allocator._op_budgets[o3] == ExecutionResources(2.5, 0, 120) - # Test can_submit_new_task and max_task_output_bytes_to_read. - assert can_submit_new_task(allocator, o2) - assert can_submit_new_task(allocator, o3) + # Test max_task_output_bytes_to_read. # max_task_output_bytes_to_read(o2) = 50 + 0 = 50 assert allocator.max_task_output_bytes_to_read(o2) == 50 # max_task_output_bytes_to_read(o3) = 120 + 25 = 145 diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index bbc069df866a..a6c2629b09d9 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -296,9 +296,6 @@ def test_get_eligible_operators_to_run(ray_start_regular_shared): resource_manager.get_op_usage = MagicMock( side_effect=lambda op: ExecutionResources(0, 0, memory_usage[op]) ) - resource_manager.op_resource_allocator.can_submit_new_task = MagicMock( - return_value=True - ) def _get_eligible_ops_to_run(ensure_liveness: bool): return get_eligible_operators(topo, [], ensure_liveness=ensure_liveness) From 01463ef6344748f2d706315c876f03dc34f50ce5 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 17:13:45 +0000 Subject: [PATCH 4/9] Cleanup Signed-off-by: Srinath Krishnamachari --- python/ray/data/context.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/data/context.py b/python/ray/data/context.py index b261aa207641..6a27eab664ea 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -348,9 +348,8 @@ class DataContext: large_args_threshold: Size in bytes after which point task arguments are considered large. Choose a value so that the data transfer overhead is significant in comparison to task scheduling (i.e., low tens of ms). - use_polars_sort: Whether to use Polars for tabular dataset sorts, groupbys, and + use_polars: Whether to use Polars for tabular dataset sorts, groupbys, and aggregations. - use_polars_join: Whether to use Polars for join operations. eager_free: Whether to eagerly free memory. decoding_size_estimation: Whether to estimate in-memory decoding data size for data source. From 907d0226ac36d0e80891d1edf0988a133c112da7 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 17:28:56 +0000 Subject: [PATCH 5/9] Fixups Signed-off-by: Srinath Krishnamachari --- .../concurrency_cap_backpressure_policy.py | 50 ++++++++++++------- .../_internal/execution/resource_manager.py | 2 +- python/ray/data/context.py | 2 +- .../data/tests/test_backpressure_policies.py | 23 ++++----- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 8e77ecc3c02f..2ba80a923b4e 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -180,7 +180,7 @@ def can_add_input(self, op: "PhysicalOperator") -> bool: last_effective_cap = self._last_effective_caps.get(op, None) if last_effective_cap != effective_cap: logger.debug( - "Effective concurrency cap changed for operator %s: %d -> %d" + "Effective concurrency cap changed for operator %s: %d -> %d " "running=%d tasks, queue=%d bytes, threshold=%d bytes", op.name, last_effective_cap, @@ -322,11 +322,6 @@ def _update_queue_threshold( # Uses same asymmetric behavior as EWMA: slow to adjust downward # Example: prev_threshold=200, threshold=100 -> smoothed using asymmetric EWMA smoothed = int(self._update_ewma_asymmetric(prev_threshold, threshold)) - if smoothed >= prev_threshold: - # Ensures progress when small deltas round to 0 - # Example: prev_threshold=200, threshold=195 -> smoothed=199, but 199>=200, so use 199 - smoothed = prev_threshold - 1 - self._queue_thresholds[op] = max(1, smoothed) return self._queue_thresholds[op] @@ -423,20 +418,41 @@ def _effective_cap(self, op: "PhysicalOperator") -> int: trend_signal = (recent_avg - older_avg) / scale # Quantized controller decision - if pressure_signal >= 2.0 and trend_signal >= 1.0: - step = -1 - elif pressure_signal >= 1.0 and trend_signal > 0.0: - step = 0 - elif pressure_signal <= -1.0 and trend_signal <= -1.0: - step = +1 - elif pressure_signal <= -2.0 and trend_signal <= -2.0: - step = +2 - else: - step = 0 + step = self._quantized_controller_step(pressure_signal, trend_signal) # Apply step to current running concurrency, clamp by configured cap. target = max(1, running + step) cap_cfg = self._concurrency_caps[op] if not math.isinf(cap_cfg): target = min(target, int(cap_cfg)) - return target \ No newline at end of file + return target + + def _quantized_controller_step( + self, pressure_signal: float, trend_signal: float + ) -> int: + """Compute the quantized controller step based on pressure and trend signals. + + This method implements the decision logic for the quantized controller: + - High pressure + growing trend = emergency backoff (-1) + - High pressure + stable/declining trend = wait and see (0) + - Low pressure + declining trend = safe to grow (+1) + - Very low pressure + strong improvement = aggressive growth (+2) + - Moderate signals = maintain current concurrency (0) + + Args: + pressure_signal: Normalized pressure signal (queue vs threshold) + trend_signal: Normalized trend signal (recent vs older average) + + Returns: + Step adjustment: -1, 0, +1, or +2 + """ + if pressure_signal >= 2.0 and trend_signal >= 1.0: + return -1 + elif pressure_signal >= 1.0 and trend_signal >= 0.0: + return 0 + elif pressure_signal <= -2.0 and trend_signal <= -2.0: + return +2 + elif pressure_signal <= -1.0 and trend_signal <= -1.0: + return +1 + else: + return 0 diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index f2925d0ef643..f4d5d3f9e448 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -796,4 +796,4 @@ def update_usages(self): ): self._op_budgets[op] = self._op_budgets[op].copy( object_store_memory=float("inf") - ) \ No newline at end of file + ) diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 6a27eab664ea..dc13cc95b2f1 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -787,4 +787,4 @@ def set_dataset_logger_id(self, dataset_id: str) -> None: # Backwards compatibility alias. -DatasetContext = DataContext \ No newline at end of file +DatasetContext = DataContext diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index 1939cc18d806..8db313c3861a 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -397,19 +397,18 @@ def test_decision_rules_table_comprehensive(self): (0.0, -1.0, 0, "Neutral pressure + declining trend -> hold"), ] + # Create a policy instance to access the helper method + mock_op = MagicMock() + policy = ConcurrencyCapBackpressurePolicy( + DataContext.get_current(), + {mock_op: MagicMock()}, + MagicMock(), + ) + for pressure_signal, trend_signal, expected_step, description in test_cases: with self.subTest(description=description): - # Inlined decision step logic from the policy - if pressure_signal >= 2.0 and trend_signal >= 1.0: - step = -1 - elif pressure_signal >= 1.0 and trend_signal > 0.0: - step = 0 - elif pressure_signal <= -2.0 and trend_signal <= -2.0: - step = +2 - elif pressure_signal <= -1.0 and trend_signal <= -1.0: - step = +1 - else: - step = 0 + # Use the actual helper method from the policy + step = policy._quantized_controller_step(pressure_signal, trend_signal) self.assertEqual( step, @@ -643,4 +642,4 @@ def test_simple_calculation_formulas(self): if __name__ == "__main__": import sys - sys.exit(pytest.main(["-v", __file__])) \ No newline at end of file + sys.exit(pytest.main(["-v", __file__])) From 999fb06a7385c677fb090ded1c02b26736306dd6 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 17:31:58 +0000 Subject: [PATCH 6/9] Cleanup Signed-off-by: Srinath Krishnamachari --- .../concurrency_cap_backpressure_policy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 2ba80a923b4e..9805060f4fc3 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -302,11 +302,6 @@ def _update_queue_threshold( # Step 4: cache & return with gentle downward response using EWMA_ALPHA prev_threshold = self._queue_thresholds[op] - # Idle/off allowed - if threshold == 0: - self._queue_thresholds[op] = 0 - return 0 - # Bootstrap if prev_threshold == 0: self._queue_thresholds[op] = max(1, threshold) From 8d65ebf63ed76d416c460635defddd774c2e6e9f Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 17:34:20 +0000 Subject: [PATCH 7/9] Cleanup Signed-off-by: Srinath Krishnamachari --- .../backpressure_policy/concurrency_cap_backpressure_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 9805060f4fc3..0c2498fdc8cf 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -357,7 +357,7 @@ def _effective_cap(self, op: "PhysicalOperator") -> int: | | | | (immediate reduction to | | | | | | prevent overload) | | +----------+----------+----------+--------------------------------+------------------+ - | >= +1.0 | > 0.0 | 0 | Wait and see | +1.5, +0.5 -> 0 | + | >= +1.0 | >= 0.0 | 0 | Wait and see | +1.5, +0.5 -> 0 | | | | | (let current level stabilize) | | +----------+----------+----------+--------------------------------+------------------+ | <= -1.0 | <= -1.0 | +1 | Conservative growth | -1.5, -1.0 -> +1 | From d5037b98b4718950b28fda66ff8256ab2436828c Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 20:05:29 +0000 Subject: [PATCH 8/9] Address comments Signed-off-by: Srinath Krishnamachari --- .../concurrency_cap_backpressure_policy.py | 4 ++- .../_internal/execution/resource_manager.py | 32 +++++++------------ .../data/tests/test_backpressure_policies.py | 8 +++-- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 0c2498fdc8cf..06669a1cf7ed 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -83,6 +83,7 @@ class ConcurrencyCapBackpressurePolicy(BackpressurePolicy): K_DEV = 4.0 def __init__(self, *args, **kwargs): + """Initialize the ConcurrencyCapBackpressurePolicy.""" super().__init__(*args, **kwargs) # Explicit concurrency caps for each operator. Infinite if not specified. @@ -164,7 +165,8 @@ def can_add_input(self, op: "PhysicalOperator") -> bool: # Observe fresh queue size for this operator and its downstream. current_queue_size_bytes = ( - self._resource_manager.get_op_outputs_usage_with_internal_and_downstream(op) + self._resource_manager.get_op_internal_object_store_usage(op) + + self._resource_manager.get_op_outputs_object_store_usage_with_downstream(op) ) # Update short history and refresh the adaptive threshold. diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index f4d5d3f9e448..cdc6db510b3c 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -277,14 +277,6 @@ def get_op_usage(self, op: PhysicalOperator) -> ExecutionResources: """Return the resource usage of the given operator at the current time.""" return self._op_usages[op] - def get_mem_op_internal(self, op: PhysicalOperator) -> int: - """Return the memory usage of the internal buffers of the given operator.""" - return self._mem_op_internal[op] - - def get_mem_op_outputs(self, op: PhysicalOperator) -> int: - """Return the memory usage of the outputs of the given operator.""" - return self._mem_op_outputs[op] - def get_op_usage_str(self, op: PhysicalOperator) -> str: """Return a human-readable string representation of the resource usage of the given operator.""" @@ -296,8 +288,8 @@ def get_op_usage_str(self, op: PhysicalOperator) -> str: ) if self._debug: usage_str += ( - f" (in={memory_string(self.get_mem_op_internal(op))}," - f"out={memory_string(self.get_mem_op_outputs(op))})" + f" (in={memory_string(self._mem_op_internal[op])}," + f"out={memory_string(self._mem_op_outputs[op])})" ) if ( isinstance(self._op_resource_allocator, ReservationOpResourceAllocator) @@ -380,12 +372,12 @@ def get_downstream_eligible_ops( else: yield from self.get_downstream_eligible_ops(next_op) - def get_op_outputs_usage_with_downstream(self, op: PhysicalOperator) -> float: + def get_op_outputs_object_store_usage_with_downstream(self, op: PhysicalOperator) -> int: """Get the outputs memory usage of the given operator, including the downstream ineligible operators. """ # Outputs usage of the current operator. - op_outputs_usage = self.get_mem_op_outputs(op) + op_outputs_usage = self._mem_op_outputs[op] # Also account the downstream ineligible operators' memory usage. op_outputs_usage += sum( self.get_op_usage(next_op).object_store_memory @@ -393,13 +385,11 @@ def get_op_outputs_usage_with_downstream(self, op: PhysicalOperator) -> float: ) return op_outputs_usage - def get_op_outputs_usage_with_internal_and_downstream( + def get_op_internal_object_store_usage( self, op: PhysicalOperator - ) -> float: - """Get the outputs memory usage of the given operator, including the internal usage and the downstream ineligible operators.""" - return self.get_mem_op_internal(op) + self.get_op_outputs_usage_with_downstream( - op - ) + ) -> int: + """Get the internal object store memory usage of the given operator""" + return self._mem_op_internal[op] class OpResourceAllocator(ABC): @@ -685,7 +675,7 @@ def max_task_output_bytes_to_read(self, op: PhysicalOperator) -> Optional[int]: return None res = self._op_budgets[op].object_store_memory # Add the remaining of `_reserved_for_op_outputs`. - op_outputs_usage = self._resource_manager.get_op_outputs_usage_with_downstream( + op_outputs_usage = self._resource_manager.get_op_outputs_object_store_usage_with_downstream( op ) res += max(self._reserved_for_op_outputs[op] - op_outputs_usage, 0) @@ -715,11 +705,11 @@ def update_usages(self): op_mem_usage = 0 # Add the memory usage of the operator itself, # excluding `_reserved_for_op_outputs`. - op_mem_usage += self._resource_manager.get_mem_op_internal(op) + op_mem_usage += self._resource_manager.get_op_internal_object_store_usage(op) # Add the portion of op outputs usage that has # exceeded `_reserved_for_op_outputs`. op_outputs_usage = ( - self._resource_manager.get_op_outputs_usage_with_downstream(op) + self._resource_manager.get_op_outputs_object_store_usage_with_downstream(op) ) op_mem_usage += max(op_outputs_usage - self._reserved_for_op_outputs[op], 0) op_usage = self._resource_manager.get_op_usage(op).copy( diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index 8db313c3861a..76ff6051f7b1 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -292,8 +292,12 @@ def test_effective_cap_calculation_with_trend(self): # Test with high pressure (queue > threshold) with patch.object( policy._resource_manager, - "get_op_outputs_usage_with_internal_and_downstream", - return_value=2000, + "get_op_internal_object_store_usage", + return_value=1000, + ), patch.object( + policy._resource_manager, + "get_op_outputs_object_store_usage_with_downstream", + return_value=1000, ): effective_cap = policy._effective_cap(mock_op) # Should be reduced due to high pressure From ff919660aeeeab7d912768807cca6413ac04fa4a Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 22 Oct 2025 20:40:51 +0000 Subject: [PATCH 9/9] Lint Signed-off-by: Srinath Krishnamachari --- .../concurrency_cap_backpressure_policy.py | 4 +++- .../_internal/execution/resource_manager.py | 20 ++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 06669a1cf7ed..97802d04b906 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -166,7 +166,9 @@ def can_add_input(self, op: "PhysicalOperator") -> bool: # Observe fresh queue size for this operator and its downstream. current_queue_size_bytes = ( self._resource_manager.get_op_internal_object_store_usage(op) - + self._resource_manager.get_op_outputs_object_store_usage_with_downstream(op) + + self._resource_manager.get_op_outputs_object_store_usage_with_downstream( + op + ) ) # Update short history and refresh the adaptive threshold. diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index cdc6db510b3c..b4ac32aed28b 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -372,7 +372,9 @@ def get_downstream_eligible_ops( else: yield from self.get_downstream_eligible_ops(next_op) - def get_op_outputs_object_store_usage_with_downstream(self, op: PhysicalOperator) -> int: + def get_op_outputs_object_store_usage_with_downstream( + self, op: PhysicalOperator + ) -> int: """Get the outputs memory usage of the given operator, including the downstream ineligible operators. """ @@ -385,9 +387,7 @@ def get_op_outputs_object_store_usage_with_downstream(self, op: PhysicalOperator ) return op_outputs_usage - def get_op_internal_object_store_usage( - self, op: PhysicalOperator - ) -> int: + def get_op_internal_object_store_usage(self, op: PhysicalOperator) -> int: """Get the internal object store memory usage of the given operator""" return self._mem_op_internal[op] @@ -675,8 +675,8 @@ def max_task_output_bytes_to_read(self, op: PhysicalOperator) -> Optional[int]: return None res = self._op_budgets[op].object_store_memory # Add the remaining of `_reserved_for_op_outputs`. - op_outputs_usage = self._resource_manager.get_op_outputs_object_store_usage_with_downstream( - op + op_outputs_usage = ( + self._resource_manager.get_op_outputs_object_store_usage_with_downstream(op) ) res += max(self._reserved_for_op_outputs[op] - op_outputs_usage, 0) if math.isinf(res): @@ -705,11 +705,13 @@ def update_usages(self): op_mem_usage = 0 # Add the memory usage of the operator itself, # excluding `_reserved_for_op_outputs`. - op_mem_usage += self._resource_manager.get_op_internal_object_store_usage(op) + op_mem_usage += self._resource_manager.get_op_internal_object_store_usage( + op + ) # Add the portion of op outputs usage that has # exceeded `_reserved_for_op_outputs`. - op_outputs_usage = ( - self._resource_manager.get_op_outputs_object_store_usage_with_downstream(op) + op_outputs_usage = self._resource_manager.get_op_outputs_object_store_usage_with_downstream( + op ) op_mem_usage += max(op_outputs_usage - self._reserved_for_op_outputs[op], 0) op_usage = self._resource_manager.get_op_usage(op).copy(