Skip to content

Commit 6ac3247

Browse files
committed
refactor remove executor from constructor in HashShuffleAggregatorIssueDetector
Signed-off-by: machichima <nary12321@gmail.com>
1 parent 89242fd commit 6ac3247

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

python/ray/data/_internal/issue_detection/detectors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from ray.data._internal.issue_detection.detectors.hash_shuffle_detector import (
66
HashShuffleAggregatorIssueDetector,
7+
HashShuffleAggregatorIssueDetectorConfig,
78
)
89
from ray.data._internal.issue_detection.detectors.high_memory_detector import (
910
HighMemoryIssueDetector,
@@ -13,7 +14,8 @@
1314
__all__ = [
1415
"HangingExecutionIssueDetector",
1516
"HangingExecutionIssueDetectorConfig",
17+
"HashShuffleAggregatorIssueDetector",
18+
"HashShuffleAggregatorIssueDetectorConfig",
1619
"HighMemoryIssueDetector",
1720
"HighMemoryIssueDetectorConfig",
18-
"HashShuffleAggregatorIssueDetector",
1921
]

python/ray/data/_internal/issue_detection/detectors/hash_shuffle_detector.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
2-
from typing import TYPE_CHECKING, List
2+
from dataclasses import dataclass
3+
from typing import TYPE_CHECKING, Callable, List
34

45
import ray
56
from ray.data._internal.execution.operators.hash_shuffle import (
@@ -15,15 +16,29 @@
1516

1617
if TYPE_CHECKING:
1718
from ray.data._internal.execution.streaming_executor import StreamingExecutor
18-
from ray.data.context import DataContext
19+
from ray.data._internal.execution.interfaces.physical_operator import (
20+
PhysicalOperator,
21+
)
22+
23+
@dataclass
24+
class HashShuffleAggregatorIssueDetectorConfig:
25+
"""Configuration for HashShuffleAggregatorIssueDetector."""
26+
detection_time_interval_s: float = 30.0
27+
min_wait_time_s: float = 300.0
1928

2029

2130
class HashShuffleAggregatorIssueDetector(IssueDetector):
2231
"""Detector for hash shuffle aggregator health issues."""
2332

24-
def __init__(self, executor: "StreamingExecutor", ctx: "DataContext"):
25-
self._executor = executor
26-
self._ctx = ctx
33+
def __init__(
34+
self,
35+
dataset_id: str,
36+
get_operators_fn: Callable[[], List["PhysicalOperator"]],
37+
config: "HashShuffleAggregatorIssueDetectorConfig",
38+
):
39+
self._dataset_id = dataset_id
40+
self._get_operators = get_operators_fn
41+
self._detector_cfg = config
2742
self._last_warning_times = {} # Track per-operator warning times
2843

2944
@classmethod
@@ -38,14 +53,24 @@ def from_executor(
3853
Returns:
3954
An instance of HashShuffleAggregatorIssueDetector.
4055
"""
41-
return cls(executor, executor._data_context)
56+
def get_operators_fn() -> List["PhysicalOperator"]:
57+
if not executor._topology:
58+
return []
59+
return list(executor._topology.keys())
60+
61+
ctx = executor._data_context
62+
return cls(
63+
dataset_id=executor._dataset_id,
64+
get_operators_fn=get_operators_fn,
65+
config=ctx.issue_detectors_config.hash_shuffle_detector_config,
66+
)
4267

4368
def detect(self) -> List[Issue]:
4469
issues = []
4570
current_time = time.time()
4671

4772
# Find all hash shuffle operators in the topology
48-
for op in self._executor._topology.keys():
73+
for op in self._get_operators():
4974
if not isinstance(op, HashShuffleOperator):
5075
continue
5176

@@ -68,7 +93,7 @@ def detect(self) -> List[Issue]:
6893
message = self._format_health_warning(aggregator_info)
6994
issues.append(
7095
Issue(
71-
dataset_name=self._executor._dataset_id,
96+
dataset_name=self._dataset_id,
7297
operator_id=op.id,
7398
issue_type=IssueType.HANGING,
7499
message=message,
@@ -79,7 +104,7 @@ def detect(self) -> List[Issue]:
79104
return issues
80105

81106
def detection_time_interval_s(self) -> float:
82-
return self._ctx.hash_shuffle_aggregator_health_warning_interval_s
107+
return self._detector_cfg.detection_time_interval_s
83108

84109
def _should_emit_warning(
85110
self, op_id: str, current_time: float, info: AggregatorHealthInfo
@@ -93,7 +118,7 @@ def _should_emit_warning(
93118
# Check if enough time has passed since start
94119
if (
95120
current_time - info.started_at
96-
< self._ctx.min_hash_shuffle_aggregator_wait_time_in_s
121+
< self._detector_cfg.min_wait_time_s
97122
):
98123
return False
99124

python/ray/data/_internal/issue_detection/issue_detector_configuration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from ray.data._internal.issue_detection.detectors import (
55
HangingExecutionIssueDetector,
66
HangingExecutionIssueDetectorConfig,
7+
HashShuffleAggregatorIssueDetector,
8+
HashShuffleAggregatorIssueDetectorConfig,
79
HighMemoryIssueDetector,
810
HighMemoryIssueDetectorConfig,
911
)
@@ -15,9 +17,16 @@ class IssueDetectorsConfiguration:
1517
hanging_detector_config: HangingExecutionIssueDetectorConfig = field(
1618
default_factory=HangingExecutionIssueDetectorConfig
1719
)
20+
hash_shuffle_detector_config: HashShuffleAggregatorIssueDetectorConfig = field(
21+
default_factory=HashShuffleAggregatorIssueDetectorConfig
22+
)
1823
high_memory_detector_config: HighMemoryIssueDetectorConfig = field(
1924
default_factory=HighMemoryIssueDetectorConfig
2025
)
2126
detectors: List[Type[IssueDetector]] = field(
22-
default_factory=lambda: [HangingExecutionIssueDetector, HighMemoryIssueDetector]
27+
default_factory=lambda: [
28+
HangingExecutionIssueDetector,
29+
HashShuffleAggregatorIssueDetector,
30+
HighMemoryIssueDetector,
31+
]
2332
)

0 commit comments

Comments
 (0)