Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] AlltoAll OP, Update Data progress bars to use row as the iteration unit #46924

Merged
merged 15 commits into from
Aug 12, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self._bulk_fn = bulk_fn
self._next_task_index = 0
self._num_outputs = num_outputs
self._output_rows = 0
self._sub_progress_bar_names = sub_progress_bar_names
self._sub_progress_bar_dict = None
self._input_buffer: List[RefBundle] = []
Expand All @@ -79,6 +80,13 @@ def num_outputs_total(self) -> Optional[int]:
else self.input_dependencies[0].num_outputs_total()
)

def num_output_rows_total(self) -> Optional[int]:
return (
self._output_rows
if self._output_rows
else self.input_dependencies[0].num_output_rows_total()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self.input_dependencies[0].num_output_rows_total() something that is static? Should we cache this value with some call like self._output_rows = self.input_dependencies[0].num_output_rows_total()?

If this total is a live total that is updated as execution continues makes sense to leave as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Here the self._output_rows is not static, but it's our primary option, as it will be update here
image

)

def _add_input_inner(self, refs: RefBundle, input_index: int) -> None:
assert not self.completed()
assert input_index == 0, input_index
Expand All @@ -99,7 +107,9 @@ def has_next(self) -> bool:
return len(self._output_buffer) > 0

def _get_next_inner(self) -> RefBundle:
return self._output_buffer.pop(0)
bundle = self._output_buffer.pop(0)
self._output_rows += bundle.num_rows()
return bundle

def get_stats(self) -> StatsDict:
return self._stats
Expand All @@ -108,7 +118,7 @@ def get_transformation_fn(self) -> AllToAllTransformFn:
return self._bulk_fn

def progress_str(self) -> str:
return f"{len(self._output_buffer)} output"
return f"{self.num_output_rows_total() or 0} rows output"

def initialize_sub_progress_bars(self, position: int) -> int:
"""Initialize all internal sub progress bars, and return the number of bars."""
Expand All @@ -117,8 +127,8 @@ def initialize_sub_progress_bars(self, position: int) -> int:
for name in self._sub_progress_bar_names:
bar = ProgressBar(
name,
self.num_outputs_total() or 1,
unit="bundle",
self.num_output_rows_total() or 1,
unit="row",
position=position,
)
# NOTE: call `set_description` to trigger the initial print of progress
Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/_internal/execution/streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle:
output_split_idx
)
if self._outer._global_info:
self._outer._global_info.update(1, dag.num_outputs_total())
self._outer._global_info.update(
item.num_rows(), dag.num_output_rows_total()
)
return item
# Needs to be BaseException to catch KeyboardInterrupt. Otherwise we
# can leave dangling progress bars by skipping shutdown.
Expand Down
16 changes: 14 additions & 2 deletions python/ray/data/_internal/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,13 @@ def block_until_complete(self, remaining: List[ObjectRef]) -> None:
done, remaining = ray.wait(
remaining, num_returns=len(remaining), fetch_local=False, timeout=0.1
)
self.update(len(done))
total_rows_processed = 0
for _, result in zip(done, ray.get(done)):
num_rows = (
result.num_rows if hasattr(result, "num_rows") else 1
) # Default to 1 if no row count is available
total_rows_processed += num_rows
self.update(total_rows_processed)

with _canceled_threads_lock:
if t in _canceled_threads:
Expand All @@ -158,9 +164,15 @@ def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]:
)
if fetch_local:
fetch_local = False
total_rows_processed = 0
for ref, result in zip(done, ray.get(done)):
ref_to_result[ref] = result
self.update(len(done))
num_rows = (
result.num_rows if hasattr(result, "num_rows") else 1
) # Default to 1 if no row count is available
total_rows_processed += num_rows
# TODO(zhilong): Change the total to total_row when init progress bar
self.update(total_rows_processed)
Comment on lines +170 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


with _canceled_threads_lock:
if t in _canceled_threads:
Expand Down