|
| 1 | +import io |
| 2 | +import logging |
| 3 | +import re |
| 4 | +import time |
| 5 | +from unittest.mock import MagicMock, patch |
| 6 | + |
| 7 | +import pytest |
| 8 | + |
| 9 | +import ray |
| 10 | +from ray.data._internal.execution.operators.input_data_buffer import ( |
| 11 | + InputDataBuffer, |
| 12 | +) |
| 13 | +from ray.data._internal.execution.operators.task_pool_map_operator import ( |
| 14 | + MapOperator, |
| 15 | +) |
| 16 | +from ray.data._internal.execution.streaming_executor import StreamingExecutor |
| 17 | +from ray.data._internal.issue_detection.detectors.hanging_detector import ( |
| 18 | + DEFAULT_OP_TASK_STATS_MIN_COUNT, |
| 19 | + DEFAULT_OP_TASK_STATS_STD_FACTOR, |
| 20 | + HangingExecutionIssueDetector, |
| 21 | + HangingExecutionIssueDetectorConfig, |
| 22 | +) |
| 23 | +from ray.data._internal.issue_detection.detectors.high_memory_detector import ( |
| 24 | + HighMemoryIssueDetector, |
| 25 | +) |
| 26 | +from ray.data.context import DataContext |
| 27 | +from ray.tests.conftest import * # noqa |
| 28 | + |
| 29 | + |
| 30 | +class TestHangingExecutionIssueDetector: |
| 31 | + def test_hanging_detector_configuration(self, restore_data_context): |
| 32 | + """Test hanging detector configuration and initialization.""" |
| 33 | + # Test default configuration from DataContext |
| 34 | + ctx = DataContext.get_current() |
| 35 | + default_config = ctx.issue_detectors_config.hanging_detector_config |
| 36 | + assert default_config.op_task_stats_min_count == DEFAULT_OP_TASK_STATS_MIN_COUNT |
| 37 | + assert ( |
| 38 | + default_config.op_task_stats_std_factor == DEFAULT_OP_TASK_STATS_STD_FACTOR |
| 39 | + ) |
| 40 | + |
| 41 | + # Test custom configuration |
| 42 | + min_count = 5 |
| 43 | + std_factor = 3.0 |
| 44 | + custom_config = HangingExecutionIssueDetectorConfig( |
| 45 | + op_task_stats_min_count=min_count, |
| 46 | + op_task_stats_std_factor=std_factor, |
| 47 | + ) |
| 48 | + ctx.issue_detectors_config.hanging_detector_config = custom_config |
| 49 | + |
| 50 | + executor = StreamingExecutor(ctx) |
| 51 | + detector = HangingExecutionIssueDetector(executor, ctx) |
| 52 | + assert detector._op_task_stats_min_count == min_count |
| 53 | + assert detector._op_task_stats_std_factor_threshold == std_factor |
| 54 | + |
| 55 | + @patch( |
| 56 | + "ray.data._internal.execution.interfaces.op_runtime_metrics.TaskDurationStats" |
| 57 | + ) |
| 58 | + def test_basic_hanging_detection( |
| 59 | + self, mock_stats_cls, ray_start_2_cpus, restore_data_context |
| 60 | + ): |
| 61 | + # Set up logging capture |
| 62 | + log_capture = io.StringIO() |
| 63 | + handler = logging.StreamHandler(log_capture) |
| 64 | + logger = logging.getLogger("ray.data._internal.issue_detection") |
| 65 | + logger.addHandler(handler) |
| 66 | + |
| 67 | + # Set up mock stats to return values that will trigger adaptive threshold |
| 68 | + mocked_mean = 2.0 # Increase from 0.5 to 2.0 seconds |
| 69 | + mocked_stddev = 0.2 # Increase from 0.05 to 0.2 seconds |
| 70 | + mock_stats = mock_stats_cls.return_value |
| 71 | + mock_stats.count.return_value = 20 # Enough samples |
| 72 | + mock_stats.mean.return_value = mocked_mean |
| 73 | + mock_stats.stddev.return_value = mocked_stddev |
| 74 | + |
| 75 | + # Set a short issue detection interval for testing |
| 76 | + ctx = DataContext.get_current() |
| 77 | + detector_cfg = ctx.issue_detectors_config.hanging_detector_config |
| 78 | + detector_cfg.detection_time_interval_s = 0.00 |
| 79 | + |
| 80 | + # test no hanging doesn't log hanging warning |
| 81 | + def f1(x): |
| 82 | + return x |
| 83 | + |
| 84 | + _ = ray.data.range(1).map(f1).materialize() |
| 85 | + |
| 86 | + log_output = log_capture.getvalue() |
| 87 | + warn_msg = ( |
| 88 | + r"A task of operator .+ with task index .+ has been running for [\d\.]+s" |
| 89 | + ) |
| 90 | + assert re.search(warn_msg, log_output) is None, log_output |
| 91 | + |
| 92 | + # # test hanging does log hanging warning |
| 93 | + def f2(x): |
| 94 | + time.sleep(5.0) # Increase from 1.1 to 5.0 seconds to exceed new threshold |
| 95 | + return x |
| 96 | + |
| 97 | + _ = ray.data.range(1).map(f2).materialize() |
| 98 | + |
| 99 | + log_output = log_capture.getvalue() |
| 100 | + assert re.search(warn_msg, log_output) is not None, log_output |
| 101 | + |
| 102 | + def test_hanging_detector_detects_issues( |
| 103 | + self, caplog, propagate_logs, restore_data_context |
| 104 | + ): |
| 105 | + """Test hanging detector adaptive thresholds with real Ray Data pipelines and extreme configurations.""" |
| 106 | + |
| 107 | + ctx = DataContext.get_current() |
| 108 | + # Configure hanging detector with extreme std_factor values |
| 109 | + ctx.issue_detectors_config.hanging_detector_config = ( |
| 110 | + HangingExecutionIssueDetectorConfig( |
| 111 | + op_task_stats_min_count=1, |
| 112 | + op_task_stats_std_factor=1, |
| 113 | + detection_time_interval_s=0, |
| 114 | + ) |
| 115 | + ) |
| 116 | + |
| 117 | + # Create a pipeline with many small blocks to ensure concurrent tasks |
| 118 | + def sleep_task(x): |
| 119 | + if x["id"] == 2: |
| 120 | + # Issue detection is based on the mean + stdev. One of the tasks must take |
| 121 | + # awhile, so doing it just for one of the rows. |
| 122 | + time.sleep(1) |
| 123 | + return x |
| 124 | + |
| 125 | + with caplog.at_level(logging.WARNING): |
| 126 | + ray.data.range(3, override_num_blocks=3).map( |
| 127 | + sleep_task, concurrency=1 |
| 128 | + ).materialize() |
| 129 | + |
| 130 | + # Check if hanging detection occurred |
| 131 | + hanging_detected = ( |
| 132 | + "has been running for" in caplog.text |
| 133 | + and "longer than the average task duration" in caplog.text |
| 134 | + ) |
| 135 | + |
| 136 | + assert hanging_detected, caplog.text |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.parametrize( |
| 140 | + "configured_memory, actual_memory, should_return_issue", |
| 141 | + [ |
| 142 | + # User has appropriately configured memory, so no issue. |
| 143 | + (4 * 1024**3, 4 * 1024**3, False), |
| 144 | + # User hasn't configured memory correctly and memory use is high, so issue. |
| 145 | + (None, 4 * 1024**3, True), |
| 146 | + (1, 4 * 1024**3, True), |
| 147 | + # User hasn't configured memory correctly but memory use is low, so no issue. |
| 148 | + (None, 4 * 1024**3 - 1, False), |
| 149 | + ], |
| 150 | +) |
| 151 | +def test_high_memory_detection( |
| 152 | + configured_memory, actual_memory, should_return_issue, restore_data_context |
| 153 | +): |
| 154 | + ctx = DataContext.get_current() |
| 155 | + |
| 156 | + input_data_buffer = InputDataBuffer(ctx, input_data=[]) |
| 157 | + map_operator = MapOperator.create( |
| 158 | + map_transformer=MagicMock(), |
| 159 | + input_op=input_data_buffer, |
| 160 | + data_context=ctx, |
| 161 | + ray_remote_args={"memory": configured_memory}, |
| 162 | + ) |
| 163 | + map_operator._metrics = MagicMock(average_max_uss_per_task=actual_memory) |
| 164 | + topology = {input_data_buffer: MagicMock(), map_operator: MagicMock()} |
| 165 | + executor = MagicMock(_topology=topology) |
| 166 | + |
| 167 | + detector = HighMemoryIssueDetector(executor, ctx) |
| 168 | + issues = detector.detect() |
| 169 | + |
| 170 | + assert should_return_issue == bool(issues) |
| 171 | + |
| 172 | + |
| 173 | +if __name__ == "__main__": |
| 174 | + import sys |
| 175 | + |
| 176 | + sys.exit(pytest.main(["-v", __file__])) |
0 commit comments