Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Polylithic: Enabling multi-threaded DataLoading through non-monolithic parallelism #1318

Open
andrewkho opened this issue Sep 16, 2024 · 8 comments

Comments

@andrewkho
Copy link
Contributor

andrewkho commented Sep 16, 2024

🚀 The feature

TL;DR - We want to lean into modular Multi-Threading/Multi-Processing instead of the current monolithic Multi-Processing, and steer users away from the monolithic Dataset parallelism approach towards composable Iterables for pre-proc operations, with parallelism configured within each operation. This will enable multi-threaded dataloading (with NoGIL support), auto-tunable parallelism, torch.compilable and GPU enabled preproc operations, more efficient loading of mixed-modalities, and composable dataloading and pre-proc graphs.

Motivation, pitch

Working name for the project: Polylithic (non-monolithic)
Where it will live: torchdata

Why Multi-Threading in the DataLoader?

Why do users need multi-threading in PyTorch DataLoader? If your GPUs are not starved, then frankly you don’t need to change anything. But what happens when your GPUs are starving, for example in training or offline inference where expensive pre-proc like video decoding happens on the trainer? Multi-Modal LLM training requires dataloading of text and images and/or video data sources.
Today, torch.utils.data.DataLoader uses multi-processing to perform dataloading and pre-proc in parallel with your training loop, and the first thing you should try is increasing the number of multiprocess workers. By changing a single variable in the DataLoader constructor, you can pretty easily add parallelism to your job, even with a custom Dataset class written in Python. If this keeps your GPUs fed, then you’re done!
But, Python multiprocessing may also introduce a lot of friction, and is heavy-weight in memory often leading to CPU OOMs, even on ZionEX’s with 96 cores, and 2TB of RAM shared between 8 Nvidia A100s.

  • Python’s copy-on-read behaviour is very memory inefficient and likely contributing to user OOMs.
  • Dataset instances need to be sent to multiprocess workers, which is the easiest way to send dataset state and user-defined python code to the multiprocess workers.
    • However this requires the whole dataset object to be picklable, which eg disallows using lambdas as functions.
    • TTFB may increase substantially under spawn and forkserver modes.
  • Pickling batches for IPC communication can also slow down end-to-end training in unexpected ways.
  • Inspecting the dataset state becomes much more difficult as users have no obvious way to interact with objects living in the worker processes.

At this point, the user has a few options to consider:

  1. throw up your hands and just live with the GPU inefficiency;
  2. explore GPU Pre-Proc, offloading work from CPU to the idle GPUs
    • There’s a lot of promise here, as the GPUs are under-utilized anyways and accelerated pre-proc can lead to faster end-to-end training, however users still hesitate to utilize GPU time and memory for non-training purposes
    • GPU Pre-Proc is not practical with multi-processing due to the current approach of passing the entire dataset to each multiprocess worker. As a workaround, users may need to move their pre-proc code to their training loop instead.
  3. migrate to a completely different dataloading system;
    • There will be research, and a learning curve for you and your team to become proficient in the new system
  4. explore ways to improve the memory efficiency of your Dataset, for example by storing metadata in some shareable memory, if you suspect you’re being hit by python’s copy-on-read behaviour;
    • You’ll be able to stay with PyTorch dataloader, but we don’t provide tooling to make this easy and it may require a complete re-write of your dataset
  5. Use offline pre-proc to move your CPU-intensive pre-proc off the box, e.g. Airstore, Mosaic, DPS tensor-caching
    • You need to pay for offline compute, storage.
  6. Use off-box online pre-proc (eg disaggregated DPP)
    • There is increased complexity in this system, the cost of disaggregated compute, and co-ordination between trainers and dataloader workers
What about Multi-Threading?

Within the realm of on-box compute approaches, multi-threading has lower startup and memory costs than multi-processing. Many multi-processing pain points disappear in a multi-threaded world, such as avoiding Python’s copy-on-read behavior, IPC costs, and user pain points. Subject to the same memory budget, jobs will be able to scale up the number of workers much more than with multi-processing (until hitting CPU limits).
However multi-threading introduces many new problems as well:

  • GIL - Global Interpreter Lock in CPython creates contention as only one thread at a time can run Python code, preventing true multi-threaded parallelism.
  • Thread-safety - User code and popular dataloading/pre-proc clients and libraries may not have been written to be thread-safe, since the only parallelism mode in torch.utils.data is multi-processing.
    • Users may see unexpected/hard-to-reproduce bugs if they are not careful and don’t understand how to write safe concurrent code, or don’t realize the library they’re using is not thread-safe
  • Thread-contention - even with NoGIL and thread-safe implementations, locks in user code, popular dataloading and pre-proc libraries, and even using python built-in classes may result in performance-degrading thread-contention.
    • It will take time for the community and library maintainers to improve their code for NoGIL python, however there is a chicken-and-egg problem as NoGIL may not gain wide adoption if they don’t see the performance improvements they expect, and library maintainers will not be motivated to improve their libraries if not enough users are asking for it.
We should still enable multi-threading in PyTorch DataLoader

Despite the problems multi-threading introduces, we should still offer a way to perform multi-threaded dataloading.

  • We should be early adopters of NoGIL/Free-Threaded Python
    • As the official PyTorch dataloading library, our early enablement/endorsement of FT Python will encourage other dataloading and pre-proc library maintainers to ensure their libraries are thread-safe and performant. Projects such as Pillow are already preparing/prepared for NoGIL.
    • As early adopters, we’ll need to be prepared for a world in which not all libraries that users are using are thread-safe. We plan to allow users to isolate parallelism to particular operations (eg through a ParallelMap operator) and enable mixed multi-process and multi-threading in a pipelined-fashion.
  • PyTorch DataLoader should enable users to try different parallelism and accelerated set ups to find what works best for them, including Multi-Process, Multi-Threading, and GPU Pre-Proc, which we don’t enable today.
    • DataLoading performance may depend on a large number of factors outside of the DataLoader’s (and often user’s) control: data storage location, storage formats, host compute resources, the particular model being trained/inferred on, and the types of pre-proc users are doing. We can’t and shouldn’t try to predict all of the possible set ups, and instead enable users to experiment and find what works best for them.
    • Even with the GIL, there are scenarios where multi-threading out-performs multi-processing, albeit GIL contention currently limits this to a small number of use cases.
    • We should enable users to try out accelerated image/video decoding and pre-proc within the dataloader framework, as those optimizations may yield significant end-to-end performance gains.
    • By giving users more fine-grained control over parallelism, tuning for optimal performance becomes a challenge. We could introduce some sort of worker-auto-tuning mechanism, a-la tf.data to make this part easy for users.
  • When LLM trains exclusively on text, it's easy to keep the GPU fed. However Multi-Modal LLM training and fine-tuning are expected to increase internally and in the OSS community.
    • Existing text dataloading pipelines often don’t require parallelism (eg just a single background process) to keep the GPUs fed. Enabling more granular parallelism will allow NLP and GenAI practitioners to keep existing dataloading setups that work, but separately parallelize any multi-media dataloading and pre-proc operations.

Supporting Multi-Modal DataLoading

Llama 3 is here, and Llama 4 will arrive soon with early-fusion multi-modality. Tasks like fine-tuning, alignment, and distillation will require multi-modal dataloading for our internal and external users. LLM training often requires reading from 10s-100s of multi-modal datasets, tokenizing them, and packing them into a “token-buffer” where tokens from individual datasets are shuffled and combined into training examples for the model.

Audio, Image, and Video datasets may also require heavy-weight decoding operations to be performed before tokenization, and the difference in the data sizes between text, image, and video may be orders of magnitude. GPU decoding of images and video is an option for users as well, and libraries like Nvidia DALI will compile the entire pre-proc pipeline into GPU operations, minimizing the overhead of transfers between CPU and GPU memory.

Existing Context and definitions

Torch.utils.data contains the following abstractions today:

  • Dataset (aka Map-Dataset): an interface that users subclass, which defines the `__getitem__(i) -> sample` method and is responsible for loading data and performing pre-proc.
    • Eg load encoded image ‘i’ into memory, performing decoding, tensorfication, cropping, random rotations, and returning the sample for training
  • Sampler: typically of type `Iterable[int]` that defines the iteration order over a Map-Dataset. Numerous built-in samplers exist which handle in-order iteration, shuffling, weighted sampling (eg if every sample has a sampling weight), and data sharding for distributed training.
  • IterableDataset: an interface that users subclass, which defines the `__iter__() -> Iterable[sample]` method and is responsible for loading data and performing pre-proc, but is also responsible for shuffling, data sharding for distributed training. IterableDatasets are not used with Samplers.
    • Eg for LLM training, holds iterators to 5 text datasets, performs weighted sampling between datasets, loads and tokenizes text, fills a token buffer and yields sets of tokens.
  • DataLoader/StatefulDataLoader: a class which takes either a) Dataset + Sampler, or b) IterableDataset, may create multiple processes which each hold a copy of the Dataset/IterableDataset object instance. The DataLoader requests data from each individual worker (through either Sampler-provided-index or next()).
    • Multi-processing is currently the only available option provided by the DataLoader. The Python GIL prevents true thread-based parallelism, however the NoGIL PEP 703 is hoping to change that and enable true free-threaded parallelism in Python. Caveat: even with the GIL, there are likely use cases which would still benefit from multi-threading over multi-processing, though the pool of use-cases is probably smaller.
    • StatefulDataLoader is a drop-in replacement for DataLoader that has state_dict/load_state_dict methods.
# Example usage of torch.utils.data.DataLoader, Sampler, and Dataset, with multiprocess parallelism
dl = torch.utils.data.DataLoader(my_dataset, maybe_my_sampler, batch_size, multiprocessing_num_workers)
for batch in dl:
  # model forward/backward

“Monolithic” parallelism

Currently users have a single lever to control parallelism, num_workers. When num_workers > 0, the DataLoader creates background processes and holds a copy of the entire Dataset object in process memory, treating it as a “monolithic” object to be parallelized.

Consider the scenario in the figure below, where a user has defined an iterable dataset which combines two text datasets and one image dataset. There is no parallelism in this example.
image

Now consider the common case when only the image-decoding and tokenization is a bottleneck causing GPU Starvation. With today’s tooling, users simply increase dataloader num_workers > 1. The image below depicts how this is done today, by treating the entire IterableDataset as a monolith that is forked/spawned to another process.
image

A granular parallelism approach

To fix the monolithic parallelism problem, we want to introduce abstractions and tooling that expose more granular parallelism controls to users. This implies a solution where users construct their dataloading and pre-proc pipelines by defining and stitching together datasource and pre-proc nodes into a graph, in a similar fashion to tf.data and datapipes, with data passing between the nodes. The root of the graph is the node which produces batches that are passed to the model. The leaves are data-sources which produce data by reading from local disk, remote storage, or eg random number generators. Intermediate nodes may transform data, perform pre-fetching, combine data from multiple nodes, perform “enrichments” by eg fetching images from blob stores, perform decoding, schedule GPU operations etc.

Requirements and Constraints

To adequately support Multi Modal LLM training for PyTorch users, address the above pain points, and give us the best chance for wide-adoption, we want our solution to meet the following requirements and constraints:

  • Eager execution is the default behaviour. Ease of experimentation, flexibility, and debugging of experimental python code are critical to PyTorch’s success. Ensuring our solution has an “eager mode” which will dump great stack-traces will make developing and debugging easy for users.
  • Construct your graph with Python. Giving users the flexibility to write their pre-proc pipelines with a general purpose language will maximize experimentation and expressivity, lower barriers for entry, and match PyTorch conventions.
    • Example pseudo-code block: LLM training with 50 datasets, randomly sampling from datasets on each iteration
class DatasetSampler:
  def __init__(self, sources: List[iterables]):
    self.sources = sources

  def __iter__(self):
    self.base_iters = [itertools.cycle(iter(x)) for x in self.sources]
    n = len(self.base_iters) 
    while True: 
      ds_idx = random(n, self.sampling_weights)
      yield next(self.base_iters[ds_idx])
  • Backwards Compatibility with torch.utils.data
    • We want to minimize the number of new concepts/classes we introduce to users
    • We also want to provide an easy path to adoption by allowing users to reuse their existing Datasets (eg WebDS, Mosaic, HuggingFace, etc) as much as possible.
  • Support multi-process, multi-threaded, and NoGIL multi-threaded based parallelism at the node level
    • Some users may not want to move to multi-threading, may be stuck with GIL Python, or non-thread-safe code and libraries, where multi-processing gives the best performance.
    • With NoGIL, we will hopefully be in a world where Thread-based parallelism is a viable alternative to process-based parallelism, and we want to lean into this aspect as much as possible.
  • Enable GPU Pre-Proc pipelines to be defined, and enables compilability
    • Our solution should provide a path to enabling GPU Pre-Proc pipelines, and torch.compile compatibility.
    • One potential solution is to have a “TorchCompile” node that takes a sequence of operations and runs torch.compile on them. Alternatively we ask users to pass a torch.compiled pipeline to a Mapper node.
    • We don’t require that every possible graph configuration (eg multiprocess parallelism) is torch.compilable, however the following example should be possible:
      • MultiThreaded reading -> torch.compile(GPU decoding -> Crop -> mirror) -> training loop
  • Support for in-order and out-of-order iteration, and support for random transform reproducibility
    • The current dataloader provides guarantees on ordering, we should continue to support this by default as it’s an important requirement for many researchers.
    • We will provide the option to relax these constraints, which may improve throughput
    • Needs feedback: if it’s too difficult to guarantee eg random transform reproducibility, we might want to relax this constraint which is more challenging in a multi-threaded environment, but should do everything we can to ensure reproducibility in iteration order.
    • Bottom line: reproducibility is an important variable to control for in experimentation, and we should do everything we can to ensure reproducibility.
  • Support for automatic tuning of workers
    • Introducing more granular parallelism controls creates a large dataloader-parameter space for users to optimize performance. Our solution should enable tf.data-style AUTOTUNE capabilities (see section 3.3) and provide good-enough results for most users.
    • Depending on how hairy this gets for multi-process, we might limit to tuning of multi-threaded workers only
  • Support for mid-epoch checkpointing and resuming
    • Training an epoch can take days or even weeks (or more) for some models and datasets. Mid-epoch checkpoint/resuming is essential to these types of workloads.
    • [caveat] We’ll need to think through how this would work for out-of-order execution.
  • Nodes will be iterable only, with no indexing support
    • Map-style datasets + samplers will be supported, but we won’t support index-based access between nodes
      • [what we won’t do] Datapipe’s MapDataPipe allowed users to pass indexes to the root of the graph and retrieve specific examples, however this requires two directions of communication between nodes, and also does not work at all for the more general case of eg sampling from multiple datasets
    • Users may still use Map Datasets + samplers by wrapping into an Iterable which produces samples, but the indices from sampler are generated from within the iterable, not coming from outside the sampler.
    • Alternatively, a Sampler can be used as a source dataset, and passed to a Mapper node which does something like “yield from (self.dataset[i] for i in self.source)”

How will we achieve this/what will we build? Plan of Record

We will introduce a new base class, (working name) say class PolylithicNode(torch.utils.data.IterableDataset). Nodes in the graph will be instances of subclasses of PolylithicNode. Nodes will define a .iterator() method instead of overriding __iter__(). This is inspired by nn.Module’s implementation where users define .forward() instead of __call__. This will allow PolylithicNode to instantiate user-defined iterators and wrap them, insert queues for pipeline-parallelism, and measure latency. For backwards compatibility, we’ll provide a wrapper which takes an existing IterableDataset. Users can compose their datasets by composing PolylithicNodes (ie through iter() and next()).

Example of composing iterable datasets to create a multimodal dataloader. [Note that we are open to ideas on syntactical sugar]

from torchdata.polylithic.nodes import PolylithicNode, Batcher, MultiThreadedMapper, PinMemory, Prefetcher # Note that all of these classes subclass PolylithicNode

# Note: PolylithicNode is an abstract class which provides common code for state_dict, graph traversal, autotuning, #   error propogation, etc.
# class PolylithicNode(torch.utils.data.IterableDataset): ...
#   def __iter__(self):  # PolylithicNode is still an IterableDataset
#     ...

# Some existing IterableDataset, perhaps generated through eg HuggingFace
class MyIterableDataset(torch.utils.data.IterableDataset):
  def __init__(self, json_l_file):
    self.json_l_file = json_l_file
  def __iter__(self):
    while True: # Loop forever
      with open(self.json_l_file, "r") as f:
        for line in f.readlines():
          yield json.loads(line)

# Define a Token Packer
class MyTokenPacker(PolylithicNode):
  def __init__(self, tokens_per_sample: int, sources: List[PolylithicNode], weights: List[float]):
    self.n = tokens_per_sample
    self.sources = sources
    self.weights = weights

  def iterator(self):
    self.source_iters = [iter(src) for src in self.sources]
    sample = []
    while True:
      while len(sample) < self.n:
        src_idx = weighted_sample_int(len(weights), self.weights)
        tokens = next(self.source_iters[src_idx])["tokens"]
        sample.extend(tokens)
      yield sample[:self.n]
      sample = sample[self.n:]

# Set up Tokenizer UDFs
def tokenize(data):
  data["tokens"] = Tokenizer()(data["text"])

def tokenize_img_and_text(data):
  data["tokens"] = DecodeAndTokenize()(data["image"]) + Tokenizer()(data["caption"])

# Set up text reader
text_src = PolylithicNode.from_iterable(MyIterableDataset("text_data.jsonl"))
text_src = MultiThreadedMapper(text_src, udf=tokenize, num_workers="AUTOTUNE")

# Set up Text and Image dataset, with GPU Decoding 
img_src = PolylithicNode.from_iterable(MyIterableDataset("img_caption_data.jsonl"))
img_src = Mapper(img_src, udf=GpuImageDecoder(...)) # single threaded in main process
img_src = MultiThreadedMapper(img_src, udf=tokenize_img_and_text, num_workers="AUTOTUNE")
�# Rest of pipeline
node = MyTokenPacker([img_src, text_src], [0.25, 0.75])
node = Batcher(node, batch_size)
node = PinMemory(node)
node = Prefetch(node, 2)

for tokens in node:
  ...

More complex diagram
image

  • We will define the PolylithicNode base class and tooling and utilities to support composing multiple PolylithicNodes into pipelines/preproc graphs (DAGs), with node-level parallelism controls.
  • As a programming model, users would be chaining together iterators similar to how nn.Module’s are composed: by having a base class (ie PolylithicNode) whose dependencies are its member variables which are also PolylithicNodes.
    • We can traverse the graph with reflection, inspecting instance fields to find ancestors in the graph (ie the current datapipes approach)
    • PolylithicNode is itself a subclass of IterableDataset
  • Users will be able to define their own nodes with Python
  • Define a constructor to wrap existing implementations of IterableDatasets into PolylithicNodes
  • Build out a library of useful nodes/operators that will provide users with the same functionality they expect, some examples:
    • MapToIterable operator that takes a map-style dataset + sampler to create an iterable PolylithicNode
    • Batcher
    • Pin Memory
    • ParallelMap operator which supports thread or process based parallelism for UDFs
      • Create input/output queues, and workers
        • input_queue -> [worker, worker] -> output_queue
          • A single thread reads from the source node and puts data into the input queue
          • Workers put data on the output queue
          • __iter__() yields from the output queue
      • We could have a mode which disables parallelism and prefetch eg:
        • With NoParallelism():
          • For batch in my_polylithic_dag: …
    • TorchCompile’d Map operator
      • We’d want to make sure this runs in the main process and single-threaded, and tensors passed in / created are on the same device
    • Prefetch operator
    • Caching capabilities (to memory or disk)
    • Broadcast node for TensorParallel consistency
  • Include tooling to traverse graph (see datapipe’s graph traversal method for an example)
    • [Needs feedback/investigation on feasibility] Provide graph-optimizations (eg node fusion) and pipeline parallelism in dataloading graph
  • Add auto-tuning capabilities a la tf.data, and include tooling to measure throughput through each node to identify bottlenecks.
  • Include an “eager mode” that will disable parallelism and prefetching, to ease debugging and development
  • Deterministic iteration order by default, with out-of-order possible
  • Checkpointing support for iteration order with Deterministic ordering, and some limited support when iterating out-of-order
  • Provide a migration plan for users coming from datapipes who rely on composability
  • Delegate Data Parallel Sharding and Shuffling to DataSources / Bundles library (out of scope for this RFC), to ease IterableDataset Pain Points.
  • Backwards compatibility: folks can re-use existing Map and IterableDatasets, and Samplers
    • Existing IterableDatasets (or any iterable) can be wrapped/converted to a PolylithicNode
    • While we won’t support map-style access (e.g. MapDataPipe), users can control iteration order at the source, by combining Map + Sampler into a PolylithicNode.
    • Debugging will be more clunky, but can be achieved by passing a list of specified indexes as the sampler.

Alternatives

  • Ask users to write their UDFs in C/C++
    • Python is the frontend for PyTorch and the easiest language for most users/researchers to get started with. Migrating to C/C++ is something many users are not comfortable doing, and incurs a large re-write cost, as well as environment handling for different platforms (e.g. developing on a mac, but training on Linux servers).
  • What about defining a Python DSL which wraps C++ (or some other language) operators?
    • This increases the difficulty for users to onboard and write custom operations, as they need to learn the DSL.
    • We also need to understand/predict the space of possible operations that users may want to do.
    • It also makes debugging code more difficult for users who aren’t C/C++ programmers.
  • Can users torch.compile everything?
    • We’d like to enable torch.compile of certain operations, but it’s unclear how this would interact with C++ libraries that users depend on (eg storage clients, image/video decoders)
    • As of today, issues like compilation time and performance are still not quite ready for the variety of use cases we see in dataloading. For eg an iterable interface, users would need to torch.compile the next() call on the iterator
  • Can we just swap out processes for threads in torch.utils.data.DataLoader? This would be very straightforward to implement, however this reliance on the "monolithic" parallelism still has drawbacks
    • For Map datasets, one approach is to assume they are "stateless", allowing us to simply pass the entire object to multiple threads. However this would require everything in __getitem__ to be thread-safe and contention-free, which will be a challenge when users depend on custom libraries.
    • For Iterable datasets, having many threads hold an iterator instance is even more challenging, as you'd have the same thread-safety and thread-contention issues as Map style, but you also need to assume your Iterable dataset is not using state to manage work-sharding. Nvidia's prototype takes a lock on the dataset, which would essentially make dataloading single-threaded.
    • For both approaches, there is limited ability to control concurrency factor in GPU pre-proc, and that would need to be traded off with dataloading and other pre-proc parallelism.

Additional context

What about DataPipes and DL v2?

To avoid confusion, and remove potential issues with backwards compatibility, we will still be deprecating DataPipes and DL v2, however due to the similarity of DataPipes and the current approach, we believe migration should be fairly straightforward. DataLoader2 has very low adoption and we won’t be providing something similar to replace it.

DataPipes and DL v2 were designed to address issues like composability, and there is a lot of value in what was built, however their parallelism and data sharding structure is still based on a monolithic approach (eg plug a datapipe into DL v1, or DL v2 + multiprocess reading service). They required migration/rewrite of datasets with often no improvement in performance, identifying dataloading-preproc bottlenecks was a challenge, and shuffling/sharding pain points weren’t adequately addressed.

The proposed approach improves upon DataPipes + DLv2 in the following ways:

  • Reduced resource utilization through granular parallelism, and auto-tuning
  • Improved throughput/performance through NoGIL multi-threading and GPU Pre-Proc pipelines
    • [Risk] if Python NoGIL does not gain adoption, users may need to fallback on granular process-based parallelism

We want to maintain the composable aspects of datapipes, the eager-execution, and continue our partnerships with storage and cloud providers (AWS, Azure, GCP) where they provide high-performance clients, share customer pain points, and provide recommended solutions and examples to their users.

@mfbalin
Copy link

mfbalin commented Sep 16, 2024

It looks like we can implement datapipes on top of Polylithic nodes and continue to use existing code without changing them at all.

What are your thoughts on this?

@andrewkho
Copy link
Contributor Author

@mfbalin I would not want to implement datapipes on top of the Polylithic, as there would be twice the surface area to support and test against

@ppwwyyxx
Copy link

It's not clear in the example code whether Prefetch is parallel or not. In fact a Prefetch can also be either multi-thread or multi-process, and both are useful.

The semantics of "Parallel Prefetch" would be just like "monolithic" parallelism where the entire node + its dependency nodes are replicated and produce results to a queue. It has a benefit over "Parallel Mapper" because a mapper requires two-sided communication, however prefetch workers take no input from main worker and only has to produce output. It requires care to ensure no-duplicates and reproducibility, similar to what people would do today with torch.DataLoader + IterableDataset.

@andrewkho
Copy link
Contributor Author

Thanks @ppwwyyxx for the comment and suggestion, all implementations are up for discussion. Definitely open to a ParallelPrefetch or something similar. One path is for users to throw everything into a torch.DataLoader with multi-processing and not re-invent the wheel there. I think my main concerns are the same as yours: setting up for no-duplicates and reproducibility may be a challenge, and ideally we could find a way to reduce foot-guns, as described in eg this notebook. If we're parallelizing a monolithic node, it's hard for us to observe what's happening internally and whether users set up is incorrect. Putting safeguards/warnings in may be a challenge at a global level

If the intention is just to parallelize source/producer nodes, say to interleave multiple files, then we should just find a way to parallelize there.

@andrewkho
Copy link
Contributor Author

OTOH, if we created a ParallelPrefetch with multi-processing that we could configure to have identical behaviour of current torch.DataLoader, we may be able to consolidate the implementation code while maintaining backwards compatibility

@josiahls
Copy link

josiahls commented Sep 18, 2024

I'm on board with the idea of PolylithicNode. Making these nodes iterable-only seems to simplify the user experience. As someone diving into this, I initially faced confusion around the distinctions between map and iter datapipes. Questions like "Should I use map for the entire pipeline?" or "Why does the map API lack functions available in the iter API?" were quite common early on, but eventually I understood that the idea is to get into iterator only as early as reasonable.

Regarding debugging, I encountered several challenges, although these seem to stem more from the nature of modular chained pipelines in Python than from torchdata specifically:

  1. Stepping through Pipelines: Using a debugger with these chained iterators was quite painful. Each "step" required going through multiple lines of boilerplate code in the iterable datapipe. This may relate to the concept of "inlining(?)," but I'm not entirely sure. Do you have any existing solutions for this issue? I can provide more details if needed.

  2. Identifying Map Datapipes During Debugging: When debugging within a "map" datapipe, it was often unclear which specific map datapipe I was in since they all appear similar. Are there plans to address this? I was thinking that iter datapipes might benefit from a "prefix" or "label" indicating their role within a group or process chunk. Perhaps you have a more elegant solution in mind?

  3. Datapipe Compatibility: Determining which datapipe plugs into another was not straightforward. Given Python's lack of strict typing, I frequently had to wait until runtime to discover incompatibilities.

  4. Stack Trace Complexity: The longer the pipeline, the more convoluted the stack trace becomes, amplifying the difficulty in troubleshooting (similar to the issue described in point 1).

I appreciate the torchdata library and the idea of simple, composable nodes. After building several production pipelines with it, I've started to believe that Python's inherent lack of type safety between pipes and the inability to skip over redundant "glue" code in connections makes it inherently challenging to create easy-to-debug, composable pipelines. This might also explain why many popular data processing libraries seem to prefer more monolithic approaches (or large modules linked together), as they're less cumbersome to debug.

That said, I still like the concept of "horizontal APIs" made up of simple, linked objects that can achieve impressive things when combined. I'm curious if the issues I've mentioned can be effectively resolved within Python. Any insights or planned improvements would be greatly appreciated.

@andrewkho
Copy link
Contributor Author

@josiahls thank you for the feedback and describing your pain points! These are definitely the questions and issues that we want to prioritize and make sure the solution is debuggable for developers and users. As far as the questions go, this does sound like something inherent in functional chains of operators.

  1. definitely understand how this boilerplate can be painful. As a first thought, we're thinking of wrapping user iterators in order to be able to measure throughput, and have an "inline/debugging" mode which removes queues so everything is eager and run synchronously. I am pretty sure we don't want to do any magic here, and will stick with plain python stack-traces, open to ideas on how we can reduce the boiler plate as much as possible.

  2. as you mentioned, I think simplifying by getting rid of map datapipes is the easiest thing to do. With multiple datasets and nodes, I think the simplest thing would be to include [optional] user-supplied "name" or "prefix" tags to the constructor, but surfacing these in stack traces may not be super straightforward. Are you thinking for stepping through in a debugger, or for stack-traces when an error is thrown?

  3. That sounds pretty rough, what I'm envisioning is having chained-iterators but not sure if something like Type-hints would help us be more specific here and catch errors before runtime

  4. definitely hear you on this one, it's a big concern of mine here as well with functional pipelines. I'm not sure there's a better way unless we reduce the number of primitive operators that users are expected to use. As a thought I'm leaning less towards a streaming API but more towards having fewer low-level operators, and leaning on ParallelMap-type operators that will execute user code.

@bhack
Copy link

bhack commented Sep 21, 2024

It could be nice if with the new design we could achieve the profiling of the whole data provisioning process like nnstreamer:
https://nnstreamer.github.io/tools/profiling/README.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants