Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions python/ray/data/_internal/execution/operators/hash_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import logging
import math
import random
import threading
import time
from collections import defaultdict, deque
Expand All @@ -16,6 +17,7 @@
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
Expand Down Expand Up @@ -601,8 +603,10 @@ def __init__(
# aggregators (keeps track which input sequences have already broadcasted
# their schemas)
self._has_schemas_broadcasted: DefaultDict[int, bool] = defaultdict(bool)
# Id of the last partition finalization of which had already been scheduled
self._last_finalized_partition_id: int = -1
# Set of partitions still pending finalization
self._pending_finalization_partition_ids: Set[int] = set(
range(target_num_partitions)
)

self._output_queue: Deque[RefBundle] = deque()

Expand Down Expand Up @@ -823,11 +827,6 @@ def _try_finalize(self):
if not self._is_shuffling_done():
return

logger.debug(
f"Scheduling next shuffling finalization batch (last finalized "
f"partition id is {self._last_finalized_partition_id})"
)

def _on_bundle_ready(partition_id: int, bundle: RefBundle):
# Add finalized block to the output queue
self._output_queue.append(bundle)
Expand Down Expand Up @@ -872,10 +871,8 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
or self._aggregator_pool.num_aggregators
)

num_remaining_partitions = (
self._num_partitions - 1 - self._last_finalized_partition_id
)
num_running_finalizing_tasks = len(self._finalizing_tasks)
num_remaining_partitions = len(self._pending_finalization_partition_ids)

# Finalization is executed in batches of no more than
# `DataContext.max_hash_shuffle_finalization_batch_size` tasks at a time.
Expand All @@ -899,12 +896,21 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
if next_batch_size == 0:
return

# Next partition to be scheduled for finalization is the one right
# after the last one scheduled
next_partition_id = self._last_finalized_partition_id + 1

target_partition_ids = list(
range(next_partition_id, next_partition_id + next_batch_size)
# We're sampling randomly next set of partitions to be finalized
# to distribute finalization window uniformly across the nodes of the cluster
# and avoid effect of "sliding lense" effect where we finalize the batch of
# N *adjacent* partitions that may be co-located on the same node:
#
# - Adjacent partitions i and i+1 are handled by adjacent
Copy link
Contributor

@iamjustinhsu iamjustinhsu Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait is this true? if module N = num actors, then partition i and i + 1 must necessarily be in different actors. Oh wait nvm, i see what your saying

# aggregators (since membership is determined as i % num_aggregators)
#
# - Adjacent aggregators have high likelihood of running on the
# same node (when num aggregators > num nodes)
Comment on lines +907 to +908
Copy link
Contributor

@iamjustinhsu iamjustinhsu Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessarily true? your default strategy is spread, and each aggregator is scheduled with same num of resources, so aggregator i and i + 1 have as much of a chance of scheduling on the same node as aggregator i and j. please correct my assumptions if im wrong

#
# NOTE: This doesn't affect determinism, since this only impacts order
# of finalization (hence not required to be seeded)
target_partition_ids = random.sample(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So wouldn't a better strategy be to check how much each agg actor is currently consuming relative to the node's capacity and schedule the finalization if there's remaining capacity?

I just find the randomization strategy harder to reason in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it's a function of partition size, so ideally if we can get metadata about the partition before scheduling the finalize() that would be even better.

list(self._pending_finalization_partition_ids), next_batch_size
)
Comment on lines +912 to 914
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The random.sample function can operate directly on sets, so converting self._pending_finalization_partition_ids to a list is unnecessary. Removing the list() conversion will improve performance by avoiding the creation of a new list in each call, which can be expensive if the number of pending partitions is large.

Suggested change
target_partition_ids = random.sample(
list(self._pending_finalization_partition_ids), next_batch_size
)
target_partition_ids = random.sample(
self._pending_finalization_partition_ids, next_batch_size
)


logger.debug(
Expand Down Expand Up @@ -941,15 +947,15 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
),
)

# Pop partition id from remaining set
self._pending_finalization_partition_ids.remove(partition_id)

# Update Finalize Metrics on task submission
# NOTE: This is empty because the input is directly forwarded from the
# output of the shuffling stage, which we don't return.
empty_bundle = RefBundle([], schema=None, owns_blocks=False)
self.reduce_metrics.on_task_submitted(partition_id, empty_bundle)

# Update last finalized partition id
self._last_finalized_partition_id = max(target_partition_ids)

def _do_shutdown(self, force: bool = False) -> None:
self._aggregator_pool.shutdown(force=True)
# NOTE: It's critical for Actor Pool to release actors before calling into
Expand Down Expand Up @@ -1021,7 +1027,7 @@ def implements_accurate_memory_accounting(self) -> bool:
return True

def _is_finalized(self):
return self._last_finalized_partition_id == self._num_partitions - 1
return len(self._pending_finalization_partition_ids) == 0

def _handle_shuffled_block_metadata(
self,
Expand Down
14 changes: 5 additions & 9 deletions python/ray/data/_internal/execution/operators/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _preprocess(
left_seq_partition: pa.Table = self._get_partition_builder(
input_seq_id=0, partition_id=partition_id
).build()

right_seq_partition: pa.Table = self._get_partition_builder(
input_seq_id=1, partition_id=partition_id
).build()
Expand Down Expand Up @@ -198,7 +199,6 @@ def _preprocess(
should_index_r = self._should_index_side("right", supported_r, unsupported_r)

# Add index columns for back-referencing if we have unsupported columns
# TODO: what are the chances of a collision with the index column?
if should_index_l:
supported_l = self._append_index_column(
table=supported_l, col_name=self._index_name("left")
Expand Down Expand Up @@ -246,7 +246,7 @@ def _postprocess(
return supported

def _index_name(self, suffix: str) -> str:
return f"__ray_data_index_level_{suffix}__"
return f"__rd_index_level_{suffix}__"

def clear(self, partition_id: int):
self._left_input_seq_partition_builders.pop(partition_id)
Expand All @@ -263,9 +263,6 @@ def _get_partition_builder(self, *, input_seq_id: int, partition_id: int):
)
return partition_builder

def _get_index_col_name(self, index: int) -> str:
return f"__index_level_{index}__"

def _should_index_side(
self, side: str, supported_table: "pa.Table", unsupported_table: "pa.Table"
) -> bool:
Expand Down Expand Up @@ -318,9 +315,8 @@ def _split_unsupported_columns(
"""
supported, unsupported = [], []
for idx in range(len(table.columns)):
column: "pa.ChunkedArray" = table.column(idx)

col_type = column.type
col: "pa.ChunkedArray" = table.column(idx)
col_type: "pa.DataType" = col.type

if _is_pa_extension_type(col_type) or self._is_pa_join_not_supported(
col_type
Expand All @@ -329,7 +325,7 @@ def _split_unsupported_columns(
else:
supported.append(idx)

return (table.select(supported), table.select(unsupported))
return table.select(supported), table.select(unsupported)

def _add_back_unsupported_columns(
self,
Expand Down