Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[submodule "3rdparty/NeMo"]
path = 3rdparty/NeMo-workspace/NeMo
url = https://github.com/NVIDIA/NeMo.git
branch = ashors/nemorl-qwen3
branch = zhiyul/yukih/prepare-refit-info
shallow = true
[submodule "3rdparty/Megatron-LM"]
path = 3rdparty/Megatron-LM-workspace/Megatron-LM
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/NeMo-workspace/NeMo
Submodule NeMo updated from 33259f to 8ddf43
83 changes: 46 additions & 37 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast

Expand Down Expand Up @@ -400,6 +401,7 @@ def refit_policy_generation(
policy_generation: GenerationInterface,
colocated_inference: bool,
_refit_buffer_size_gb: Optional[int] = None,
timer: Optional[Timer] = None,
) -> None:
"""Refit the policy generation interface with the latest policy weights.

Expand All @@ -414,43 +416,50 @@ def refit_policy_generation(
policy.offload_before_refit()
policy_generation.prepare_for_generation(tags=["weights"])

# update weights
update_success = False
if colocated_inference:
# get model param keys, which is grouped by size
grouped_param_keys = policy.prepare_weights_for_ipc(
_refit_buffer_size_gb=_refit_buffer_size_gb
)
total_num_keys = sum(len(k) for k in grouped_param_keys)
print(
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
)
# do update
for keys in grouped_param_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
update_success = policy_generation.update_weights_from_ipc_handles(
ipc_handles
# Create a context manager that does nothing when timer is None
timer_context = (
timer.time("prepare_for_generation/transfer_and_update_weights")
if timer is not None
else nullcontext()
)
with timer_context:
# update weights
update_success = False
if colocated_inference:
# get model param keys, which is grouped by size
grouped_param_keys = policy.prepare_weights_for_ipc(
_refit_buffer_size_gb=_refit_buffer_size_gb
)
if not update_success:
break
else:
# update weights through nccl
futures_train = policy.broadcast_weights_for_collective()
futures_inference = policy_generation.update_weights_from_collective()
# wait for all futures to complete
ray.get(futures_train)
results = ray.get(futures_inference)
update_success = all(result for result in results if result is not None)

# check if update is successful
if not update_success:
error_tag = "cuda-ipc" if colocated_inference else "nccl"
error_message = (
"❌ Error: Updating weights for the generation policy failed during refit.\n"
f"This often indicates an issue with {error_tag} or "
"a problem within the generation backend (e.g., vLLM worker).\n"
)
raise RuntimeError(error_message)
total_num_keys = sum(len(k) for k in grouped_param_keys)
print(
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
)
# do update
for keys in grouped_param_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
update_success = policy_generation.update_weights_from_ipc_handles(
ipc_handles
)
if not update_success:
break
else:
# update weights through nccl
futures_train = policy.broadcast_weights_for_collective()
futures_inference = policy_generation.update_weights_from_collective()
# wait for all futures to complete
ray.get(futures_train)
results = ray.get(futures_inference)
update_success = all(result for result in results if result is not None)

# check if update is successful
if not update_success:
error_tag = "cuda-ipc" if colocated_inference else "nccl"
error_message = (
"❌ Error: Updating weights for the generation policy failed during refit.\n"
f"This often indicates an issue with {error_tag} or "
"a problem within the generation backend (e.g., vLLM worker).\n"
)
raise RuntimeError(error_message)

if colocated_inference:
policy.offload_after_refit()
Expand Down Expand Up @@ -544,7 +553,7 @@ def grpo_train(
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
policy, policy_generation, colocated_inference, timer=timer
)
POLICY_GENERATION_STALE = False
else:
Expand Down
47 changes: 28 additions & 19 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from typing import Any, Iterable, Optional

import torch
from torch.multiprocessing.reductions import rebuild_cuda_tensor

try:
import vllm # noqa: F401
Expand Down Expand Up @@ -136,7 +138,7 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
try:
is_tensor_packed = local_device_ipc_handles[0]
if is_tensor_packed:
_, all_handles, tensor_metadata = local_device_ipc_handles
_, all_handles, list_keys = local_device_ipc_handles
else:
_, name_and_handle_list = local_device_ipc_handles

Expand All @@ -152,33 +154,40 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
# Extract packed tensor from IPC handle
dtype_to_packed_tensor = {}
for dtype, tensor_handle in all_handles:
func, args = tensor_handle
func = rebuild_cuda_tensor
args = tensor_handle[0]
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
dtype_to_packed_tensor[dtype] = tensor

# Unpack tensor to weights. Here we only return a view of the tensor to avoid
# using extra memory.
for key, metadata in tensor_metadata.items():
# dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias)
if isinstance(metadata, tuple):
# use dtype of current step
offset, dtype = metadata
shape, _, size = self.state_dict_info[key]
# update record
self.state_dict_info[key] = (shape, dtype, size)
else:
offset = metadata
shape, dtype, size = self.state_dict_info[key]
tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view(
*shape
weights = []
dtype_to_offset = defaultdict(lambda: 0)
for key in list_keys:
shape, dtype, size = self.state_dict_info[key]
weights.append(
(
key,
dtype_to_packed_tensor[dtype][
dtype_to_offset[dtype] : dtype_to_offset[dtype] + size
].view(*shape),
)
)
weights.append((key, tensor))
dtype_to_offset[dtype] += size

expected_sizes = {
dtype: tensor.numel()
for dtype, tensor in dtype_to_packed_tensor.items()
}
assert dtype_to_offset == expected_sizes, (
f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. "
f"This indicates the keys list order doesn't match the order used when packing tensors."
)
else:
# Process each handle to get the tensor
for name, handle in name_and_handle_list:
func, args = handle
func = rebuild_cuda_tensor
args = handle[0]
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
Expand Down
1 change: 0 additions & 1 deletion nemo_rl/models/megatron/refit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def gather_params(model, keys: list[str], key_to_global_keys: dict[str, list[str
if k is not None:
gathered_params[k] = p

print(f"Time taken to gather params: {time.perf_counter() - st}")
return gathered_params


Expand Down
5 changes: 2 additions & 3 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from nemo_rl.models.policy.utils import (
configure_expandable_segments,
get_gpu_info,
get_handle_from_tensor,
get_runtime_env_for_policy_worker,
import_class_from_path,
is_vllm_v1_engine_enabled,
Expand Down Expand Up @@ -1186,8 +1187,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:

@torch.no_grad()
def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
from torch.multiprocessing.reductions import reduce_tensor

assert self._held_sharded_state_dict_reference is not None, (
"prepare_weights_for_ipc must be called before get_weights_ipc_handles"
)
Expand Down Expand Up @@ -1217,7 +1216,7 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
# Create handles for the tensors
all_handles = []
for key, p in converted_params.items():
handle = reduce_tensor(p.detach())
handle = get_handle_from_tensor(p)
all_handles.append((key, handle))

# (pack_tensor_for_ipc: bool, handles: list)
Expand Down
Loading