11import time
2- from typing import TYPE_CHECKING , List
2+ from dataclasses import dataclass
3+ from typing import TYPE_CHECKING , Callable , List
34
45import ray
56from ray .data ._internal .execution .operators .hash_shuffle import (
1516
1617if TYPE_CHECKING :
1718 from ray .data ._internal .execution .streaming_executor import StreamingExecutor
18- from ray .data .context import DataContext
19+ from ray .data ._internal .execution .interfaces .physical_operator import (
20+ PhysicalOperator ,
21+ )
22+
23+ @dataclass
24+ class HashShuffleAggregatorIssueDetectorConfig :
25+ """Configuration for HashShuffleAggregatorIssueDetector."""
26+ detection_time_interval_s : float = 30.0
27+ min_wait_time_s : float = 300.0
1928
2029
2130class HashShuffleAggregatorIssueDetector (IssueDetector ):
2231 """Detector for hash shuffle aggregator health issues."""
2332
24- def __init__ (self , executor : "StreamingExecutor" , ctx : "DataContext" ):
25- self ._executor = executor
26- self ._ctx = ctx
33+ def __init__ (
34+ self ,
35+ dataset_id : str ,
36+ get_operators_fn : Callable [[], List ["PhysicalOperator" ]],
37+ config : "HashShuffleAggregatorIssueDetectorConfig" ,
38+ ):
39+ self ._dataset_id = dataset_id
40+ self ._get_operators = get_operators_fn
41+ self ._detector_cfg = config
2742 self ._last_warning_times = {} # Track per-operator warning times
2843
2944 @classmethod
@@ -38,14 +53,24 @@ def from_executor(
3853 Returns:
3954 An instance of HashShuffleAggregatorIssueDetector.
4055 """
41- return cls (executor , executor ._data_context )
56+ def get_operators_fn () -> List ["PhysicalOperator" ]:
57+ if not executor ._topology :
58+ return []
59+ return list (executor ._topology .keys ())
60+
61+ ctx = executor ._data_context
62+ return cls (
63+ dataset_id = executor ._dataset_id ,
64+ get_operators_fn = get_operators_fn ,
65+ config = ctx .issue_detectors_config .hash_shuffle_detector_config ,
66+ )
4267
4368 def detect (self ) -> List [Issue ]:
4469 issues = []
4570 current_time = time .time ()
4671
4772 # Find all hash shuffle operators in the topology
48- for op in self ._executor . _topology . keys ():
73+ for op in self ._get_operators ():
4974 if not isinstance (op , HashShuffleOperator ):
5075 continue
5176
@@ -68,7 +93,7 @@ def detect(self) -> List[Issue]:
6893 message = self ._format_health_warning (aggregator_info )
6994 issues .append (
7095 Issue (
71- dataset_name = self ._executor . _dataset_id ,
96+ dataset_name = self ._dataset_id ,
7297 operator_id = op .id ,
7398 issue_type = IssueType .HANGING ,
7499 message = message ,
@@ -79,7 +104,7 @@ def detect(self) -> List[Issue]:
79104 return issues
80105
81106 def detection_time_interval_s (self ) -> float :
82- return self ._ctx . hash_shuffle_aggregator_health_warning_interval_s
107+ return self ._detector_cfg . detection_time_interval_s
83108
84109 def _should_emit_warning (
85110 self , op_id : str , current_time : float , info : AggregatorHealthInfo
@@ -93,7 +118,7 @@ def _should_emit_warning(
93118 # Check if enough time has passed since start
94119 if (
95120 current_time - info .started_at
96- < self ._ctx . min_hash_shuffle_aggregator_wait_time_in_s
121+ < self ._detector_cfg . min_wait_time_s
97122 ):
98123 return False
99124
0 commit comments