Skip to content
3 changes: 2 additions & 1 deletion skyrl-train/examples/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from loguru import logger
from skyrl_train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl_train.utils import Timer, normalize_advantages_dict
from skyrl_train.utils import Timer
from skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl_train.training_batch import TrainingInputBatch
from skyrl_train.generators.base import GeneratorOutput
from skyrl_train.utils.trainer_utils import ResumeMode
Expand Down
3 changes: 3 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ trainer:
type: null # filter, replace, or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)


gradient_checkpointing: true
gradient_checkpointing_use_reentrant: false
Expand Down Expand Up @@ -158,6 +159,8 @@ generator:
num_inference_engines: 1
backend: "vllm"
weight_sync_backend: "nccl"
# if using cuda_ipc, we send in batches of this size in GB
weight_transfer_threshold_cuda_ipc_GB: 1.0
inference_engine_tensor_parallel_size: 4
n_samples_per_prompt: 5
async_engine: true
Expand Down
12 changes: 6 additions & 6 deletions skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class InferenceEngineOutput(TypedDict):
response_logprobs: Optional[List[List[float]]]


class NamedWeightUpdateRequest(TypedDict):
name: str
dtype: str
shape: List[int]
extras: Optional[Dict[str, Any]]
class NamedWeightsUpdateRequest(TypedDict):
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
extras: Optional[List[Dict[str, Any]]]


class InferenceEngineInterface(ABC):
Expand All @@ -48,7 +48,7 @@ async def init_weight_update_communicator(
raise NotImplementedError()

@abstractmethod
async def update_named_weight(self, request: NamedWeightUpdateRequest):
async def update_named_weights(self, request: NamedWeightsUpdateRequest):
raise NotImplementedError()

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
InferenceEngineInterface,
InferenceEngineInput,
InferenceEngineOutput,
NamedWeightUpdateRequest,
NamedWeightsUpdateRequest,
)
import asyncio
from typing import List, Any, Optional
Expand Down Expand Up @@ -96,9 +96,8 @@ async def _generate_with_trajectory_routing(self, prompts, prompt_token_ids, tra
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]

# something went wrong
if any([len(response) == 0 for response in responses]) or not all(
[isinstance(sample_ids, list) for sample_ids in response_ids]
if any([len(response) == 0 for response in responses]) or (
add_resp_ids and not all([isinstance(sample_ids, list) for sample_ids in response_ids])
):
raise RuntimeError(
"Did not receive responses / response ids for some prompts. This should never happen. There is likely something wrong with the inference engine"
Expand Down Expand Up @@ -192,8 +191,8 @@ async def init_weight_update_communicator(
rank_offset_count += engine.tp_size
await asyncio.gather(*tasks)

async def update_named_weight(self, request: NamedWeightUpdateRequest):
return await self._run_on_all_engines("update_named_weight", request=request)
async def update_named_weights(self, request: NamedWeightsUpdateRequest):
return await self._run_on_all_engines("update_named_weights", request=request)

async def reset_prefix_cache(self):
return await self._run_on_all_engines("reset_prefix_cache")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
InferenceEngineInterface,
InferenceEngineInput,
InferenceEngineOutput,
NamedWeightUpdateRequest,
NamedWeightsUpdateRequest,
)


Expand Down Expand Up @@ -41,8 +41,8 @@ async def init_weight_update_communicator(
master_addr, master_port, rank_offset, world_size, group_name, backend, override_existing
)

async def update_named_weight(self, request: NamedWeightUpdateRequest):
return await self.inference_engine_actor.update_named_weight.remote(request)
async def update_named_weights(self, request: NamedWeightsUpdateRequest):
return await self.inference_engine_actor.update_named_weights.remote(request)

async def teardown(self):
return await self.inference_engine_actor.teardown.remote()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
InferenceEngineInterface,
InferenceEngineInput,
InferenceEngineOutput,
NamedWeightUpdateRequest,
NamedWeightsUpdateRequest,
)
from typing import List, Optional, Dict, Any
import json
Expand Down Expand Up @@ -103,25 +103,36 @@ async def init_weight_update_communicator(
) as response:
return await response.json()

async def update_named_weight(self, request: NamedWeightUpdateRequest):
if request.get("extras") and "ipc_handles" in request["extras"]:
async def update_named_weights(self, request: NamedWeightsUpdateRequest):
if "names" not in request:
raise ValueError(f"Expected update weight request with 'names' entry, got keys: {request.keys()}")

assert (
len(request["names"]) == 1
), f"Remote inference engines support only requests with a single named weight at a time , got request with {len(request['names'])} entries"

if request.get("extras") and "ipc_handles" in request["extras"][0]:
raise ValueError(
"Remote inference engines do not support CUDA IPC weight updates. Only local engines support IPC."
)
if self.engine_backend == "vllm":
weight_update_method = "update_weight"
weight_update_method = "update_weights"
elif self.engine_backend == "sglang":
weight_update_method = "update_weights_from_distributed"
else:
raise ValueError(f"Invalid engine backend: {self.engine_backend}")

async with aiohttp.ClientSession() as session:
name = request["names"][0]
dtype = request["dtypes"][0]
shape = request["shapes"][0]

resp = await session.post(
f"{self.url}/{weight_update_method}",
json={
"name": request["name"],
"dtype": request["dtype"],
"shape": request["shape"],
"name": name,
"dtype": dtype,
"shape": shape,
},
)
return await resp.json()
Expand Down
72 changes: 42 additions & 30 deletions skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
InferenceEngineInterface,
InferenceEngineInput,
InferenceEngineOutput,
NamedWeightUpdateRequest,
NamedWeightsUpdateRequest,
)
from skyrl_train.utils import torch_dtype_to_str

Expand Down Expand Up @@ -104,12 +104,12 @@ def setup_envvars_for_sglang(kwargs, bundle_indices):
os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0])


def update_weight_cuda_ipc(model, named_tensors):
def update_weights_cuda_ipc(model, named_tensors):
"""
Custom weight loader for SGLang that handles IPC handles.

This function is called by SGLang's model runner to load weights.
It reconstructs tensors from SkyRL's NamedWeightUpdateRequest that contains IPC handles
It reconstructs tensors from SkyRL's NamedWeightsUpdateRequest that contains IPC handles
and loads them into the model.
"""
import torch
Expand All @@ -128,36 +128,40 @@ def update_weight_cuda_ipc(model, named_tensors):
request_data = tensor_bytes[:end_index]
try:
request_data_decoded = base64.b64decode(request_data)
request = pickle.loads(request_data_decoded)
request: NamedWeightsUpdateRequest = pickle.loads(request_data_decoded)
except Exception as e:
raise ValueError(f"Failed to deserialize request data: {e}")

# Extract the request data
ipc_handles = request["extras"]["ipc_handles"]
dtype = request["dtype"]
_ = request["shape"]
weight_name = request["name"]
weights_to_load = []
for i in range(len(request["names"])):
# Extract the request data
ipc_handles = request["extras"][i]["ipc_handles"]
dtype = request["dtypes"][i]
_ = request["shapes"][i]
weight_name = request["names"][i]

device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
physical_gpu_id = str(props.uuid)
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
physical_gpu_id = str(props.uuid)

# Infer model dtype and device index from first parameter
model_dtype = torch_dtype_to_str(next(model.parameters()).dtype)
assert dtype == model_dtype, f"mismatch dtype: src {dtype}, dst {model_dtype}"
device_id = next(model.parameters()).device.index
# Infer model dtype and device index from first parameter
model_dtype = torch_dtype_to_str(next(model.parameters()).dtype)
assert dtype == model_dtype, f"mismatch dtype: src {dtype}, dst {model_dtype}"
device_id = next(model.parameters()).device.index

handle = ipc_handles[physical_gpu_id]
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
weight = func(*list_args)
model.load_weights([(weight_name, weight)])
handle = ipc_handles[physical_gpu_id]
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
weight = func(*list_args)
weights_to_load.append((weight_name, weight))

model.load_weights(weights_to_load)

CUSTOM_WEIGHT_LOADER_PATH = "skyrl_train.inference_engines.sglang.sglang_engine.update_weight_cuda_ipc"

CUSTOM_WEIGHT_LOADER_PATH = "skyrl_train.inference_engines.sglang.sglang_engine.update_weights_cuda_ipc"


class SGLangInferenceEngine(InferenceEngineInterface):
Expand Down Expand Up @@ -256,12 +260,15 @@ async def init_weight_update_communicator(
success, message = await self.engine.tokenizer_manager.init_weights_update_group(obj, None)
return success, message

async def update_named_weight(self, request: NamedWeightUpdateRequest) -> Tuple[bool, str]:
async def update_named_weights(self, request: NamedWeightsUpdateRequest) -> Tuple[bool, str]:
"""Update named weights in SGLang engine."""
if "names" not in request:
raise ValueError(f"Expected update weight request with 'names' entry, got keys: {request.keys()}")

extras = request.get("extras")
if extras is not None and "ipc_handles" in extras:
if extras is not None and "ipc_handles" in extras[0]:
# CUDA IPC -- Here we reuse SGLang's update_weights_from_tensor, but actually load the
# weight from our request data. This will use the update_weight_cuda_ipc defined above.
# weight from our request data. This will use the update_weights_cuda_ipc defined above.
# This is a bit hacky, but the only way as of now, since there is no other way to
# write per-TP worker code besides using `custom_weight_loader`, unlike in vLLM we can
# use `WorkerWrap`.
Expand Down Expand Up @@ -293,14 +300,19 @@ async def update_named_weight(self, request: NamedWeightUpdateRequest) -> Tuple[
success, message = await self.engine.tokenizer_manager.update_weights_from_tensor(obj, None)
return success, message
else:
assert (
len(request["names"]) == 1
), f"Update weights without cuda IPC only supports a single named weight at a time , got request with {len(request['names'])} entries"
# Broadcast
obj = UpdateWeightsFromDistributedReqInput(
name=request["name"], dtype=request["dtype"], shape=request["shape"]
name=request["names"][0], dtype=request["dtypes"][0], shape=request["shapes"][0]
)

# Call the underlying async method for the same reason as in `init_weight_update_communicator`
success, message = await self.engine.tokenizer_manager.update_weights_from_distributed(obj, None)
return success, message
if not success:
raise RuntimeError(f"Update weight request failed with message: {message}")
return
Comment on lines +313 to +315
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function update_named_weights is type-hinted to return Tuple[bool, str], but this path returns None on success, which violates the function's contract. The IPC path correctly returns a tuple. To be consistent and correct, this should return a tuple on success, for example (True, "").

Suggested change
if not success:
raise RuntimeError(f"Update weight request failed with message: {message}")
return
if not success:
raise RuntimeError(f"Update weight request failed with message: {message}")
return True, ""


async def wake_up(self, tags: Optional[List[str]] = None):
"""Wake up the engine. For multi-stage waking up, pass in `"weight"` or `"kv_cache"` to tags."""
Expand Down
Loading