You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/reference/collectors.rst
+242Lines changed: 242 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -459,6 +459,248 @@ transformed, and applied, ensuring seamless integration with their existing infr
459
459
RPCWeightUpdater
460
460
DistributedWeightUpdater
461
461
462
+
Weight Synchronization API
463
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
464
+
465
+
The weight synchronization API provides a simple, modular approach to updating model weights across
466
+
distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple
467
+
models may need to be synchronized independently.
468
+
469
+
Overview
470
+
^^^^^^^^
471
+
472
+
In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference
473
+
policies synchronized with the latest trained weights. The API addresses this challenge through a clean
474
+
separation of concerns, where four classes are involved:
475
+
476
+
- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is
477
+
your main entrypoint to configure the weight synchronization.
478
+
- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers.
479
+
- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes.
480
+
- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC)
481
+
482
+
The following diagram shows the different classes involved in the weight synchronization process:
483
+
484
+
.. aafig::
485
+
:aspect: 60
486
+
:scale: 130
487
+
:proportional:
488
+
489
+
INITIALIZATION PHASE
490
+
====================
491
+
492
+
WeightSyncScheme
493
+
+------------------+
494
+
||
495
+
| Configuration: |
496
+
| - strategy |
497
+
| - transport_type |
498
+
||
499
+
+--------+---------+
500
+
|
501
+
+------------+-------------+
502
+
||
503
+
creates creates
504
+
||
505
+
v v
506
+
Main Process Worker Process
507
+
+--------------++---------------+
508
+
| WeightSender || WeightReceiver|
509
+
||||
510
+
| - strategy || - strategy |
511
+
| - transports || - transport |
512
+
| - model_ref || - model_ref |
513
+
||||
514
+
| Registers: || Registers: |
515
+
| - model || - model |
516
+
| - workers || - transport |
517
+
+--------------++---------------+
518
+
||
519
+
| Transport Layer |
520
+
|+----------------+|
521
+
+-->+ MPTransport |<------+
522
+
|| (pipes) ||
523
+
|+----------------+|
524
+
|+----------------+|
525
+
+-->+ SharedMemTrans |<------+
526
+
|| (shared mem) ||
527
+
|+----------------+|
528
+
|+----------------+|
529
+
+-->+ RayTransport |<------+
530
+
| (Ray store) |
531
+
+----------------+
532
+
533
+
534
+
SYNCHRONIZATION PHASE
535
+
=====================
536
+
537
+
Main Process Worker Process
538
+
539
+
+-------------------+ +-------------------+
540
+
|WeightSender | | WeightReceiver |
541
+
|| | |
542
+
|1. Extract | | 4. Poll transport |
543
+
|weights from | | for weights |
544
+
|model using | | |
545
+
|strategy | | |
546
+
|| 2. Send via | |
547
+
|+-------------+ | Transport | +--------------+ |
548
+
|| Strategy | | +------------+ | | Strategy | |
549
+
|| extract() | | | | | | apply() | |
550
+
|+-------------+ +----+ Transport +-------->+ +--------------+ |
0 commit comments