Skip to content

Commit a7707ca

Browse files
committed
[Feature] Collectors - Weight Sync Scheme Integration
ghstack-source-id: c1cc5c4 Pull-Request: #3187
1 parent d0c8b7e commit a7707ca

File tree

15 files changed

+1727
-400
lines changed

15 files changed

+1727
-400
lines changed

docs/source/reference/llms.rst

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ Collectors
633633
.. _Collectors:
634634

635635
TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`)
636-
that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.
636+
that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines.
637637

638638
See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline
639639
in a dedicated class.
@@ -649,8 +649,126 @@ Collectors are defined by the following parameters and features:
649649
In other cases, the collector can be iterated over to collect data.
650650
- **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be
651651
included in each batch yield during collection.
652-
- **Weight Updater**: Weight updaters are the classes that update the policy weights. Isolating the weight update
653-
in a dedicated class allows to easily implement different weight update strategies depending on the policy specification.
652+
- **Weight Synchronization Schemes**: Weight sync schemes handle the synchronization of weights between the training model
653+
and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and
654+
other inference backends.
655+
656+
vLLM Weight Synchronization Schemes
657+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658+
659+
TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between
660+
performance and simplicity:
661+
662+
**1. NCCL-Based Synchronization** (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme`)
663+
664+
Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for:
665+
666+
- High-frequency weight updates
667+
- Large models where transfer speed is critical
668+
- Setups with GPU interconnect (NVLink, InfiniBand)
669+
670+
**2. Double-Buffer Synchronization** (:class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`)
671+
672+
Uses memory-mapped file storage for asynchronous weight transfers. Best for:
673+
674+
- Simpler setup without NCCL coordination
675+
- Distributed setups with shared filesystems (NFS)
676+
- Cases where update frequency is lower
677+
678+
**Usage Example with NCCL:**
679+
680+
.. code-block:: python
681+
682+
from torchrl.collectors.llm import RayLLMCollector
683+
from torchrl.weight_update.llm import VLLMWeightSyncScheme
684+
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
685+
686+
# Create vLLM engine
687+
vllm_engine = AsyncVLLM.from_pretrained(
688+
"Qwen/Qwen2.5-7B",
689+
num_devices=2,
690+
num_replicas=2,
691+
)
692+
policy = vLLMWrapper(vllm_engine, input_mode="history")
693+
694+
# Create NCCL weight sync scheme
695+
weight_sync_scheme = VLLMWeightSyncScheme(
696+
master_address="localhost",
697+
master_port=29500,
698+
gpus_per_replica=2, # tp_size × dp_size × pp_size
699+
num_replicas=2,
700+
strategy="state_dict"
701+
)
702+
703+
# Create collector with weight sync scheme
704+
collector = RayLLMCollector(
705+
env=make_env,
706+
policy=policy,
707+
dialog_turns_per_batch=256,
708+
total_dialog_turns=10000,
709+
weight_sync_schemes={"policy": weight_sync_scheme},
710+
track_policy_version=True,
711+
)
712+
713+
# During training, get the sender and update weights
714+
sender = collector._weight_senders["policy"]
715+
sender.register_model(training_model)
716+
717+
# Initialize collective group (must be called before first update)
718+
metadata = get_model_metadata(training_model)
719+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
720+
721+
# Update weights during training
722+
for i, data in enumerate(collector):
723+
# ... training step ...
724+
if i % 10 == 0:
725+
sender.update_weights() # Broadcasts via NCCL
726+
727+
**Usage Example with Double-Buffer:**
728+
729+
.. code-block:: python
730+
731+
from torchrl.collectors.llm import RayLLMCollector
732+
from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme
733+
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
734+
735+
# Create vLLM engine
736+
vllm_engine = AsyncVLLM.from_pretrained(
737+
"Qwen/Qwen2.5-7B",
738+
num_devices=2,
739+
num_replicas=1,
740+
)
741+
policy = vLLMWrapper(vllm_engine, input_mode="history")
742+
743+
# Create double-buffer weight sync scheme
744+
weight_sync_scheme = VLLMDoubleBufferSyncScheme(
745+
remote_addr="/tmp/weights", # Or "/mnt/shared/weights" for NFS
746+
num_threads=128,
747+
strategy="state_dict"
748+
)
749+
750+
# Create collector with weight sync scheme
751+
collector = RayLLMCollector(
752+
env=make_env,
753+
policy=policy,
754+
dialog_turns_per_batch=256,
755+
total_dialog_turns=10000,
756+
weight_sync_schemes={"policy": weight_sync_scheme},
757+
track_policy_version=True,
758+
)
759+
760+
# During training, get the sender and receiver
761+
sender = collector._weight_senders["policy"]
762+
sender.register_model(training_model)
763+
764+
# No initialization needed for double-buffer scheme!
765+
766+
# Update weights during training
767+
for i, data in enumerate(collector):
768+
# ... training step ...
769+
if i % 10 == 0:
770+
sender.update_weights() # Writes to shared storage
771+
# vLLM workers can poll and apply: receiver.poll_and_apply()
654772
655773
Policy Version Tracking
656774
~~~~~~~~~~~~~~~~~~~~~~~
@@ -662,19 +780,52 @@ transform, or a boolean to the collector constructor.
662780

663781
>>> from torchrl.envs.llm.transforms import PolicyVersion
664782
>>> from torchrl.collectors.llm import LLMCollector
665-
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
783+
>>> from torchrl.weight_update.llm import VLLMWeightSyncScheme, get_model_metadata
666784
>>> env = make_env() # place your code here
667785
>>> policy = make_policy() # place your code here
668-
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
669-
>>> # init the updater
670-
>>> collector.weight_updater.init(...)
671-
>>> # the version is incremented after each weight update
672-
>>> collector.update_policy_weights_(state_dict=...)
786+
>>> scheme = VLLMWeightSyncScheme(master_port=29500, gpus_per_replica=1, num_replicas=1)
787+
>>> collector = LLMCollector(env, policy=policy, weight_sync_schemes={"policy": scheme}, track_policy_version=True)
788+
>>> # Get the sender and register model
789+
>>> sender = collector._weight_senders["policy"]
790+
>>> sender.register_model(training_model)
791+
>>> # Initialize the collective group
792+
>>> metadata = get_model_metadata(training_model)
793+
>>> sender.init_all_workers_group(metadata, vllm_engine=policy.model)
794+
>>> # Update weights
795+
>>> sender.update_weights()
673796
>>> print(collector.policy_version_tracker.version)
674797
>>> # the policy version is written in the data
675798
>>> for data in collector:
676799
... print(data["policy_version"])
677800

801+
.. currentmodule:: torchrl.weight_update.llm
802+
803+
.. autosummary::
804+
:toctree: generated/
805+
:template: rl_template.rst
806+
807+
VLLMWeightSyncScheme
808+
VLLMWeightSender
809+
VLLMWeightReceiver
810+
VLLMCollectiveTransport
811+
VLLMDoubleBufferSyncScheme
812+
VLLMDoubleBufferWeightSender
813+
VLLMDoubleBufferWeightReceiver
814+
VLLMDoubleBufferTransport
815+
get_model_metadata
816+
817+
Legacy Weight Updaters (Deprecated)
818+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
819+
820+
.. deprecated:: 0.11
821+
The `vLLMUpdater` and `vLLMUpdaterV2` classes are deprecated in favor of the new weight synchronization schemes
822+
(:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme` and :class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`).
823+
These schemes provide better performance, more flexibility, and cleaner integration with collectors.
824+
The legacy updaters will be removed in a future release.
825+
826+
The legacy weight updaters (`vLLMUpdater` and `vLLMUpdaterV2`) are still available but are no longer recommended.
827+
Please migrate to the new weight synchronization schemes shown above.
828+
678829
.. currentmodule:: torchrl.collectors.llm
679830

680831
.. autosummary::
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Example of updating weights of several models at once in a multiprocessed data collector.
2+
3+
This example demonstrates:
4+
1. Using different weight sync schemes for different models
5+
2. Updating the policy (via pipes with MultiProcessWeightSyncScheme)
6+
3. Updating Ray-based transforms in env and replay buffer (via RayModuleTransformScheme)
7+
4. Atomic multi-model weight updates using weights_dict
8+
9+
Note:
10+
- Ray actors are shared across all workers, so RayModuleTransformScheme uses a
11+
single transport rather than per-worker pipes.
12+
- When using transform_factory with a replay buffer, delayed_init automatically defaults
13+
to True for proper serialization in multiprocessing contexts.
14+
- extend_buffer defaults to True in all collectors, extending the buffer with entire
15+
rollouts rather than individual frames for better compatibility with postprocessing.
16+
"""
17+
18+
from functools import partial
19+
20+
import torch.nn as nn
21+
from tensordict import TensorDict
22+
from tensordict.nn import TensorDictModule
23+
24+
from torchrl.collectors import MultiSyncDataCollector
25+
from torchrl.data import LazyTensorStorage, ReplayBuffer
26+
from torchrl.envs.libs.gym import GymEnv
27+
from torchrl.envs.transforms.module import ModuleTransform
28+
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
29+
30+
31+
def make_module():
32+
# A module that transforms the observations
33+
return TensorDictModule(
34+
nn.Linear(3, 3), in_keys=["observation"], out_keys=["observation"]
35+
)
36+
37+
38+
def policy_factory():
39+
# A module that produces the actions
40+
return TensorDictModule(
41+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
42+
)
43+
44+
45+
def make_env():
46+
env_module = ModuleTransform(
47+
module_factory=make_module, inverse=False, no_grad=True
48+
)
49+
return GymEnv("Pendulum-v1").append_transform(env_module)
50+
51+
52+
def main():
53+
rb = ReplayBuffer(
54+
storage=LazyTensorStorage(10000, shared_init=True),
55+
transform_factory=partial(
56+
ModuleTransform,
57+
module_factory=make_module,
58+
inverse=True,
59+
no_grad=True,
60+
),
61+
# delayed_init automatically defaults to True when transform_factory is provided
62+
)
63+
64+
policy = policy_factory()
65+
66+
weight_sync_schemes = {
67+
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
68+
"replay_buffer.transform[0].module": MultiProcessWeightSyncScheme(
69+
strategy="tensordict"
70+
),
71+
"env.transform[0].module": MultiProcessWeightSyncScheme(strategy="tensordict"),
72+
}
73+
74+
collector = MultiSyncDataCollector(
75+
create_env_fn=[make_env, make_env],
76+
policy_factory=policy_factory,
77+
total_frames=2000,
78+
max_frames_per_traj=50,
79+
frames_per_batch=200,
80+
init_random_frames=-1,
81+
device="cpu",
82+
storing_device="cpu",
83+
weight_sync_schemes=weight_sync_schemes,
84+
replay_buffer=rb,
85+
local_init_rb=True,
86+
# extend_buffer=True is the default for MultiSyncDataCollector
87+
)
88+
89+
policy_weights = TensorDict.from_module(policy).data
90+
env_module_weights = TensorDict.from_module(make_module()).data
91+
rb_module_weights = TensorDict.from_module(make_module()).data
92+
93+
for i, _data in enumerate(collector):
94+
env_module_weights.zero_()
95+
rb_module_weights.zero_()
96+
policy_weights.zero_()
97+
98+
collector.update_policy_weights_(
99+
weights_dict={
100+
"policy": policy_weights,
101+
"env.transform[0].module": env_module_weights,
102+
"replay_buffer.transform[0].module": rb_module_weights,
103+
}
104+
)
105+
106+
assert len(rb) == i * 200 + 200
107+
108+
if i >= 10:
109+
break
110+
111+
collector.shutdown()
112+
113+
114+
if __name__ == "__main__":
115+
main()

sota-implementations/expert-iteration/ei_utils.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from torch import device as torch_device, dtype as torch_dtype
1616

1717
from torchrl._utils import logger as torchrl_logger
18-
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
1918
from torchrl.envs.llm import RetrieveLogProb
2019
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
2120
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
21+
from torchrl.weight_update.llm import VLLMWeightSyncScheme
2222
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
2323
from transformers.tokenization_utils import PreTrainedTokenizer
2424

@@ -479,42 +479,40 @@ def get_hf_model(
479479
torch.set_default_dtype(original_dtype)
480480

481481

482-
def make_weight_updater(
483-
policy_training=None,
482+
def make_weight_sync_scheme(
484483
master_address=None,
485484
master_port=None,
486-
model_metadata=None,
487-
vllm_tp_size=None,
488-
) -> vLLMUpdater:
489-
"""Creates a vLLM weight updater for the policy.
485+
vllm_tp_size=1,
486+
) -> VLLMWeightSyncScheme:
487+
"""Creates a vLLM weight synchronization scheme using NCCL collectives.
490488
491-
This function can be used in two ways:
492-
1. Synchronous mode (expert-iteration-sync.py): Pass policy_training to get an initialized updater with metadata
493-
2. Async mode (expert-iteration-async.py): Pass master_address, master_port, model_metadata, and remote_actor
489+
This function creates a weight sync scheme that uses NCCL for high-performance
490+
GPU-to-GPU weight transfers from the training model to vLLM inference workers.
494491
495492
Args:
496-
policy_training (Optional[TransformersWrapper]): The training policy model. Required for sync mode.
497-
master_address (Optional[str]): Ray master address for async mode.
498-
master_port (Optional[int]): Ray master port for async mode.
499-
model_metadata (Optional[dict]): Model metadata for async mode. If not provided but policy_training is,
500-
it will be extracted from the policy.
501-
vllm_tp_size (Optional[int]): vLLM tensor parallel size. If not provided, will be set to 1.
493+
master_address (Optional[str]): Address of the master node for distributed init.
494+
Defaults to "localhost".
495+
master_port (Optional[int]): Port of the master node for distributed init.
496+
If None, will auto-assign.
497+
vllm_tp_size (int): vLLM tensor parallel size (gpus_per_replica). Defaults to 1.
502498
503499
Returns:
504-
vLLMUpdater: An instance of the weight updater configured to update
505-
the vLLM worker's weights.
500+
VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
506501
"""
507-
if model_metadata is None and policy_training is not None:
508-
# Extract metadata from training policy
509-
model_metadata = {
510-
k: (v.dtype, v.shape) for k, v in policy_training.model.state_dict().items()
511-
}
502+
if master_address is None:
503+
master_address = "localhost"
504+
505+
torchrl_logger.info(
506+
f"Creating VLLMWeightSyncScheme with tp_size={vllm_tp_size}, "
507+
f"master_address={master_address}, master_port={master_port}"
508+
)
512509

513-
return vLLMUpdater(
510+
return VLLMWeightSyncScheme(
514511
master_address=master_address,
515512
master_port=master_port,
516-
model_metadata=model_metadata,
517-
vllm_tp_size=vllm_tp_size,
513+
gpus_per_replica=vllm_tp_size,
514+
num_replicas=1, # For expert iteration, typically 1 replica
515+
strategy="state_dict",
518516
)
519517

520518

0 commit comments

Comments
 (0)