Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,40 @@ def update_weights_from_ipc_handles(self, ipc_handles):
# Get handles for this device
device_uuid = self.report_device_id()
handles = ipc_handles[device_uuid]
is_tensor_packed = handles[0]
if is_tensor_packed:
_, all_handles, tensor_metadata = handles
else:
_, name_and_handle_list = handles

device_id = self.device.index
weights = []

# Process each handle to get the tensor
for name, handle in handles:
func, args = handle
list_args = list(args)
# Update device ID to match the current device
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
if is_tensor_packed:
# Extract packed tensor from IPC handle
dtype_to_packed_tensor = {}
for dtype, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
dtype_to_packed_tensor[dtype] = tensor

# Unpack tensor to weights. Here we only return a view of the tensor to avoid
# using extra memory.
for key, (shape, dtype, offset, size) in tensor_metadata.items():
tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view(
*shape
)
weights.append((key, tensor))
else:
# Process each handle to get the tensor
for name, handle in name_and_handle_list:
func, args = handle
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))

# Load weights into the model
self.model_runner.model.load_weights(weights=weights)
Expand Down
5 changes: 4 additions & 1 deletion nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,10 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
handle = reduce_tensor(p.detach())
all_handles.append((key, handle))

return {device_uuid: all_handles}
# (pack_tensor_for_ipc: bool, handles: list)
serialized = (False, all_handles)

return {device_uuid: serialized}

@torch.no_grad()
def prepare_info_for_collective(self) -> dict[str, Any]:
Expand Down
5 changes: 4 additions & 1 deletion nemo_rl/models/policy/fsdp1_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,10 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
handle = reduce_tensor(p.detach())
all_handles.append((key, handle))

return {device_uuid: all_handles}
# (pack_tensor_for_ipc: bool, handles: list)
serialized = (False, all_handles)

return {device_uuid: serialized}

def prepare_for_lp_inference(self) -> None:
self.model = self.manual_load_to_gpu(self.model)
Expand Down
74 changes: 62 additions & 12 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,18 +1338,68 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:
from torch.multiprocessing.reductions import reduce_tensor

# Create IPC handles for each parameter
all_handles = []
for key, tensor in gathered_hf_params.items():
handle = reduce_tensor(tensor.detach())
all_handles.append((key, handle))

# Store references to avoid premature garbage collection
self._held_gather_buffer = gathered_hf_params
shapes = {}
for key, tensor in gathered_hf_params.items():
shapes[key] = tensor.shape

return {device_uuid: all_handles}
tensor_number_threshold = os.getenv(
"NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32"
) # an arbitrary threshold
if len(gathered_hf_params) >= int(tensor_number_threshold):
pack_tensor_for_ipc = True
else:
pack_tensor_for_ipc = False

if pack_tensor_for_ipc:
# Pack tensors in gathered_hf_params into consolidated tensors by dtype
# First calculate total size needed for each dtype
type_to_total_size = defaultdict(lambda: 0)
tensor_metadata = dict()

for key, tensor in gathered_hf_params.items():
tensor_metadata[key] = (
tensor.shape, # shape of the tensor
tensor.dtype, # dtype of the tensor
type_to_total_size[tensor.dtype], # offset of the tensor
# in packed buffer
tensor.numel(), # size of the tensor
)
type_to_total_size[tensor.dtype] += tensor.numel()

# Allocate consolidated tensors for each dtype
packed_tensors = {
dtype: torch.empty(
total_size,
device=next(iter(gathered_hf_params.values())).device,
dtype=dtype,
requires_grad=False,
)
for dtype, total_size in type_to_total_size.items()
}

# Copy tensors into consolidated buffers
for key, tensor in gathered_hf_params.items():
metadata = tensor_metadata[key]
_, dtype, offset, size = metadata
packed_tensors[dtype][offset : offset + size].copy_(
tensor.detach().view(-1)
)

# Create IPC handles for consolidated tensors
all_handles = [
(dtype, reduce_tensor(tensor.detach()))
for dtype, tensor in packed_tensors.items()
]

# Store reference to prevent garbage collection
self._held_gather_buffer = packed_tensors

serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata)
else:
all_handles = []
for key, tensor in gathered_hf_params.items():
handle = reduce_tensor(tensor.detach())
all_handles.append((key, handle))
self._held_gather_buffer = gathered_hf_params
serialized = (False, all_handles)

return {device_uuid: serialized}

def prepare_for_lp_inference(self):
self.model = self.move_model(self.model, "cuda", move_grads=False)
Expand Down
Loading