-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Train] [Data][Doc] Scaling out expensive collation functions doc #58993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The GIL makes checking s`elf._serialize_cache is not None` atomic, so we don't need lock. Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
Signed-off-by: xgui <xgui@anyscale.com>
Signed-off-by: xgui <xgui@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a new user guide on an important performance optimization: moving the collate_fn from Ray Train workers to Ray Data. The documentation is comprehensive and well-structured, with a clear explanation of the problem, solution, and a complete runnable example.
I've identified a few areas for improvement in the provided code examples:
- A recurring typo in a variable name.
- An inefficient and likely incorrect tensor deserialization method in the utility class.
- An overly complex function for mock data generation that could be simplified for better readability.
These changes will improve the clarity and correctness of the example code for users.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
Signed-off-by: xgui <xgui@anyscale.com>
Signed-off-by: xgui <xgui@anyscale.com>
doc/source/train/user-guides.rst
Outdated
| user-guides/fault-tolerance | ||
| user-guides/monitor-your-application | ||
| user-guides/reproducibility | ||
| user-guides/move-collate-to-data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should probably go into the training ingest section?
https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doc is long so I linked inside https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Dict, List, Tuple, Union | ||
| import torch | ||
| from ray import cloudpickle as pickle | ||
| import pyarrow as pa | ||
|
|
||
| # (dtype, shape, offset) | ||
| FEATURE_TYPE = Tuple[torch.dtype, torch.Size, int] | ||
| TORCH_BYTE_ELEMENT_TYPE = torch.uint8 | ||
|
|
||
| def _create_binary_array_from_buffer(buffer: bytes) -> pa.BinaryArray: | ||
| """Zero-copy create a binary array from a buffer.""" | ||
| data_buffer = pa.py_buffer(buffer) | ||
| return pa.Array.from_buffers( | ||
| pa.binary(), | ||
| 1, | ||
| [ | ||
| None, | ||
| pa.array([0, data_buffer.size], type=pa.int32()).buffers()[1], | ||
| data_buffer, | ||
| ], | ||
| ) | ||
|
|
||
| @dataclass | ||
| class _Metadata: | ||
| features: Dict[str, List[FEATURE_TYPE]] | ||
| total_buffer_size: int | ||
|
|
||
| @dataclass | ||
| class _TensorBatch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't plan to provide these as out of the box?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved to advanced section as users might have their own way.
Signed-off-by: xgui <xgui@anyscale.com>
Signed-off-by: xgui <xgui@anyscale.com>
| class CollateFnRayData(ArrowBatchCollateFn): | ||
| def __init__(self): | ||
| self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | ||
|
|
||
| def __call__(self, batch: pa.Table) -> Dict[str, np.ndarray]: | ||
| results = self.tokenizer( | ||
| batch["text"].to_pylist(), | ||
| truncation=True, | ||
| padding="longest", | ||
| return_tensors="np", | ||
| ) | ||
| results["labels"] = np.array(batch["label"]) | ||
| return results | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
asking a couple questions while I rewrite this - do you have to inherit from ArrowBatchCollateFn? What does it do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tells the iterator that this function receives pyarrow.Table as input.
ray/python/ray/data/collate_fn.py
Lines 173 to 179 in 1180868
| @DeveloperAPI | |
| class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): | |
| """Collate function that takes pyarrow.Table as the input batch type. | |
| Arrow tables with chunked arrays can be efficiently transferred to GPUs without | |
| combining the chunks with the `arrow_batch_to_tensors` utility function. | |
| See `DefaultCollateFn` for example. | |
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but how is the iterator aware of this if you move it into the map_batches operator?
|
|
||
| import random | ||
| import string | ||
| import ray | ||
|
|
||
| def random_text(length: int) -> str: | ||
| """Generate random text of specified length.""" | ||
| if length <= 0: | ||
| return "" | ||
|
|
||
| if length <= 3: | ||
| return "".join(random.choices(string.ascii_lowercase, k=length)) | ||
|
|
||
| words = [] | ||
| current_length = 0 | ||
|
|
||
| while current_length < length: | ||
| remaining = length - current_length | ||
|
|
||
| if remaining <= 4: | ||
| word_length = remaining | ||
| word = "".join(random.choices(string.ascii_lowercase, k=word_length)) | ||
| words.append(word) | ||
| break | ||
| else: | ||
| max_word_length = min(10, remaining - 1) | ||
| if max_word_length >= 3: | ||
| word_length = random.randint(3, max_word_length) | ||
| else: | ||
| word_length = remaining | ||
| word = "".join(random.choices(string.ascii_lowercase, k=word_length)) | ||
| words.append(word) | ||
| current_length += len(word) + 1 | ||
|
|
||
| text = " ".join(words) | ||
| return text[:length] | ||
|
|
||
| def random_label() -> int: | ||
| """Pick a random label.""" | ||
| labels = [0, 1, 2, 3, 4, 5, 6, 7] | ||
| return random.choice(labels) | ||
|
|
||
| def create_mock_ray_text_dataset(dataset_size: int = 96, min_len: int = 5, max_len: int = 100): | ||
| """Create a mock Ray dataset with random text and labels.""" | ||
| numbers = random.choices(range(min_len, max_len + 1), k=dataset_size) | ||
| ray_dataset = ray.data.from_items(numbers) | ||
|
|
||
| def map_to_text_and_label(item): | ||
| length = item['item'] | ||
| text = random_text(length) | ||
| label = random_label() | ||
| return { | ||
| "length": length, | ||
| "text": text, | ||
| "label": label | ||
| } | ||
|
|
||
| text_dataset = ray_dataset.map(map_to_text_and_label) | ||
| return text_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just hide this as a utility that the users can look at, instead of displaying it in the docs?
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Dict, List, Tuple, Union | ||
| import torch | ||
| from ray import cloudpickle as pickle | ||
| import pyarrow as pa | ||
|
|
||
| # (dtype, shape, offset) | ||
| FEATURE_TYPE = Tuple[torch.dtype, torch.Size, int] | ||
| TORCH_BYTE_ELEMENT_TYPE = torch.uint8 | ||
|
|
||
| def _create_binary_array_from_buffer(buffer: bytes) -> pa.BinaryArray: | ||
| """Zero-copy create a binary array from a buffer.""" | ||
| data_buffer = pa.py_buffer(buffer) | ||
| return pa.Array.from_buffers( | ||
| pa.binary(), | ||
| 1, | ||
| [ | ||
| None, | ||
| pa.array([0, data_buffer.size], type=pa.int32()).buffers()[1], | ||
| data_buffer, | ||
| ], | ||
| ) | ||
|
|
||
| @dataclass | ||
| class _Metadata: | ||
| features: Dict[str, List[FEATURE_TYPE]] | ||
| total_buffer_size: int | ||
|
|
||
| @dataclass | ||
| class _TensorBatch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this entire section, can we just hide this as a utility that the users can look at, instead of displaying it in the docs? like just link to it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should be able to put it in doc_code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I hide it now.
Signed-off-by: Richard Liaw <rliaw@berkeley.edu>
Signed-off-by: xgui <xgui@anyscale.com>
Signed-off-by: xgui <xgui@anyscale.com>
…ate-fn-doc Signed-off-by: Richard Liaw <rliaw@berkeley.edu>
|
Two last things to do (for rliaw):
|
Add instructions on using
ds.repartition(target_num_rows=batch_size).map_batches(collate_fn, batch_size=batch_size)ds.map_batches(collate_fn, batch_size=batch_size).repartition(target_num_rows=batch_size)to scale out the collate function inside ray data.
Docs for #58837