Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 160 additions & 9 deletions docs/source/reference/llms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ Collectors
.. _Collectors:

TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`)
that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.
that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines.

See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline
in a dedicated class.
Expand All @@ -649,8 +649,126 @@ Collectors are defined by the following parameters and features:
In other cases, the collector can be iterated over to collect data.
- **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be
included in each batch yield during collection.
- **Weight Updater**: Weight updaters are the classes that update the policy weights. Isolating the weight update
in a dedicated class allows to easily implement different weight update strategies depending on the policy specification.
- **Weight Synchronization Schemes**: Weight sync schemes handle the synchronization of weights between the training model
and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and
other inference backends.

vLLM Weight Synchronization Schemes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between
performance and simplicity:

**1. NCCL-Based Synchronization** (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme`)

Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for:

- High-frequency weight updates
- Large models where transfer speed is critical
- Setups with GPU interconnect (NVLink, InfiniBand)

**2. Double-Buffer Synchronization** (:class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`)

Uses memory-mapped file storage for asynchronous weight transfers. Best for:

- Simpler setup without NCCL coordination
- Distributed setups with shared filesystems (NFS)
- Cases where update frequency is lower

**Usage Example with NCCL:**

.. code-block:: python

from torchrl.collectors.llm import RayLLMCollector
from torchrl.weight_update.llm import VLLMWeightSyncScheme
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper

# Create vLLM engine
vllm_engine = AsyncVLLM.from_pretrained(
"Qwen/Qwen2.5-7B",
num_devices=2,
num_replicas=2,
)
policy = vLLMWrapper(vllm_engine, input_mode="history")

# Create NCCL weight sync scheme
weight_sync_scheme = VLLMWeightSyncScheme(
master_address="localhost",
master_port=29500,
gpus_per_replica=2, # tp_size × dp_size × pp_size
num_replicas=2,
strategy="state_dict"
)

# Create collector with weight sync scheme
collector = RayLLMCollector(
env=make_env,
policy=policy,
dialog_turns_per_batch=256,
total_dialog_turns=10000,
weight_sync_schemes={"policy": weight_sync_scheme},
track_policy_version=True,
)

# During training, get the sender and update weights
sender = collector._weight_senders["policy"]
sender.register_model(training_model)

# Initialize collective group (must be called before first update)
metadata = get_model_metadata(training_model)
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)

# Update weights during training
for i, data in enumerate(collector):
# ... training step ...
if i % 10 == 0:
sender.update_weights() # Broadcasts via NCCL

**Usage Example with Double-Buffer:**

.. code-block:: python

from torchrl.collectors.llm import RayLLMCollector
from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper

# Create vLLM engine
vllm_engine = AsyncVLLM.from_pretrained(
"Qwen/Qwen2.5-7B",
num_devices=2,
num_replicas=1,
)
policy = vLLMWrapper(vllm_engine, input_mode="history")

# Create double-buffer weight sync scheme
weight_sync_scheme = VLLMDoubleBufferSyncScheme(
remote_addr="/tmp/weights", # Or "/mnt/shared/weights" for NFS
num_threads=128,
strategy="state_dict"
)

# Create collector with weight sync scheme
collector = RayLLMCollector(
env=make_env,
policy=policy,
dialog_turns_per_batch=256,
total_dialog_turns=10000,
weight_sync_schemes={"policy": weight_sync_scheme},
track_policy_version=True,
)

# During training, get the sender and receiver
sender = collector._weight_senders["policy"]
sender.register_model(training_model)

# No initialization needed for double-buffer scheme!

# Update weights during training
for i, data in enumerate(collector):
# ... training step ...
if i % 10 == 0:
sender.update_weights() # Writes to shared storage
# vLLM workers can poll and apply: receiver.poll_and_apply()

Policy Version Tracking
~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -662,19 +780,52 @@ transform, or a boolean to the collector constructor.

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> from torchrl.weight_update.llm import VLLMWeightSyncScheme, get_model_metadata
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> scheme = VLLMWeightSyncScheme(master_port=29500, gpus_per_replica=1, num_replicas=1)
>>> collector = LLMCollector(env, policy=policy, weight_sync_schemes={"policy": scheme}, track_policy_version=True)
>>> # Get the sender and register model
>>> sender = collector._weight_senders["policy"]
>>> sender.register_model(training_model)
>>> # Initialize the collective group
>>> metadata = get_model_metadata(training_model)
>>> sender.init_all_workers_group(metadata, vllm_engine=policy.model)
>>> # Update weights
>>> sender.update_weights()
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
... print(data["policy_version"])

.. currentmodule:: torchrl.weight_update.llm

.. autosummary::
:toctree: generated/
:template: rl_template.rst

VLLMWeightSyncScheme
VLLMWeightSender
VLLMWeightReceiver
VLLMCollectiveTransport
VLLMDoubleBufferSyncScheme
VLLMDoubleBufferWeightSender
VLLMDoubleBufferWeightReceiver
VLLMDoubleBufferTransport
get_model_metadata

Legacy Weight Updaters (Deprecated)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. deprecated:: 0.11
The `vLLMUpdater` and `vLLMUpdaterV2` classes are deprecated in favor of the new weight synchronization schemes
(:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme` and :class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`).
These schemes provide better performance, more flexibility, and cleaner integration with collectors.
The legacy updaters will be removed in a future release.

The legacy weight updaters (`vLLMUpdater` and `vLLMUpdaterV2`) are still available but are no longer recommended.
Please migrate to the new weight synchronization schemes shown above.

.. currentmodule:: torchrl.collectors.llm

.. autosummary::
Expand Down
115 changes: 115 additions & 0 deletions examples/collectors/multi_weight_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Example of updating weights of several models at once in a multiprocessed data collector.

This example demonstrates:
1. Using different weight sync schemes for different models
2. Updating the policy (via pipes with MultiProcessWeightSyncScheme)
3. Updating Ray-based transforms in env and replay buffer (via RayModuleTransformScheme)
4. Atomic multi-model weight updates using weights_dict

Note:
- Ray actors are shared across all workers, so RayModuleTransformScheme uses a
single transport rather than per-worker pipes.
- When using transform_factory with a replay buffer, delayed_init automatically defaults
to True for proper serialization in multiprocessing contexts.
- extend_buffer defaults to True in all collectors, extending the buffer with entire
rollouts rather than individual frames for better compatibility with postprocessing.
"""

from functools import partial

import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule

from torchrl.collectors import MultiSyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms.module import ModuleTransform
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme


def make_module():
# A module that transforms the observations
return TensorDictModule(
nn.Linear(3, 3), in_keys=["observation"], out_keys=["observation"]
)


def policy_factory():
# A module that produces the actions
return TensorDictModule(
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
)


def make_env():
env_module = ModuleTransform(
module_factory=make_module, inverse=False, no_grad=True
)
return GymEnv("Pendulum-v1").append_transform(env_module)


def main():
rb = ReplayBuffer(
storage=LazyTensorStorage(10000, shared_init=True),
transform_factory=partial(
ModuleTransform,
module_factory=make_module,
inverse=True,
no_grad=True,
),
# delayed_init automatically defaults to True when transform_factory is provided
)

policy = policy_factory()

weight_sync_schemes = {
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
"replay_buffer.transform[0].module": MultiProcessWeightSyncScheme(
strategy="tensordict"
),
"env.transform[0].module": MultiProcessWeightSyncScheme(strategy="tensordict"),
}

collector = MultiSyncDataCollector(
create_env_fn=[make_env, make_env],
policy_factory=policy_factory,
total_frames=2000,
max_frames_per_traj=50,
frames_per_batch=200,
init_random_frames=-1,
device="cpu",
storing_device="cpu",
weight_sync_schemes=weight_sync_schemes,
replay_buffer=rb,
local_init_rb=True,
# extend_buffer=True is the default for MultiSyncDataCollector
)

policy_weights = TensorDict.from_module(policy).data
env_module_weights = TensorDict.from_module(make_module()).data
rb_module_weights = TensorDict.from_module(make_module()).data

for i, _data in enumerate(collector):
env_module_weights.zero_()
rb_module_weights.zero_()
policy_weights.zero_()

collector.update_policy_weights_(
weights_dict={
"policy": policy_weights,
"env.transform[0].module": env_module_weights,
"replay_buffer.transform[0].module": rb_module_weights,
}
)

assert len(rb) == i * 200 + 200

if i >= 10:
break

collector.shutdown()


if __name__ == "__main__":
main()
50 changes: 24 additions & 26 deletions sota-implementations/expert-iteration/ei_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from torch import device as torch_device, dtype as torch_dtype

from torchrl._utils import logger as torchrl_logger
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
from torchrl.envs.llm import RetrieveLogProb
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.weight_update.llm import VLLMWeightSyncScheme
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.tokenization_utils import PreTrainedTokenizer

Expand Down Expand Up @@ -479,42 +479,40 @@ def get_hf_model(
torch.set_default_dtype(original_dtype)


def make_weight_updater(
policy_training=None,
def make_weight_sync_scheme(
master_address=None,
master_port=None,
model_metadata=None,
vllm_tp_size=None,
) -> vLLMUpdater:
"""Creates a vLLM weight updater for the policy.
vllm_tp_size=1,
) -> VLLMWeightSyncScheme:
"""Creates a vLLM weight synchronization scheme using NCCL collectives.

This function can be used in two ways:
1. Synchronous mode (expert-iteration-sync.py): Pass policy_training to get an initialized updater with metadata
2. Async mode (expert-iteration-async.py): Pass master_address, master_port, model_metadata, and remote_actor
This function creates a weight sync scheme that uses NCCL for high-performance
GPU-to-GPU weight transfers from the training model to vLLM inference workers.

Args:
policy_training (Optional[TransformersWrapper]): The training policy model. Required for sync mode.
master_address (Optional[str]): Ray master address for async mode.
master_port (Optional[int]): Ray master port for async mode.
model_metadata (Optional[dict]): Model metadata for async mode. If not provided but policy_training is,
it will be extracted from the policy.
vllm_tp_size (Optional[int]): vLLM tensor parallel size. If not provided, will be set to 1.
master_address (Optional[str]): Address of the master node for distributed init.
Defaults to "localhost".
master_port (Optional[int]): Port of the master node for distributed init.
If None, will auto-assign.
vllm_tp_size (int): vLLM tensor parallel size (gpus_per_replica). Defaults to 1.

Returns:
vLLMUpdater: An instance of the weight updater configured to update
the vLLM worker's weights.
VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
"""
if model_metadata is None and policy_training is not None:
# Extract metadata from training policy
model_metadata = {
k: (v.dtype, v.shape) for k, v in policy_training.model.state_dict().items()
}
if master_address is None:
master_address = "localhost"

torchrl_logger.info(
f"Creating VLLMWeightSyncScheme with tp_size={vllm_tp_size}, "
f"master_address={master_address}, master_port={master_port}"
)

return vLLMUpdater(
return VLLMWeightSyncScheme(
master_address=master_address,
master_port=master_port,
model_metadata=model_metadata,
vllm_tp_size=vllm_tp_size,
gpus_per_replica=vllm_tp_size,
num_replicas=1, # For expert iteration, typically 1 replica
strategy="state_dict",
)


Expand Down
Loading
Loading