Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
de4f17f
Avoid lock if serialization result is cached
xinyuangui2 Nov 18, 2025
003b4ab
Merge branch 'ray-project:master' into master
xinyuangui2 Nov 19, 2025
93ab9d2
Merge branch 'ray-project:master' into master
xinyuangui2 Nov 20, 2025
e2cd6b8
Merge branch 'ray-project:master' into master
xinyuangui2 Nov 25, 2025
136ec12
Merge branch 'ray-project:master' into master
xinyuangui2 Nov 25, 2025
5650791
collate to ray data doc
xinyuangui2 Nov 26, 2025
bc53cb7
collate to ray data doc
xinyuangui2 Nov 26, 2025
5651ab1
Update doc/source/train/user-guides/move-collate-to-data.rst
xinyuangui2 Nov 26, 2025
0f9e445
Update doc/source/train/user-guides/move-collate-to-data.rst
xinyuangui2 Nov 26, 2025
a1d1b39
Update doc/source/train/user-guides/move-collate-to-data.rst
xinyuangui2 Nov 26, 2025
1273d7e
Update doc/source/train/user-guides/move-collate-to-data.rst
xinyuangui2 Nov 26, 2025
b3c9bfd
resolve comments
xinyuangui2 Nov 26, 2025
e9c523b
remove some redundancies
xinyuangui2 Nov 26, 2025
572a215
fix
xinyuangui2 Nov 26, 2025
e06c519
Merge branch 'master' into collate-fn-doc
xinyuangui2 Dec 3, 2025
ec8da4b
apply the new change
xinyuangui2 Dec 3, 2025
b5b76ec
add one link
xinyuangui2 Dec 3, 2025
1e961da
Merge branch 'master' into collate-fn-doc
xinyuangui2 Dec 8, 2025
447075c
adjusted-rliaw
richardliaw Dec 9, 2025
b83cc2f
resolve comments
xinyuangui2 Dec 9, 2025
d94f9ee
hide the utils
xinyuangui2 Dec 9, 2025
c0df121
Merge branch 'collate-fn-doc' of github.com:xinyuangui2/ray into coll…
richardliaw Dec 10, 2025
d469462
ok
richardliaw Dec 10, 2025
d58df9a
update
richardliaw Dec 10, 2025
521b51b
update
richardliaw Dec 10, 2025
8d825a1
Merge branch 'master' into collate-fn-doc
xinyuangui2 Dec 10, 2025
022a035
quick fix
xinyuangui2 Dec 10, 2025
54f6c9a
fix
richardliaw Dec 10, 2025
8273928
Merge branch 'collate-fn-doc' of github.com:xinyuangui2/ray into coll…
richardliaw Dec 10, 2025
425c656
update
richardliaw Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions doc/source/train/doc_code/collate_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import torch
from ray import cloudpickle as pickle
import pyarrow as pa

# (dtype, shape, offset)
FEATURE_TYPE = Tuple[torch.dtype, torch.Size, int]
TORCH_BYTE_ELEMENT_TYPE = torch.uint8

def _create_binary_array_from_buffer(buffer: bytes) -> pa.BinaryArray:
"""Zero-copy create a binary array from a buffer."""
data_buffer = pa.py_buffer(buffer)
return pa.Array.from_buffers(
pa.binary(),
1,
[
None,
pa.array([0, data_buffer.size], type=pa.int32()).buffers()[1],
data_buffer,
],
)

@dataclass
class _Metadata:
features: Dict[str, List[FEATURE_TYPE]]
total_buffer_size: int

@dataclass
class _TensorBatch:
"""Internal class for serializing/deserializing tensor batches."""
buffer: torch.Tensor
metadata: _Metadata

@classmethod
def from_batch(cls, batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> '_TensorBatch':
"""Serialize a batch of tensors into a single buffer."""
features: Dict[str, List[FEATURE_TYPE]] = {}
flattened_binary_tensors = []
total_buffer_size = 0

for name, tensors in batch.items():
features[name] = []
if not isinstance(tensors, list):
tensors = [tensors]
for tensor in tensors:
flattened_tensor = tensor.flatten().contiguous().view(TORCH_BYTE_ELEMENT_TYPE)
flattened_binary_tensors.append(flattened_tensor)
features[name].append((tensor.dtype, tensor.shape, total_buffer_size))
total_buffer_size += flattened_tensor.shape[0]

buffer = torch.empty(total_buffer_size, dtype=TORCH_BYTE_ELEMENT_TYPE)
cur_offset = 0
for flattened_tensor in flattened_binary_tensors:
buffer[cur_offset:cur_offset + flattened_tensor.shape[0]] = flattened_tensor
cur_offset += flattened_tensor.shape[0]

return _TensorBatch(
buffer=buffer,
metadata=_Metadata(
features=features,
total_buffer_size=total_buffer_size,
),
)

def to_table(self) -> pa.Table:
"""Convert to a single-row PyArrow table."""
buffer_array = _create_binary_array_from_buffer(self.buffer.numpy().data)
metadata_array = _create_binary_array_from_buffer(pickle.dumps(self.metadata))
return pa.Table.from_arrays(
arrays=[buffer_array, metadata_array],
names=["_buffer", "_metadata"],
)

@classmethod
def from_table(cls, table: pa.Table) -> '_TensorBatch':
"""Deserialize from a single-row PyArrow table."""
return _TensorBatch(
buffer=torch.frombuffer(
table["_buffer"].chunks[0].buffers()[2],
dtype=TORCH_BYTE_ELEMENT_TYPE
),
metadata=pickle.loads(table["_metadata"].chunks[0].buffers()[2]),
)

def to_batch(self, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]:
"""Deserialize back to a batch of tensors."""
batch = {}
storage_buffer = self.buffer.untyped_storage()
offsets = []
for name, features in self.metadata.features.items():
for _, _, offset in features:
offsets.append(offset)
offsets.append(self.metadata.total_buffer_size)

offset_id = 0
for name, features in self.metadata.features.items():
batch[name] = []
for dtype, shape, _ in features:
# Create a zero-copy view of the byte slice.
byte_slice = self.buffer[offsets[offset_id]:offsets[offset_id + 1]]
tensor = torch.frombuffer(
byte_slice.numpy().data, dtype=dtype
).view(shape)
if pin_memory:
tensor = tensor.pin_memory()
batch[name].append(tensor)
offset_id += 1
return batch

# Helper functions for use in your code
def serialize_tensors_to_table(batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> pa.Table:
"""Serialize a batch of tensors to a PyArrow table."""
return _TensorBatch.from_batch(batch).to_table()

def deserialize_table_to_tensors(table: pa.Table, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]:
"""Deserialize a PyArrow table back to tensors."""
return _TensorBatch.from_table(table).to_batch(pin_memory=pin_memory)
58 changes: 58 additions & 0 deletions doc/source/train/doc_code/random_text_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import random
import string
import ray

def random_text(length: int) -> str:
"""Generate random text of specified length."""
if length <= 0:
return ""

if length <= 3:
return "".join(random.choices(string.ascii_lowercase, k=length))

words = []
current_length = 0

while current_length < length:
remaining = length - current_length

if remaining <= 4:
word_length = remaining
word = "".join(random.choices(string.ascii_lowercase, k=word_length))
words.append(word)
break
else:
max_word_length = min(10, remaining - 1)
if max_word_length >= 3:
word_length = random.randint(3, max_word_length)
else:
word_length = remaining
word = "".join(random.choices(string.ascii_lowercase, k=word_length))
words.append(word)
current_length += len(word) + 1

text = " ".join(words)
return text[:length]

def random_label() -> int:
"""Pick a random label."""
labels = [0, 1, 2, 3, 4, 5, 6, 7]
return random.choice(labels)

def create_mock_ray_text_dataset(dataset_size: int = 96, min_len: int = 5, max_len: int = 100):
"""Create a mock Ray dataset with random text and labels."""
numbers = random.choices(range(min_len, max_len + 1), k=dataset_size)
ray_dataset = ray.data.from_items(numbers)

def map_to_text_and_label(item):
length = item['item']
text = random_text(length)
label = random_label()
return {
"length": length,
"text": text,
"label": label
}

text_dataset = ray_dataset.map(map_to_text_and_label)
return text_dataset
1 change: 1 addition & 0 deletions doc/source/train/user-guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ Ray Train User Guides
user-guides/monitor-your-application
user-guides/reproducibility
Hyperparameter Optimization <user-guides/hyperparameter-optimization>
user-guides/scaling-collation-functions
20 changes: 20 additions & 0 deletions doc/source/train/user-guides/_collate_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
:orphan:

.. _train-collate-utils:

Collate Utilities
=================

.. literalinclude:: ../doc_code/collate_utils.py
:language: python


.. _random-text-generator:

Random Text Generator
=====================

The following helper functions generate random text samples with labels:

.. literalinclude:: ../doc_code/random_text_generator.py
:language: python
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ For example, the following code prefetches 10 batches at a time for each trainin
Avoid heavy transformation in collate_fn
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``collate_fn`` parameter in :meth:`iter_batches <ray.data.DataIterator.iter_batches>` or :meth:`iter_torch_batches <ray.data.DataIterator.iter_torch_batches>` allows you to transform data before feeding it to the model. This operation happens locally in the training workers. Avoid adding a heavy transformation in this function as it may become the bottleneck. Instead, :ref:`apply the transformation with map or map_batches <transforming_data>` before passing the dataset to the Trainer.
The ``collate_fn`` parameter in :meth:`iter_batches <ray.data.DataIterator.iter_batches>` or :meth:`iter_torch_batches <ray.data.DataIterator.iter_torch_batches>` allows you to transform data before feeding it to the model. This operation happens locally in the training workers. Avoid adding a heavy transformation in this function as it may become the bottleneck. Instead, :ref:`apply the transformation with map or map_batches <transforming_data>` before passing the dataset to the Trainer. When your expensive transformation requires batch_size as input, such as text tokenization, you can :ref:`scale it out to Ray Data <train-scaling-collation-functions>` for better performance.


.. _dataset_cache_performance:
Expand Down
Loading