diff --git a/doc/source/train/doc_code/collate_utils.py b/doc/source/train/doc_code/collate_utils.py new file mode 100644 index 000000000000..9dd00d1a9afa --- /dev/null +++ b/doc/source/train/doc_code/collate_utils.py @@ -0,0 +1,118 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union +import torch +from ray import cloudpickle as pickle +import pyarrow as pa + +# (dtype, shape, offset) +FEATURE_TYPE = Tuple[torch.dtype, torch.Size, int] +TORCH_BYTE_ELEMENT_TYPE = torch.uint8 + +def _create_binary_array_from_buffer(buffer: bytes) -> pa.BinaryArray: + """Zero-copy create a binary array from a buffer.""" + data_buffer = pa.py_buffer(buffer) + return pa.Array.from_buffers( + pa.binary(), + 1, + [ + None, + pa.array([0, data_buffer.size], type=pa.int32()).buffers()[1], + data_buffer, + ], + ) + +@dataclass +class _Metadata: + features: Dict[str, List[FEATURE_TYPE]] + total_buffer_size: int + +@dataclass +class _TensorBatch: + """Internal class for serializing/deserializing tensor batches.""" + buffer: torch.Tensor + metadata: _Metadata + + @classmethod + def from_batch(cls, batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> '_TensorBatch': + """Serialize a batch of tensors into a single buffer.""" + features: Dict[str, List[FEATURE_TYPE]] = {} + flattened_binary_tensors = [] + total_buffer_size = 0 + + for name, tensors in batch.items(): + features[name] = [] + if not isinstance(tensors, list): + tensors = [tensors] + for tensor in tensors: + flattened_tensor = tensor.flatten().contiguous().view(TORCH_BYTE_ELEMENT_TYPE) + flattened_binary_tensors.append(flattened_tensor) + features[name].append((tensor.dtype, tensor.shape, total_buffer_size)) + total_buffer_size += flattened_tensor.shape[0] + + buffer = torch.empty(total_buffer_size, dtype=TORCH_BYTE_ELEMENT_TYPE) + cur_offset = 0 + for flattened_tensor in flattened_binary_tensors: + buffer[cur_offset:cur_offset + flattened_tensor.shape[0]] = flattened_tensor + cur_offset += flattened_tensor.shape[0] + + return _TensorBatch( + buffer=buffer, + metadata=_Metadata( + features=features, + total_buffer_size=total_buffer_size, + ), + ) + + def to_table(self) -> pa.Table: + """Convert to a single-row PyArrow table.""" + buffer_array = _create_binary_array_from_buffer(self.buffer.numpy().data) + metadata_array = _create_binary_array_from_buffer(pickle.dumps(self.metadata)) + return pa.Table.from_arrays( + arrays=[buffer_array, metadata_array], + names=["_buffer", "_metadata"], + ) + + @classmethod + def from_table(cls, table: pa.Table) -> '_TensorBatch': + """Deserialize from a single-row PyArrow table.""" + return _TensorBatch( + buffer=torch.frombuffer( + table["_buffer"].chunks[0].buffers()[2], + dtype=TORCH_BYTE_ELEMENT_TYPE + ), + metadata=pickle.loads(table["_metadata"].chunks[0].buffers()[2]), + ) + + def to_batch(self, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]: + """Deserialize back to a batch of tensors.""" + batch = {} + storage_buffer = self.buffer.untyped_storage() + offsets = [] + for name, features in self.metadata.features.items(): + for _, _, offset in features: + offsets.append(offset) + offsets.append(self.metadata.total_buffer_size) + + offset_id = 0 + for name, features in self.metadata.features.items(): + batch[name] = [] + for dtype, shape, _ in features: + # Create a zero-copy view of the byte slice. + byte_slice = self.buffer[offsets[offset_id]:offsets[offset_id + 1]] + tensor = torch.frombuffer( + byte_slice.numpy().data, dtype=dtype + ).view(shape) + if pin_memory: + tensor = tensor.pin_memory() + batch[name].append(tensor) + offset_id += 1 + return batch + +# Helper functions for use in your code +def serialize_tensors_to_table(batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> pa.Table: + """Serialize a batch of tensors to a PyArrow table.""" + return _TensorBatch.from_batch(batch).to_table() + +def deserialize_table_to_tensors(table: pa.Table, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]: + """Deserialize a PyArrow table back to tensors.""" + return _TensorBatch.from_table(table).to_batch(pin_memory=pin_memory) diff --git a/doc/source/train/doc_code/random_text_generator.py b/doc/source/train/doc_code/random_text_generator.py new file mode 100644 index 000000000000..4dd8809b5b0f --- /dev/null +++ b/doc/source/train/doc_code/random_text_generator.py @@ -0,0 +1,58 @@ +import random +import string +import ray + +def random_text(length: int) -> str: + """Generate random text of specified length.""" + if length <= 0: + return "" + + if length <= 3: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + words = [] + current_length = 0 + + while current_length < length: + remaining = length - current_length + + if remaining <= 4: + word_length = remaining + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + break + else: + max_word_length = min(10, remaining - 1) + if max_word_length >= 3: + word_length = random.randint(3, max_word_length) + else: + word_length = remaining + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + current_length += len(word) + 1 + + text = " ".join(words) + return text[:length] + +def random_label() -> int: + """Pick a random label.""" + labels = [0, 1, 2, 3, 4, 5, 6, 7] + return random.choice(labels) + +def create_mock_ray_text_dataset(dataset_size: int = 96, min_len: int = 5, max_len: int = 100): + """Create a mock Ray dataset with random text and labels.""" + numbers = random.choices(range(min_len, max_len + 1), k=dataset_size) + ray_dataset = ray.data.from_items(numbers) + + def map_to_text_and_label(item): + length = item['item'] + text = random_text(length) + label = random_label() + return { + "length": length, + "text": text, + "label": label + } + + text_dataset = ray_dataset.map(map_to_text_and_label) + return text_dataset \ No newline at end of file diff --git a/doc/source/train/user-guides.rst b/doc/source/train/user-guides.rst index 4aca0b2c6b46..35c8d855c498 100644 --- a/doc/source/train/user-guides.rst +++ b/doc/source/train/user-guides.rst @@ -19,3 +19,4 @@ Ray Train User Guides user-guides/monitor-your-application user-guides/reproducibility Hyperparameter Optimization + user-guides/scaling-collation-functions diff --git a/doc/source/train/user-guides/_collate_utils.rst b/doc/source/train/user-guides/_collate_utils.rst new file mode 100644 index 000000000000..7c4c053ec94e --- /dev/null +++ b/doc/source/train/user-guides/_collate_utils.rst @@ -0,0 +1,20 @@ +:orphan: + +.. _train-collate-utils: + +Collate Utilities +================= + +.. literalinclude:: ../doc_code/collate_utils.py + :language: python + + +.. _random-text-generator: + +Random Text Generator +===================== + +The following helper functions generate random text samples with labels: + +.. literalinclude:: ../doc_code/random_text_generator.py + :language: python \ No newline at end of file diff --git a/doc/source/train/user-guides/data-loading-preprocessing.rst b/doc/source/train/user-guides/data-loading-preprocessing.rst index e52a39a73983..0a806b5d6c80 100644 --- a/doc/source/train/user-guides/data-loading-preprocessing.rst +++ b/doc/source/train/user-guides/data-loading-preprocessing.rst @@ -603,7 +603,7 @@ For example, the following code prefetches 10 batches at a time for each trainin Avoid heavy transformation in collate_fn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The ``collate_fn`` parameter in :meth:`iter_batches ` or :meth:`iter_torch_batches ` allows you to transform data before feeding it to the model. This operation happens locally in the training workers. Avoid adding a heavy transformation in this function as it may become the bottleneck. Instead, :ref:`apply the transformation with map or map_batches ` before passing the dataset to the Trainer. +The ``collate_fn`` parameter in :meth:`iter_batches ` or :meth:`iter_torch_batches ` allows you to transform data before feeding it to the model. This operation happens locally in the training workers. Avoid adding a heavy transformation in this function as it may become the bottleneck. Instead, :ref:`apply the transformation with map or map_batches ` before passing the dataset to the Trainer. When your expensive transformation requires batch_size as input, such as text tokenization, you can :ref:`scale it out to Ray Data ` for better performance. .. _dataset_cache_performance: diff --git a/doc/source/train/user-guides/scaling-collation-functions.rst b/doc/source/train/user-guides/scaling-collation-functions.rst new file mode 100644 index 000000000000..5b60d68e7d5e --- /dev/null +++ b/doc/source/train/user-guides/scaling-collation-functions.rst @@ -0,0 +1,377 @@ +.. _train-scaling-collation-functions: + +Advanced: Scaling out expensive collate functions +================================================= + +By default, the collate function executes on the training worker when you call :meth:`ray.data.DataIterator.iter_torch_batches`. This approach has two main drawbacks: + +- **Low scalability**: The collate function runs sequentially on each training worker, limiting parallelism. +- **Resource competition**: The collate function consumes CPU and memory resources from the training worker, potentially slowing down model training. + +Scaling out the collate function to Ray Data allows you to scale collation across multiple CPU nodes independently of training workers, improving better overall pipeline throughput, especially with heavy collate functions. + +This optimization is particularly effective when the collate function is computationally expensive (such as tokenization, image augmentation, or complex feature engineering) and you have additional CPU resources available for data preprocessing. + +Moving the collate function to Ray Data +--------------------------------------- + +The following example shows a typical collate function that runs on the training worker: + +.. code-block:: python + + train_dataset = read_parquet().map(...) + + def train_func(): + for batch in ray.train.get_dataset_shard("train").iter_torch_batches( + collate_fn=collate_fn, + batch_size=BATCH_SIZE + ): + # Training logic here + pass + + trainer = TorchTrainer( + train_func, + datasets={"train": train_dataset}, + scaling_config=ScalingConfig(num_workers=4, use_gpu=True) + ) + + result = trainer.fit() + +If the collate function is time/compute intensive and you'd like to scale it out,you should: + +* Create a custom collate function that runs in Ray Data and use :meth:`ray.data.Dataset.map_batches` to scale it out. +* Use :meth:`ray.data.Dataset.repartition` to ensure the batch size alignment. + + +Creating a custom collate function that runs in Ray Data +-------------------------------------------------------- + +To scale out, you'll want to move the ``collate_fn`` into a Ray Data ``map_batches`` operation: + +.. code-block:: python + + def collate_fn(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + return batch + + train_dataset = train_dataset.map_batches(collate_fn, batch_size=BATCH_SIZE) + + def train_func(): + for batch in ray.train.get_dataset_shard("train").iter_torch_batches( + collate_fn=None, + batch_size=BATCH_SIZE, + ): + # Training logic here + pass + + trainer = TorchTrainer( + train_func, + datasets={"train": train_dataset}, + scaling_config=ScalingConfig(num_workers=4, use_gpu=True) + ) + + result = trainer.fit() + +A couple of things to note: + +- The ``collate_fn`` returns a dictionary of NumPy arrays, which is a standard Ray Data batch format. +- The ``iter_torch_batches`` method uses ``collate_fn=None``, which reduces the amount of work is done on the training worker process. + +Ensuring batch size alignment +----------------------------- + +Typically, collate functions are used to create complete batches of data with a target batch size. +However, if you move the collate function to Ray Data using :meth:`ray.data.Dataset.map_batches`, by default, it will not guarantee the batch size for each function call. + +There are two common problems that you may encounter. + +1. The collate function requires a certain number of rows provided as an input to work properly. +2. You want to avoid any reformatting / rebatching of the data on the training worker process. + +To solve these problems, you can use :meth:`ray.data.Dataset.repartition` with ``target_num_rows_per_block`` to ensure the batch size alignment. + +By calling ``repartition`` before ``map_batches``, you ensure that the input blocks contain the desired number of rows. + +.. code-block:: python + + # Note: If you only use map_batches(batch_size=BATCH_SIZE), you are not guaranteed to get the desired number of rows as an input. + dataset = dataset.repartition(target_num_rows_per_block=BATCH_SIZE).map_batches(collate_fn, batch_size=BATCH_SIZE) + +By calling ``repartition`` after ``map_batches``, you ensure that the output blocks contain the desired number of rows. This avoids any reformatting / rebatching of the data on the training worker process. + +.. code-block:: python + + dataset = dataset.map_batches(collate_fn, batch_size=BATCH_SIZE).repartition(target_num_rows_per_block=BATCH_SIZE) + + def train_func(): + for batch in ray.train.get_dataset_shard("train").iter_torch_batches( + collate_fn=None, + batch_size=BATCH_SIZE, + ): + # Training logic here + pass + + trainer = TorchTrainer( + train_func, + datasets={"train": train_dataset}, + scaling_config=ScalingConfig(num_workers=4, use_gpu=True) + ) + + result = trainer.fit() + +Putting things together +----------------------- + +Throughout this guide, we use a mock text dataset to demonstrate the optimization. You can find the implementation of the mock dataset in :ref:`random-text-generator`. + +.. tab-set:: + .. tab-item:: Baseline implementation + + The following example shows a typical collate function that runs on the training worker: + + .. testcode:: + :skipif: True + + from transformers import AutoTokenizer + import torch + import numpy as np + from typing import Dict + from ray.train.torch import TorchTrainer + from ray.train import ScalingConfig + from mock_dataset import create_mock_ray_text_dataset + + BATCH_SIZE = 10000 + + def vanilla_collate_fn(tokenizer: AutoTokenizer, batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]: + outputs = tokenizer( + list(batch["text"]), + truncation=True, + padding="longest", + return_tensors="pt", + ) + outputs["labels"] = torch.LongTensor(batch["label"]) + return outputs + + def train_func(): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + collate_fn = lambda x: vanilla_collate_fn(tokenizer, x) + + # Collate function runs on the training worker + for batch in ray.train.get_dataset_shard("train").iter_torch_batches( + collate_fn=collate_fn, + batch_size=BATCH_SIZE + ): + # Training logic here + pass + + train_dataset = create_mock_ray_text_dataset( + dataset_size=1000000, + min_len=1000, + max_len=3000 + ) + + trainer = TorchTrainer( + train_func, + datasets={"train": train_dataset}, + scaling_config=ScalingConfig(num_workers=4, use_gpu=True) + ) + + result = trainer.fit() + + .. tab-item:: Optimized implementation + + The following example moves the collate function to Ray Data preprocessing: + + .. testcode:: + :skipif: True + + from transformers import AutoTokenizer + import numpy as np + from typing import Dict + from ray.train.torch import TorchTrainer + from ray.train import ScalingConfig + from mock_dataset import create_mock_ray_text_dataset + import pyarrow as pa + + BATCH_SIZE = 10000 + + class CollateFnRayData: + def __init__(self): + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + + def __call__(self, batch: pa.Table) -> Dict[str, np.ndarray]: + results = self.tokenizer( + batch["text"].to_pylist(), + truncation=True, + padding="longest", + return_tensors="np", + ) + results["labels"] = np.array(batch["label"]) + return results + + def train_func(): + # Collate function already ran in Ray Data + for batch in ray.train.get_dataset_shard("train").iter_torch_batches( + collate_fn=None, + batch_size=BATCH_SIZE, + ): + # Training logic here + pass + + # Apply preprocessing in Ray Data + train_dataset = ( + create_mock_ray_text_dataset( + dataset_size=1000000, + min_len=1000, + max_len=3000 + ) + .map_batches( + CollateFnRayData, + batch_size=BATCH_SIZE, + batch_format="pyarrow", + ) + .repartition(target_num_rows_per_block=BATCH_SIZE) # Ensure batch size alignment + ) + + trainer = TorchTrainer( + train_func, + datasets={"train": train_dataset}, + scaling_config=ScalingConfig(num_workers=4, use_gpu=True) + ) + + result = trainer.fit() + +The optimized implementation makes these changes: + +- **Preprocessing in Ray Data**: The tokenization logic moves from ``train_func`` to ``CollateFnRayData``, which runs in ``map_batches``. +- **NumPy output**: The collate function returns ``Dict[str, np.ndarray]`` instead of PyTorch tensors, which Ray Data natively supports. +- **Batch alignment**: ``repartition(target_num_rows_per_block=BATCH_SIZE)`` after ``map_batches`` ensures the collate function receives exact batch sizes and output blocks align with the batch size. +- **No collate_fn in iterator**: ``iter_torch_batches`` uses ``collate_fn=None`` because preprocessing already happened in Ray Data. + +Benchmark results +~~~~~~~~~~~~~~~~~ + +The following benchmarks demonstrate the performance improvement from scaling out the collate function. The test uses text tokenization with a batch size of 10,000 on a dataset of 1 million rows with text lengths between 1,000 and 3,000 characters. + +**Single node (g4dn.12xlarge: 48 vCPU, 4 NVIDIA T4 GPUs, 192 GiB memory)** + +.. list-table:: + :header-rows: 1 + + * - Configuration + - Throughput + * - Collate in iterator (baseline) + - 1,588 rows/s + * - Collate in Ray Data + - 3,437 rows/s + +**With 2 additional CPU nodes (m5.8xlarge: 32 vCPU, 128 GiB memory each)** + +.. list-table:: + :header-rows: 1 + + * - Configuration + - Throughput + * - Collate in iterator (baseline) + - 1,659 rows/s + * - Collate in Ray Data + - 10,717 rows/s + +The results show that scaling out the collate function to Ray Data provides a 2x speedup on a single node and a 6x speedup when adding CPU-only nodes for preprocessing. + +Advanced: Handling custom data types +------------------------------------ + +The optimized implementation above returns ``Dict[str, np.ndarray]``, which Ray Data natively supports. However, if your collate function needs to return PyTorch tensors or other custom data types that :meth:`ray.data.Dataset.map_batches` doesn't directly support, you need to serialize them. + +.. _train-tensor-serialization-utility: + +Tensor serialization utility +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following utility serializes PyTorch tensors into PyArrow format. It flattens all tensors in a batch into a single binary buffer, stores metadata about tensor shapes and dtypes, and packs everything into a single-row PyArrow table. On the training side, it deserializes the table back into the original tensor structure. + +The serialization and deserialization operations are typically lightweight compared to the actual collate function work (such as tokenization or image processing), so the overhead is minimal relative to the performance gains from scaling the collate function. + +You can use :ref:`train-collate-utils` as a reference implementation and adapt it to your needs. + +Example with tensor serialization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following example demonstrates using tensor serialization when your collate function must return PyTorch tensors. This approach requires ``repartition`` before ``map_batches`` because the collate function changes the number of output rows (each batch becomes a single serialized row). + +.. testcode:: + :skipif: True + + from transformers import AutoTokenizer + import torch + from typing import Dict + from ray.data.collate_fn import ArrowBatchCollateFn + import pyarrow as pa + from collate_utils import serialize_tensors_to_table, deserialize_table_to_tensors + from ray.train.torch import TorchTrainer + from ray.train import ScalingConfig + from mock_dataset import create_mock_ray_text_dataset + + BATCH_SIZE = 10000 + + class TextTokenizerCollateFn: + """Collate function that runs in Ray Data preprocessing.""" + def __init__(self): + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + + def __call__(self, batch: pa.Table) -> pa.Table: + # Tokenize the batch + outputs = self.tokenizer( + batch["text"].to_pylist(), + truncation=True, + padding="longest", + return_tensors="pt", + ) + outputs["labels"] = torch.LongTensor(batch["label"].to_numpy()) + + # Serialize to single-row table using the utility + return serialize_tensors_to_table(outputs) + + class IteratorCollateFn(ArrowBatchCollateFn): + """Collate function for iter_torch_batches that deserializes the batch.""" + def __init__(self, pin_memory=False): + self._pin_memory = pin_memory + + def __call__(self, batch: pa.Table) -> Dict[str, torch.Tensor]: + # Deserialize from single-row table using the utility + return deserialize_table_to_tensors(batch, pin_memory=self._pin_memory) + + def train_func(): + collate_fn = IteratorCollateFn() + + # Collate function only deserializes on the training worker + for batch in ray.train.get_dataset_shard("train").iter_torch_batches( + collate_fn=collate_fn, + batch_size=1 # Each "row" is actually a full batch + ): + # Training logic here + pass + + # Apply preprocessing in Ray Data + # Use repartition BEFORE map_batches because output row count changes + train_dataset = ( + create_mock_ray_text_dataset( + dataset_size=1000000, + min_len=1000, + max_len=3000 + ) + .repartition(target_num_rows_per_block=BATCH_SIZE) + .map_batches( + TextTokenizerCollateFn, + batch_size=BATCH_SIZE, + batch_format="pyarrow", + ) + ) + + trainer = TorchTrainer( + train_func, + datasets={"train": train_dataset}, + scaling_config=ScalingConfig(num_workers=4, use_gpu=True) + ) + + result = trainer.fit()