diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index c40aea4418..8fe204e21f 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -59,17 +59,40 @@ def update_weights_from_ipc_handles(self, ipc_handles): # Get handles for this device device_uuid = self.report_device_id() handles = ipc_handles[device_uuid] + is_tensor_packed = handles[0] + if is_tensor_packed: + _, all_handles, tensor_metadata = handles + else: + _, name_and_handle_list = handles + device_id = self.device.index weights = [] - # Process each handle to get the tensor - for name, handle in handles: - func, args = handle - list_args = list(args) - # Update device ID to match the current device - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) + if is_tensor_packed: + # Extract packed tensor from IPC handle + dtype_to_packed_tensor = {} + for dtype, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = device_id + tensor = func(*list_args) + dtype_to_packed_tensor[dtype] = tensor + + # Unpack tensor to weights. Here we only return a view of the tensor to avoid + # using extra memory. + for key, (shape, dtype, offset, size) in tensor_metadata.items(): + tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view( + *shape + ) + weights.append((key, tensor)) + else: + # Process each handle to get the tensor + for name, handle in name_and_handle_list: + func, args = handle + list_args = list(args) + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) # Load weights into the model self.model_runner.model.load_weights(weights=weights) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 6872250d10..51c88bccf6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -943,7 +943,10 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: handle = reduce_tensor(p.detach()) all_handles.append((key, handle)) - return {device_uuid: all_handles} + # (pack_tensor_for_ipc: bool, handles: list) + serialized = (False, all_handles) + + return {device_uuid: serialized} @torch.no_grad() def prepare_info_for_collective(self) -> dict[str, Any]: diff --git a/nemo_rl/models/policy/fsdp1_policy_worker.py b/nemo_rl/models/policy/fsdp1_policy_worker.py index f4ec53daa0..418f280e46 100644 --- a/nemo_rl/models/policy/fsdp1_policy_worker.py +++ b/nemo_rl/models/policy/fsdp1_policy_worker.py @@ -901,7 +901,10 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: handle = reduce_tensor(p.detach()) all_handles.append((key, handle)) - return {device_uuid: all_handles} + # (pack_tensor_for_ipc: bool, handles: list) + serialized = (False, all_handles) + + return {device_uuid: serialized} def prepare_for_lp_inference(self) -> None: self.model = self.manual_load_to_gpu(self.model) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9113723af0..1764aeff89 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1338,18 +1338,68 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter - all_handles = [] - for key, tensor in gathered_hf_params.items(): - handle = reduce_tensor(tensor.detach()) - all_handles.append((key, handle)) - - # Store references to avoid premature garbage collection - self._held_gather_buffer = gathered_hf_params - shapes = {} - for key, tensor in gathered_hf_params.items(): - shapes[key] = tensor.shape - - return {device_uuid: all_handles} + tensor_number_threshold = os.getenv( + "NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32" + ) # an arbitrary threshold + if len(gathered_hf_params) >= int(tensor_number_threshold): + pack_tensor_for_ipc = True + else: + pack_tensor_for_ipc = False + + if pack_tensor_for_ipc: + # Pack tensors in gathered_hf_params into consolidated tensors by dtype + # First calculate total size needed for each dtype + type_to_total_size = defaultdict(lambda: 0) + tensor_metadata = dict() + + for key, tensor in gathered_hf_params.items(): + tensor_metadata[key] = ( + tensor.shape, # shape of the tensor + tensor.dtype, # dtype of the tensor + type_to_total_size[tensor.dtype], # offset of the tensor + # in packed buffer + tensor.numel(), # size of the tensor + ) + type_to_total_size[tensor.dtype] += tensor.numel() + + # Allocate consolidated tensors for each dtype + packed_tensors = { + dtype: torch.empty( + total_size, + device=next(iter(gathered_hf_params.values())).device, + dtype=dtype, + requires_grad=False, + ) + for dtype, total_size in type_to_total_size.items() + } + + # Copy tensors into consolidated buffers + for key, tensor in gathered_hf_params.items(): + metadata = tensor_metadata[key] + _, dtype, offset, size = metadata + packed_tensors[dtype][offset : offset + size].copy_( + tensor.detach().view(-1) + ) + + # Create IPC handles for consolidated tensors + all_handles = [ + (dtype, reduce_tensor(tensor.detach())) + for dtype, tensor in packed_tensors.items() + ] + + # Store reference to prevent garbage collection + self._held_gather_buffer = packed_tensors + + serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata) + else: + all_handles = [] + for key, tensor in gathered_hf_params.items(): + handle = reduce_tensor(tensor.detach()) + all_handles.append((key, handle)) + self._held_gather_buffer = gathered_hf_params + serialized = (False, all_handles) + + return {device_uuid: serialized} def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False)