@@ -633,7 +633,7 @@ Collectors
633633.. _Collectors :
634634
635635TorchRL offers specialized collector classes (:class: `~torchrl.collectors.llm.LLMCollector ` and :class: `~torchrl.collectors.llm.RayLLMCollector `)
636- that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.
636+ that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines.
637637
638638See :ref: `ref_collectors ` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline
639639in a dedicated class.
@@ -649,8 +649,126 @@ Collectors are defined by the following parameters and features:
649649 In other cases, the collector can be iterated over to collect data.
650650- **Steps **: A collector is built with a certain number of steps budget, as well as a number of steps to be
651651 included in each batch yield during collection.
652- - **Weight Updater **: Weight updaters are the classes that update the policy weights. Isolating the weight update
653- in a dedicated class allows to easily implement different weight update strategies depending on the policy specification.
652+ - **Weight Synchronization Schemes **: Weight sync schemes handle the synchronization of weights between the training model
653+ and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and
654+ other inference backends.
655+
656+ vLLM Weight Synchronization Schemes
657+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658+
659+ TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between
660+ performance and simplicity:
661+
662+ **1. NCCL-Based Synchronization ** (:class: `~torchrl.weight_update.llm.VLLMWeightSyncScheme `)
663+
664+ Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for:
665+
666+ - High-frequency weight updates
667+ - Large models where transfer speed is critical
668+ - Setups with GPU interconnect (NVLink, InfiniBand)
669+
670+ **2. Double-Buffer Synchronization ** (:class: `~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme `)
671+
672+ Uses memory-mapped file storage for asynchronous weight transfers. Best for:
673+
674+ - Simpler setup without NCCL coordination
675+ - Distributed setups with shared filesystems (NFS)
676+ - Cases where update frequency is lower
677+
678+ **Usage Example with NCCL: **
679+
680+ .. code-block :: python
681+
682+ from torchrl.collectors.llm import RayLLMCollector
683+ from torchrl.weight_update.llm import VLLMWeightSyncScheme
684+ from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
685+
686+ # Create vLLM engine
687+ vllm_engine = AsyncVLLM.from_pretrained(
688+ " Qwen/Qwen2.5-7B" ,
689+ num_devices = 2 ,
690+ num_replicas = 2 ,
691+ )
692+ policy = vLLMWrapper(vllm_engine, input_mode = " history" )
693+
694+ # Create NCCL weight sync scheme
695+ weight_sync_scheme = VLLMWeightSyncScheme(
696+ master_address = " localhost" ,
697+ master_port = 29500 ,
698+ gpus_per_replica = 2 , # tp_size × dp_size × pp_size
699+ num_replicas = 2 ,
700+ strategy = " state_dict"
701+ )
702+
703+ # Create collector with weight sync scheme
704+ collector = RayLLMCollector(
705+ env = make_env,
706+ policy = policy,
707+ dialog_turns_per_batch = 256 ,
708+ total_dialog_turns = 10000 ,
709+ weight_sync_schemes = {" policy" : weight_sync_scheme},
710+ track_policy_version = True ,
711+ )
712+
713+ # During training, get the sender and update weights
714+ sender = collector._weight_senders[" policy" ]
715+ sender.register_model(training_model)
716+
717+ # Initialize collective group (must be called before first update)
718+ metadata = get_model_metadata(training_model)
719+ sender.init_all_workers_group(metadata, vllm_engine = vllm_engine)
720+
721+ # Update weights during training
722+ for i, data in enumerate (collector):
723+ # ... training step ...
724+ if i % 10 == 0 :
725+ sender.update_weights() # Broadcasts via NCCL
726+
727+ **Usage Example with Double-Buffer: **
728+
729+ .. code-block :: python
730+
731+ from torchrl.collectors.llm import RayLLMCollector
732+ from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme
733+ from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
734+
735+ # Create vLLM engine
736+ vllm_engine = AsyncVLLM.from_pretrained(
737+ " Qwen/Qwen2.5-7B" ,
738+ num_devices = 2 ,
739+ num_replicas = 1 ,
740+ )
741+ policy = vLLMWrapper(vllm_engine, input_mode = " history" )
742+
743+ # Create double-buffer weight sync scheme
744+ weight_sync_scheme = VLLMDoubleBufferSyncScheme(
745+ remote_addr = " /tmp/weights" , # Or "/mnt/shared/weights" for NFS
746+ num_threads = 128 ,
747+ strategy = " state_dict"
748+ )
749+
750+ # Create collector with weight sync scheme
751+ collector = RayLLMCollector(
752+ env = make_env,
753+ policy = policy,
754+ dialog_turns_per_batch = 256 ,
755+ total_dialog_turns = 10000 ,
756+ weight_sync_schemes = {" policy" : weight_sync_scheme},
757+ track_policy_version = True ,
758+ )
759+
760+ # During training, get the sender and receiver
761+ sender = collector._weight_senders[" policy" ]
762+ sender.register_model(training_model)
763+
764+ # No initialization needed for double-buffer scheme!
765+
766+ # Update weights during training
767+ for i, data in enumerate (collector):
768+ # ... training step ...
769+ if i % 10 == 0 :
770+ sender.update_weights() # Writes to shared storage
771+ # vLLM workers can poll and apply: receiver.poll_and_apply()
654772
655773 Policy Version Tracking
656774~~~~~~~~~~~~~~~~~~~~~~~
@@ -662,19 +780,52 @@ transform, or a boolean to the collector constructor.
662780
663781 >>> from torchrl.envs.llm.transforms import PolicyVersion
664782 >>> from torchrl.collectors.llm import LLMCollector
665- >>> from torchrl.collectors .llm.weight_update import vLLMUpdater
783+ >>> from torchrl.weight_update .llm import VLLMWeightSyncScheme, get_model_metadata
666784 >>> env = make_env() # place your code here
667785 >>> policy = make_policy() # place your code here
668- >>> collector = LLMCollector(env, policy = policy, weight_updater = vLLMUpdater(), track_policy_version = True )
669- >>> # init the updater
670- >>> collector.weight_updater.init(... )
671- >>> # the version is incremented after each weight update
672- >>> collector.update_policy_weights_(state_dict = ... )
786+ >>> scheme = VLLMWeightSyncScheme(master_port = 29500 , gpus_per_replica = 1 , num_replicas = 1 )
787+ >>> collector = LLMCollector(env, policy = policy, weight_sync_schemes = {" policy" : scheme}, track_policy_version = True )
788+ >>> # Get the sender and register model
789+ >>> sender = collector._weight_senders[" policy" ]
790+ >>> sender.register_model(training_model)
791+ >>> # Initialize the collective group
792+ >>> metadata = get_model_metadata(training_model)
793+ >>> sender.init_all_workers_group(metadata, vllm_engine = policy.model)
794+ >>> # Update weights
795+ >>> sender.update_weights()
673796 >>> print (collector.policy_version_tracker.version)
674797 >>> # the policy version is written in the data
675798 >>> for data in collector:
676799 ... print (data[" policy_version" ])
677800
801+ .. currentmodule :: torchrl.weight_update.llm
802+
803+ .. autosummary ::
804+ :toctree: generated/
805+ :template: rl_template.rst
806+
807+ VLLMWeightSyncScheme
808+ VLLMWeightSender
809+ VLLMWeightReceiver
810+ VLLMCollectiveTransport
811+ VLLMDoubleBufferSyncScheme
812+ VLLMDoubleBufferWeightSender
813+ VLLMDoubleBufferWeightReceiver
814+ VLLMDoubleBufferTransport
815+ get_model_metadata
816+
817+ Legacy Weight Updaters (Deprecated)
818+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
819+
820+ .. deprecated :: 0.11
821+ The `vLLMUpdater ` and `vLLMUpdaterV2 ` classes are deprecated in favor of the new weight synchronization schemes
822+ (:class: `~torchrl.weight_update.llm.VLLMWeightSyncScheme ` and :class: `~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme `).
823+ These schemes provide better performance, more flexibility, and cleaner integration with collectors.
824+ The legacy updaters will be removed in a future release.
825+
826+ The legacy weight updaters (`vLLMUpdater ` and `vLLMUpdaterV2 `) are still available but are no longer recommended.
827+ Please migrate to the new weight synchronization schemes shown above.
828+
678829.. currentmodule :: torchrl.collectors.llm
679830
680831.. autosummary ::
0 commit comments