Skip to content

Commit 4a27b91

Browse files
committed
[Feature] Documentation
ghstack-source-id: f17f26c Pull-Request: #3192
1 parent ffa464a commit 4a27b91

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

docs/source/reference/collectors.rst

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,248 @@ transformed, and applied, ensuring seamless integration with their existing infr
459459
RPCWeightUpdater
460460
DistributedWeightUpdater
461461

462+
Weight Synchronization API
463+
~~~~~~~~~~~~~~~~~~~~~~~~~~
464+
465+
The weight synchronization API provides a simple, modular approach to updating model weights across
466+
distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple
467+
models may need to be synchronized independently.
468+
469+
Overview
470+
^^^^^^^^
471+
472+
In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference
473+
policies synchronized with the latest trained weights. The API addresses this challenge through a clean
474+
separation of concerns, where four classes are involved:
475+
476+
- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is
477+
your main entrypoint to configure the weight synchronization.
478+
- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers.
479+
- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes.
480+
- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC)
481+
482+
The following diagram shows the different classes involved in the weight synchronization process:
483+
484+
.. aafig::
485+
:aspect: 60
486+
:scale: 130
487+
:proportional:
488+
489+
INITIALIZATION PHASE
490+
====================
491+
492+
WeightSyncScheme
493+
+------------------+
494+
| |
495+
| Configuration: |
496+
| - strategy |
497+
| - transport_type |
498+
| |
499+
+--------+---------+
500+
|
501+
+------------+-------------+
502+
| |
503+
creates creates
504+
| |
505+
v v
506+
Main Process Worker Process
507+
+--------------+ +---------------+
508+
| WeightSender | | WeightReceiver|
509+
| | | |
510+
| - strategy | | - strategy |
511+
| - transports | | - transport |
512+
| - model_ref | | - model_ref |
513+
| | | |
514+
| Registers: | | Registers: |
515+
| - model | | - model |
516+
| - workers | | - transport |
517+
+--------------+ +---------------+
518+
| |
519+
| Transport Layer |
520+
| +----------------+ |
521+
+-->+ MPTransport |<------+
522+
| | (pipes) | |
523+
| +----------------+ |
524+
| +----------------+ |
525+
+-->+ SharedMemTrans |<------+
526+
| | (shared mem) | |
527+
| +----------------+ |
528+
| +----------------+ |
529+
+-->+ RayTransport |<------+
530+
| (Ray store) |
531+
+----------------+
532+
533+
534+
SYNCHRONIZATION PHASE
535+
=====================
536+
537+
Main Process Worker Process
538+
539+
+-------------------+ +-------------------+
540+
| WeightSender | | WeightReceiver |
541+
| | | |
542+
| 1. Extract | | 4. Poll transport |
543+
| weights from | | for weights |
544+
| model using | | |
545+
| strategy | | |
546+
| | 2. Send via | |
547+
| +-------------+ | Transport | +--------------+ |
548+
| | Strategy | | +------------+ | | Strategy | |
549+
| | extract() | | | | | | apply() | |
550+
| +-------------+ +----+ Transport +-------->+ +--------------+ |
551+
| | | | | | | |
552+
| v | +------------+ | v |
553+
| +-------------+ | | +--------------+ |
554+
| | Model | | | | Model | |
555+
| | (source) | | 3. Ack (optional) | | (dest) | |
556+
| +-------------+ | <-----------------------+ | +--------------+ |
557+
| | | |
558+
+-------------------+ | 5. Apply weights |
559+
| to model using |
560+
| strategy |
561+
+-------------------+
562+
563+
Key Challenges Addressed
564+
^^^^^^^^^^^^^^^^^^^^^^^^^
565+
566+
Modern RL training often involves multiple models that need independent synchronization:
567+
568+
1. **Multiple Models Per Collector**: A collector might need to update:
569+
570+
- The main policy network
571+
- A value network in a Ray actor within the replay buffer
572+
- Models embedded in the environment itself
573+
- Separate world models or auxiliary networks
574+
575+
2. **Different Update Strategies**: Each model may require different synchronization approaches:
576+
577+
- Full state_dict transfer vs. TensorDict-based updates
578+
- Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
579+
- Varied update frequencies
580+
581+
3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to
582+
specific worker indices, requiring a more flexible update mechanism.
583+
584+
Architecture
585+
^^^^^^^^^^^^
586+
587+
The API follows a scheme-based design where users specify synchronization requirements upfront,
588+
and the collector handles the orchestration transparently:
589+
590+
.. aafig::
591+
:aspect: 60
592+
:scale: 130
593+
:proportional:
594+
595+
Main Process Worker Process 1 Worker Process 2
596+
597+
+-----------------+ +---------------+ +---------------+
598+
| Collector | | Collector | | Collector |
599+
| | | | | |
600+
| Models: | | Models: | | Models: |
601+
| +----------+ | | +--------+ | | +--------+ |
602+
| | Policy A | | | |Policy A| | | |Policy A| |
603+
| +----------+ | | +--------+ | | +--------+ |
604+
| +----------+ | | +--------+ | | +--------+ |
605+
| | Model B | | | |Model B| | | |Model B| |
606+
| +----------+ | | +--------+ | | +--------+ |
607+
| | | | | |
608+
| Weight Senders: | | Weight | | Weight |
609+
| +----------+ | | Receivers: | | Receivers: |
610+
| | Sender A +---+------------+->Receiver A | | Receiver A |
611+
| +----------+ | | | | |
612+
| +----------+ | | +--------+ | | +--------+ |
613+
| | Sender B +---+------------+->Receiver B | | Receiver B |
614+
| +----------+ | Pipes | | Pipes | |
615+
+-----------------+ +-------+-------+ +-------+-------+
616+
^ ^ ^
617+
| | |
618+
| update_policy_weights_() | Apply weights |
619+
| | |
620+
+------+-------+ | |
621+
| User Code | | |
622+
| (Training) | | |
623+
+--------------+ +------------------------+
624+
625+
The weight synchronization flow:
626+
627+
1. **Initialization**: User creates ``weight_sync_schemes`` dict mapping model IDs to schemes
628+
2. **Registration**: Collector creates ``WeightSender`` for each model in the main process
629+
3. **Worker Setup**: Each worker creates corresponding ``WeightReceiver`` instances
630+
4. **Synchronization**: Calling ``update_policy_weights_()`` triggers all senders to push weights
631+
5. **Application**: Receivers automatically apply weights to their registered models
632+
633+
Available Classes
634+
^^^^^^^^^^^^^^^^^
635+
636+
**Synchronization Schemes** (User-Facing Configuration):
637+
638+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme`: Base class for schemes
639+
- :class:`~torchrl.weight_update.weight_sync_schemes.MultiProcessWeightSyncScheme`: For multiprocessing with pipes
640+
- :class:`~torchrl.weight_update.weight_sync_schemes.SharedMemWeightSyncScheme`: For shared memory synchronization
641+
- :class:`~torchrl.weight_update.weight_sync_schemes.RayWeightSyncScheme`: For Ray-based distribution
642+
- :class:`~torchrl.weight_update.weight_sync_schemes.NoWeightSyncScheme`: Dummy scheme for no synchronization
643+
644+
**Internal Classes** (Automatically Managed):
645+
646+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender`: Sends weights to all workers for one model
647+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver`: Receives and applies weights in worker
648+
- :class:`~torchrl.weight_update.weight_sync_schemes.TransportBackend`: Communication layer abstraction
649+
650+
Usage Example
651+
^^^^^^^^^^^^^
652+
653+
.. code-block:: python
654+
655+
from torchrl.collectors import MultiSyncDataCollector
656+
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
657+
658+
# Define synchronization for multiple models
659+
weight_sync_schemes = {
660+
"policy": MultiProcessWeightSyncScheme(strategy="tensordict"),
661+
"value_net": MultiProcessWeightSyncScheme(strategy="state_dict"),
662+
}
663+
664+
collector = MultiSyncDataCollector(
665+
create_env_fn=[make_env] * 4,
666+
policy=policy,
667+
frames_per_batch=1000,
668+
weight_sync_schemes=weight_sync_schemes, # Pass schemes dict
669+
)
670+
671+
# Single call updates all registered models across all workers
672+
for i, batch in enumerate(collector):
673+
# Training step
674+
loss = train(batch)
675+
676+
# Sync all models with one call
677+
collector.update_policy_weights_(policy)
678+
679+
The collector automatically:
680+
681+
- Creates ``WeightSender`` instances in the main process for each model
682+
- Creates ``WeightReceiver`` instances in each worker process
683+
- Resolves models by ID (e.g., ``"policy"`` → ``collector.policy``)
684+
- Handles transport setup and communication
685+
- Applies weights using the appropriate strategy (state_dict vs tensordict)
686+
687+
API Reference
688+
^^^^^^^^^^^^^
689+
690+
.. currentmodule:: torchrl.weight_update.weight_sync_schemes
691+
692+
.. autosummary::
693+
:toctree: generated/
694+
:template: rl_template.rst
695+
696+
WeightSyncScheme
697+
MultiProcessWeightSyncScheme
698+
SharedMemWeightSyncScheme
699+
RayWeightSyncScheme
700+
NoWeightSyncScheme
701+
WeightSender
702+
WeightReceiver
703+
462704
Collectors and replay buffers interoperability
463705
----------------------------------------------
464706

0 commit comments

Comments
 (0)