diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index d0f19f96f983..d9ce3df8860e 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Iterator, Optional import ray +from ray._private.ray_constants import env_integer from ray.data._internal.block_batching.interfaces import Batch, BlockPrefetcher from ray.data._internal.block_batching.util import ( ActorBlockPrefetcher, @@ -22,6 +23,10 @@ from ray.data.context import DataContext from ray.types import ObjectRef +DEFAULT_FORMAT_THREADPOOL_NUM_WORKERS = env_integer( + "RAY_DATA_MAX_FORMAT_THREADPOOL_NUM_WORKERS", 4 +) + class BatchIterator: """Defines an iterator pipeline to convert a stream of block object references @@ -174,12 +179,15 @@ def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: ) def _format_batches(self, batches: Iterator[Batch]) -> Iterator[Batch]: + num_threadpool_workers = min( + DEFAULT_FORMAT_THREADPOOL_NUM_WORKERS, self._prefetch_batches + ) return _format_in_threadpool( batch_iter=batches, stats=self._stats, batch_format=self._batch_format, collate_fn=self._collate_fn, - num_threadpool_workers=self._prefetch_batches, + num_threadpool_workers=num_threadpool_workers, ) def _finalize_batches( @@ -226,6 +234,7 @@ def _iter_batches(self) -> Iterator[DataBatch]: fn=self._pipeline, num_workers=1, preserve_ordering=False, + buffer_size=max(self._prefetch_batches, 1), ) self.before_epoch_start()