diff --git a/skyrl-train/examples/gsm8k/run_gsm8k.sh b/skyrl-train/examples/gsm8k/run_gsm8k.sh index 139732c4d4..39c3600383 100755 --- a/skyrl-train/examples/gsm8k/run_gsm8k.sh +++ b/skyrl-train/examples/gsm8k/run_gsm8k.sh @@ -12,7 +12,10 @@ DATA_DIR="$HOME/data/gsm8k" NUM_GPUS=4 LOGGER="wandb" # change to "console" to print to stdout -uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ +INFERENCE_BACKEND="vllm" +# INFERENCE_BACKEND="sglang" + +uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ trainer.algorithm.advantage_estimator="grpo" \ @@ -37,7 +40,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.sampling_params.max_generate_length=1024 \ trainer.policy.optimizer_config.lr=1.0e-6 \ trainer.algorithm.use_kl_loss=true \ - generator.backend=vllm \ + generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ generator.weight_sync_backend=nccl \ generator.async_engine=true \ diff --git a/skyrl-train/pyproject.toml b/skyrl-train/pyproject.toml index 1215726d91..a061bcb9c1 100644 --- a/skyrl-train/pyproject.toml +++ b/skyrl-train/pyproject.toml @@ -62,7 +62,7 @@ dev = [ "black==24.10.0", "pytest>=6.2.5", "pytest-asyncio", - "pre-commit" + "pre-commit", ] docs = [ "sphinx>=7.0.0", @@ -81,7 +81,7 @@ vllm = [ "torchvision" ] sglang = [ - "sglang[srt,openai,torch_memory_saver]==0.4.8.post1", + "sglang[srt,openai,torch_memory_saver]==0.4.8.post1", # 0.4.9.post1 causes non-colocate weight broadcast to hang # The version is pinned to 0.2.5 because sglang requires this # NOTE (sumanthrh): This can be made a common dependency, but then different inference engines can pin different compatible flashinfer versions and it might quickly break. "flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", diff --git a/skyrl-train/skyrl_train/entrypoints/main_base.py b/skyrl-train/skyrl_train/entrypoints/main_base.py index ea6e84ae62..12505f789f 100644 --- a/skyrl-train/skyrl_train/entrypoints/main_base.py +++ b/skyrl-train/skyrl_train/entrypoints/main_base.py @@ -51,12 +51,13 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p max_model_len=cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length, shared_pg=colocate_pg, gpu_memory_utilization=cfg.generator.gpu_memory_utilization, - vllm_enable_sleep=cfg.trainer.placement.colocate_all, + inference_engine_enable_sleep=cfg.trainer.placement.colocate_all, async_engine=cfg.generator.async_engine, max_num_batched_tokens=cfg.generator.max_num_batched_tokens, max_num_seqs=cfg.generator.max_num_seqs, sampling_params=get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params), tokenizer=tokenizer, + backend=cfg.generator.backend, ) diff --git a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 442ee3abda..5a64b11fda 100644 --- a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -26,7 +26,7 @@ def tp_size(self): return ray.get(self.inference_engine_actor.tp_size.remote()) async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: - return await self.inference_engine_actor.generate.remote(input_batch) + return await self.inference_engine_actor.generate.remote(input_batch=input_batch) async def wake_up(self, *args: Any, **kwargs: Any): return await self.inference_engine_actor.wake_up.remote(*args, **kwargs) @@ -63,21 +63,30 @@ def create_ray_wrapped_inference_engines( max_model_len: int, shared_pg=None, gpu_memory_utilization=None, - vllm_enable_sleep=False, + inference_engine_enable_sleep=False, async_engine=False, max_num_batched_tokens=8192, max_num_seqs=1024, sampling_params: Optional[Dict[str, Any]] = None, tokenizer=None, + backend="vllm", ) -> List[InferenceEngineInterface]: """ Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances. """ - import vllm - from skyrl_train.inference_engines.vllm.vllm_engine import VLLMRayActor, AsyncVLLMRayActor from skyrl_train.utils import ray_noset_visible_devices, get_all_env_variables, get_ray_pg_ready_with_timeout - assert version.parse(vllm.__version__) >= version.parse("0.8.3"), "SkyRL-Train only supports vLLM >= 0.8.3" + if backend == "vllm": + import vllm + from skyrl_train.inference_engines.vllm.vllm_engine import VLLMRayActor, AsyncVLLMRayActor + + assert version.parse(vllm.__version__) >= version.parse("0.8.3"), "SkyRL-Train only supports vLLM >= 0.8.3" + elif backend == "sglang": + # We import SGLang later to avoid importing vllm. See `get_sglang_engine` for more. + pass + else: + raise ValueError(f"Unsupported backend: {backend}") + inference_engine_actors = [] noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote())) # NOTE: we use the ray backend for tensor parallel size > 1 to explicitly manage resource allocation @@ -107,42 +116,92 @@ def create_ray_wrapped_inference_engines( placement_group_bundle_index=i * tensor_parallel_size, ) - if async_engine: - actor_class = AsyncVLLMRayActor - else: - actor_class = VLLMRayActor - - vllm_engine = actor_class.options( - num_cpus=num_gpus, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - ).remote( - model=pretrain, - enforce_eager=enforce_eager, - worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", - tensor_parallel_size=tensor_parallel_size, - seed=seed + i, - distributed_executor_backend=distributed_executor_backend, - max_model_len=max_model_len, - enable_prefix_caching=enable_prefix_caching, - dtype=model_dtype, - trust_remote_code=True, - vllm_v1_disable_multiproc=vllm_v1_disable_multiproc, - gpu_memory_utilization=gpu_memory_utilization, - bundle_indices=bundle_indices, - num_gpus=0.2 if use_hybrid_engine else 1, - enable_sleep_mode=vllm_enable_sleep, - noset_visible_devices=noset_visible_devices, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - sampling_params=sampling_params, - tokenizer=tokenizer, - ) - inference_engine_actors.append(vllm_engine) + if backend == "vllm": + if async_engine: + actor_class = AsyncVLLMRayActor + else: + actor_class = VLLMRayActor + + engine = actor_class.options( + num_cpus=num_gpus, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote( + model=pretrain, + enforce_eager=enforce_eager, + worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", + tensor_parallel_size=tensor_parallel_size, + seed=seed + i, + distributed_executor_backend=distributed_executor_backend, + max_model_len=max_model_len, + enable_prefix_caching=enable_prefix_caching, + dtype=model_dtype, + trust_remote_code=True, + vllm_v1_disable_multiproc=vllm_v1_disable_multiproc, + gpu_memory_utilization=gpu_memory_utilization, + bundle_indices=bundle_indices, + num_gpus=0.2 if use_hybrid_engine else 1, + enable_sleep_mode=inference_engine_enable_sleep, + noset_visible_devices=noset_visible_devices, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + sampling_params=sampling_params, + tokenizer=tokenizer, + ) + elif backend == "sglang": + # NOTE: there is no async / sync engine distinction in SGLang + + # NOTE(Charlie): We need `torch.cuda.is_available()` to be True to import SGLang. Otherwise, it requires + # importing vllm. See https://github.com/sgl-project/sglang/blob/v0.4.8.post1/python/sglang/srt/layers/quantization/utils.py#L11-L17 + # Similar comment: https://github.com/volcengine/verl/blob/9cc307767b0c787e8f5ef581dac929f7bde044ef/verl/workers/fsdp_workers.py#L520-L527 + @ray.remote + def get_sglang_engine(): + # A workaround to avoid importing vllm is to give this task a GPU. + import os + + before_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + from skyrl_train.inference_engines.sglang.sglang_engine import SGLangRayActor + + os.environ["CUDA_VISIBLE_DEVICES"] = before_cuda_visible_devices + + actor_class = SGLangRayActor + engine = actor_class.options( + num_cpus=num_gpus, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote( + model_path=pretrain, + tp_size=tensor_parallel_size, + mem_fraction_static=gpu_memory_utilization, + random_seed=seed + i, + context_length=max_model_len, + disable_radix_cache=not enable_prefix_caching, + dtype=model_dtype, + trust_remote_code=True, + max_prefill_tokens=max_num_batched_tokens, + max_running_requests=max_num_seqs, + # Borrowed from veRL's SGLang rollout + mm_attention_backend="fa3", + attention_backend="fa3", + enable_memory_saver=inference_engine_enable_sleep, + # Will be popped before instantiating sgl.Engine + distributed_executor_backend=distributed_executor_backend, + noset_visible_devices=noset_visible_devices, + bundle_indices=bundle_indices, + num_gpus=0.2 if use_hybrid_engine else 1, + sampling_params=sampling_params, + tokenizer=tokenizer, + ) + return engine + + engine = ray.get(get_sglang_engine.remote()) + + inference_engine_actors.append(engine) engines = [RayWrappedInferenceEngine(actor_handle) for actor_handle in inference_engine_actors] - if vllm_enable_sleep: + if inference_engine_enable_sleep: sleep_refs = [engine.inference_engine_actor.sleep.remote() for engine in engines] ray.get(sleep_refs) diff --git a/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py b/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py new file mode 100644 index 0000000000..1f1b7c8acc --- /dev/null +++ b/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py @@ -0,0 +1,333 @@ +"""SGLang inference engine implementation.""" + +import pickle +import base64 +import torch +import os +from typing import List, Optional, Tuple +import ray +import multiprocessing as mp + +import sglang.srt.entrypoints.engine +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.utils import ( + assert_pkg_version, + is_cuda, + maybe_set_triton_cache_manager, + set_prometheus_multiproc_dir, + set_ulimit, + MultiprocessingSerializer, +) +from sglang.srt.managers.tokenizer_manager import ( + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromDistributedReqInput, + InitWeightsUpdateGroupReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +) +from skyrl_train.inference_engines.base import ( + InferenceEngineInterface, + InferenceEngineInput, + InferenceEngineOutput, + NamedWeightUpdateRequest, +) +from skyrl_train.utils import torch_dtype_to_str + + +# Patch SGLang's _set_envs_and_config to avoid signal handler issues in Ray actors +# Based on VERL's solution: https://github.com/sgl-project/sglang/issues/6723 +# https://github.com/volcengine/verl/blob/v0.4.1/verl/workers/rollout/sglang_rollout/sglang_rollout.py#L85 +def _patched_set_envs_and_config(server_args): + """Patched version of SGLang's _set_envs_and_config that removes signal handler registration.""" + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = str(int(getattr(server_args, "enable_nccl_nvls", False))) + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer_python", + "0.2.5", + "Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.", + ) + if is_cuda(): + assert_pkg_version( + "sgl-kernel", + "0.1.1", + "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", + ) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + # We do NOT register signal handlers here to avoid Ray actor issues + # Original SGLang code had: signal.signal(signal.SIGCHLD, sigchld_handler) + # But this fails in Ray actors since signal handlers only work in main thread + + +# Apply the patch +sglang.srt.entrypoints.engine._set_envs_and_config = _patched_set_envs_and_config + + +# TODO(charlie): duplicate of setup_envvars_for_vllm, is it needed? +def setup_envvars_for_sglang(kwargs, bundle_indices): + distributed_executor_backend = kwargs.pop("distributed_executor_backend", None) + noset_visible_devices = kwargs.pop("noset_visible_devices", None) + if distributed_executor_backend == "ray": + # a hack to make the script work. + # stop ray from manipulating *_VISIBLE_DEVICES + # at the top-level when the distributed_executor_backend is ray. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ.pop("ROCR_VISIBLE_DEVICES", None) + os.environ.pop("HIP_VISIBLE_DEVICES", None) + pass + elif noset_visible_devices: + # We need to set CUDA_VISIBLE_DEVICES to the ray assigned GPU + # when the distributed_executor_backend is not rayargs and + # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set. + os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) + + +def update_weight_cuda_ipc(model, named_tensors): + """ + Custom weight loader for SGLang that handles IPC handles. + + This function is called by SGLang's model runner to load weights. + It reconstructs tensors from SkyRL's NamedWeightUpdateRequest that contains IPC handles + and loads them into the model. + """ + import torch + + # Extract tensor name and data + name, tensor = named_tensors[0] + if name != "ipc_request": + raise ValueError(f"Expected IPC request tensor name to be 'ipc_request', got: {name}") + + # Convert tensor to bytes, then decode and deserialize + tensor_bytes = tensor.cpu().numpy().tobytes() + end_marker = b"__END_OF_REQUEST__" + end_index = tensor_bytes.find(end_marker) + if end_index == -1: + raise ValueError("End marker not found in tensor data") + request_data = tensor_bytes[:end_index] + try: + request_data_decoded = base64.b64decode(request_data) + request = pickle.loads(request_data_decoded) + except Exception as e: + raise ValueError(f"Failed to deserialize request data: {e}") + + # Extract the request data + ipc_handles = request["extras"]["ipc_handles"] + dtype = request["dtype"] + _ = request["shape"] + weight_name = request["name"] + + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + physical_gpu_id = str(props.uuid) + + # Infer model dtype and device index from first parameter + model_dtype = torch_dtype_to_str(next(model.parameters()).dtype) + assert dtype == model_dtype, f"mismatch dtype: src {dtype}, dst {model_dtype}" + device_id = next(model.parameters()).device.index + + handle = ipc_handles[physical_gpu_id] + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + weight = func(*list_args) + model.load_weights([(weight_name, weight)]) + + +CUSTOM_WEIGHT_LOADER_PATH = "skyrl_train.inference_engines.sglang.sglang_engine.update_weight_cuda_ipc" + + +class SGLangInferenceEngine(InferenceEngineInterface): + """SGLang inference engine that implements InferenceEngineInterface.""" + + def __init__(self, *args, bundle_indices: Optional[List[int]] = None, **kwargs): + setup_envvars_for_sglang(kwargs, bundle_indices) + + # Store common attributes + self._tp_size = kwargs.get("tp_size", 1) + if self._tp_size > 1: + raise ValueError( + "As of now, we don't support tensor parallel inference engine with SGLang. " + "Please set `inference_engine_tensor_parallel_size` to 1." + ) + self.tokenizer = kwargs.pop("tokenizer", None) + + # Extract sampling params + sampling_params_dict = kwargs.pop("sampling_params", None) + self.sampling_params = sampling_params_dict or {} + + # Unused kwargs + _ = kwargs.pop("num_gpus", 1) + + # Add custom weight loader + kwargs["custom_weight_loader"] = CUSTOM_WEIGHT_LOADER_PATH + + # Create the SGLang engine (signal handler issue is now fixed by patching) + self.engine = Engine(**kwargs) + print(f"Created SGLang engine with kwargs: {kwargs}") + + def tp_size(self): + """Return the tensor parallel size.""" + return self._tp_size + + def _preprocess_prompts(self, input_batch: InferenceEngineInput): + """Preprocess prompts for SGLang generation.""" + prompts = input_batch.get("prompts") + prompt_token_ids = input_batch.get("prompt_token_ids") + request_sampling_params = input_batch.get("sampling_params") + + if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None): + raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.") + + # Use request sampling params if provided, otherwise use defaults + sampling_params = request_sampling_params if request_sampling_params is not None else self.sampling_params + + if prompt_token_ids is None: + prompt_token_ids = self.tokenizer.apply_chat_template( + prompts, + add_generation_prompt=True, + add_special_tokens=False, + return_dict=True, + tokenize=True, + )["input_ids"] + + return prompt_token_ids, sampling_params + + def _postprocess_outputs(self, outputs): + """Process SGLang outputs to match expected format.""" + responses: List[str] = [] + stop_reasons: List[str] = [] + + for output in outputs: + responses.append(output["text"]) + stop_reasons.append(output["meta_info"]["finish_reason"]["type"]) + + return InferenceEngineOutput( + responses=responses, + stop_reasons=stop_reasons, + ) + + async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: + """Generate responses using SGLang engine.""" + token_ids_prompts, sampling_params = self._preprocess_prompts(input_batch) + outputs = await self.engine.async_generate(input_ids=token_ids_prompts, sampling_params=sampling_params) + return self._postprocess_outputs(outputs) + + async def init_weight_update_communicator( + self, master_addr, master_port, rank_offset, world_size, group_name, backend, override_existing: bool = False + ): + """Initialize weight update communicator for SGLang.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_addr, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + + # NOTE(charlie): Call the async method on tokenizer_manager directly to avoid event loop + # conflicts. Same underlying implementation: https://github.com/sgl-project/sglang/blob/v0.4.8.post1/python/sglang/srt/model_executor/model_runner.py#L689 + success, message = await self.engine.tokenizer_manager.init_weights_update_group(obj, None) + return success, message + + async def update_named_weight(self, request: NamedWeightUpdateRequest) -> Tuple[bool, str]: + """Update named weights in SGLang engine.""" + extras = request.get("extras") + if extras is not None and "ipc_handles" in extras: + # CUDA IPC -- Here we reuse SGLang's update_weights_from_tensor, but actually load the + # weight from our request data. This will use the update_weight_cuda_ipc defined above. + # This is a bit hacky, but the only way as of now, since there is no other way to + # write per-TP worker code besides using `custom_weight_loader`, unlike in vLLM we can + # use `WorkerWrap`. + + # Serialize the request data + request_data = pickle.dumps(request) + request_data_encoded = base64.b64encode(request_data) + end_marker = b"__END_OF_REQUEST__" + data_with_marker = request_data_encoded + end_marker + + # Create a tensor large enough to hold the serialized data; round up for alignment + data_size = len(data_with_marker) + padded_size = ((data_size + 3) // 4) * 4 + tensor_data = bytearray(data_with_marker) + tensor_data.extend(b"\x00" * (padded_size - data_size)) + tensor_array = torch.frombuffer(tensor_data, dtype=torch.uint8) + + # Use SGLang's API to update weights with custom loader + request_tensor = [("ipc_request", tensor_array)] + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=[ + MultiprocessingSerializer.serialize(request_tensor) for _ in range(self._tp_size) + ], + load_format=CUSTOM_WEIGHT_LOADER_PATH, + flush_cache=False, # TODO(charlie): flush cache on last weight update? + ) + + # Call the underlying async method for the same reason as in `init_weight_update_communicator` + success, message = await self.engine.tokenizer_manager.update_weights_from_tensor(obj, None) + return success, message + else: + # Broadcast + obj = UpdateWeightsFromDistributedReqInput( + name=request["name"], dtype=request["dtype"], shape=request["shape"] + ) + + # Call the underlying async method for the same reason as in `init_weight_update_communicator` + success, message = await self.engine.tokenizer_manager.update_weights_from_distributed(obj, None) + return success, message + + async def wake_up(self, tags: Optional[List[str]] = None): + """Wake up the engine. For multi-stage waking up, pass in `"weight"` or `"kv_cache"` to tags.""" + obj = ResumeMemoryOccupationReqInput(tags=tags) + # Call the underlying async method for the same reason as in `init_weight_update_communicator` + await self.engine.tokenizer_manager.resume_memory_occupation(obj, None) + print( + f"From SGLang engine -- Free GPU memory after wake up with tags {tags if tags is not None else 'None'}: " + + f"{torch.cuda.mem_get_info()[0] / 1024**2:.1f} MB" + ) + + async def sleep(self, tags: Optional[List[str]] = None): + """Put engine to sleep.""" + obj = ReleaseMemoryOccupationReqInput(tags=tags) + # Call the underlying async method for the same reason as in `init_weight_update_communicator` + await self.engine.tokenizer_manager.release_memory_occupation(obj, None) + print( + f"From SGLang engine -- Free GPU memory after sleep with tags {tags if tags is not None else 'None'}: " + + f"{torch.cuda.mem_get_info()[0] / 1024**2:.1f} MB" + ) + + async def teardown(self): + """Shutdown the SGLang engine.""" + self.engine.shutdown() + + async def reset_prefix_cache(self): + """Reset prefix cache in SGLang engine.""" + # Call the underlying async method for the same reason as in `init_weight_update_communicator` + return await self.engine.tokenizer_manager.flush_cache() + + +SGLangRayActor = ray.remote(SGLangInferenceEngine) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 65808e1007..86ff94a220 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -162,6 +162,8 @@ async def eval(self) -> Dict[str, float]: # Extract data_sources from env_extras concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras] + vis = self.tokenizer.decode(generator_output["response_ids"][0]) + print("Eval output example: ", vis) # 2. Group data by data source and calculate per-dataset metrics eval_metrics = calculate_per_dataset_metrics( @@ -208,8 +210,14 @@ def train(self): self.weights_manager = InferenceWeightsManager( self.policy_model, self.inference_engine_client, self.cfg.trainer.placement.colocate_all ) + # NOTE(Charlie): sglang's engine needs to sync weights after wake up. see https://github.com/sgl-project/sglang/issues/7939 + # Change it to True after sglang fixes the issue. + sync_before_eval = self.cfg.trainer.placement.colocate_all and self.cfg.generator.backend == "sglang" self.eval_weights_manager = InferenceWeightsManager( - self.policy_model, self.inference_engine_client, self.cfg.trainer.placement.colocate_all, no_sync=True + self.policy_model, + self.inference_engine_client, + self.cfg.trainer.placement.colocate_all, + no_sync=not sync_before_eval, ) # Load checkpoint state if resumption is enabled diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index d2e7c54333..9310ae1070 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -130,12 +130,10 @@ def validate_cfg(cfg: DictConfig): cfg.generator.remote_inference_engine_urls ), "num_inference_engines should be equal to the number of remote_inference_engine_urls" - if not cfg.generator.async_engine: + if not cfg.generator.async_engine and cfg.generator.backend == "vllm": assert ( cfg.generator.batched - ), "if we are using the offline engine, we need to put generator in batched mode for faster generation" - if cfg.generator.backend == "sglang" and cfg.generator.run_engines_locally: - raise ValueError("SGLang backend currently does not support local engines") + ), "if we are using the offline vLLM engine, we need to put generator in batched mode for faster generation" assert ( cfg.trainer.sequence_parallel_backend == "ulysses" @@ -208,6 +206,13 @@ def validate_cfg(cfg: DictConfig): algorithm_config.max_seq_len = cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length cfg.trainer.algorithm = algorithm_config + # TODO: fix once we support these features with SGLang + if cfg.generator.backend == "sglang" and cfg.generator.run_engines_locally: + assert cfg.generator.inference_engine_tensor_parallel_size == 1, ( + "As of now, We do not support tensor parallel inference engine with SGLang when running engines locally. " + "Please set `inference_engine_tensor_parallel_size` to 1." + ) + if cfg.trainer.strategy == "deepspeed" and not ( cfg.trainer.policy.optimizer_config.offload_after_step and cfg.trainer.critic.optimizer_config.offload_after_step @@ -261,6 +266,13 @@ def get_physical_gpu_id(): def initialize_ray(cfg: DictConfig): # TODO(sumanthrh): introduce a debug mode and add debugging flags like `CUDA_LAUNCH_BLOCKING` here env_vars = {} + + # NOTE (charlie): See https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 + # and https://docs.vllm.ai/en/v0.9.2/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues + # Same for SGLang as we set `NCCL_CUMEM_ENABLE` to 0 in `sglang_engine.py`'s _patched_set_envs_and_config + if cfg.generator.weight_sync_backend == "nccl": + env_vars["NCCL_CUMEM_ENABLE"] = "0" + if cfg.generator.backend == "vllm": # NOTE (sumanthrh): In vllm >= 0.9.0, we need to explicitly allow for serialization via pickle for collective RPCs. # During weight transfer, we use IPC handles, which contains a `function` object and requires pickling. @@ -279,11 +291,6 @@ def initialize_ray(cfg: DictConfig): env_vars["VLLM_USE_V1"] = "1" env_vars["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - # NOTE (charlie): See https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 - # and https://docs.vllm.ai/en/v0.9.2/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues - if cfg.generator.weight_sync_backend == "nccl": - env_vars["NCCL_CUMEM_ENABLE"] = "0" - max_num_gpus_per_node = max( [ cfg.trainer.placement.policy_num_gpus_per_node, diff --git a/skyrl-train/tests/gpu/test_engine_generation.py b/skyrl-train/tests/gpu/test_engine_generation.py index 930b506792..5c7f4d5385 100644 --- a/skyrl-train/tests/gpu/test_engine_generation.py +++ b/skyrl-train/tests/gpu/test_engine_generation.py @@ -1,4 +1,14 @@ +""" +# Run only vllm tests (requires vllm extra): +uv run --isolated --extra dev --extra vllm pytest tests/gpu/test_engine_generation.py -m "vllm" + +# Run only sglang tests (requires sglang extra): +uv run --isolated --extra dev --extra sglang pytest tests/gpu/test_engine_generation.py -m "sglang" +""" + +import pytest import ray +import hydra from skyrl_train.inference_engines.remote_inference_engine import create_remote_inference_engines from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -11,12 +21,23 @@ from omegaconf import DictConfig from skyrl_train.inference_engines.base import InferenceEngineInput from skyrl_train.utils import initialize_ray +from skyrl_train.entrypoints.main_base import config_dir +from typing import Tuple + +MODEL = "Qwen/Qwen2.5-1.5B-Instruct" + -model = "Qwen/Qwen2.5-1.5B-Instruct" -tp_size = 2 +def get_test_actor_config() -> DictConfig: + """Get base config with test-specific overrides.""" + with hydra.initialize_config_dir(config_dir=config_dir): + cfg = hydra.compose(config_name="ppo_base_config") + cfg.trainer.policy.model.path = MODEL -def init_remote_vinference_engines(tp_size): + return cfg + + +def init_remote_inference_servers(tp_size: int, backend: str) -> Tuple[InferenceEngineClient, subprocess.Popen]: available_gpus = get_available_gpus() assert ( len(available_gpus) >= tp_size @@ -38,60 +59,89 @@ def get_free_port(): engine_port = get_free_port() # Launch vLLM server using subprocess - vllm_cmd = [ - "uv", - "run", - "--isolated", - "--extra", - "vllm", - "-m", - "skyrl_train.inference_engines.vllm.vllm_server", - "--model", - model, - "--enforce-eager", - "--tensor-parallel-size", - str(tp_size), - "--distributed-executor-backend", - "ray", - "--dtype", - "bfloat16", - "--host", - "127.0.0.1", - "--port", - str(engine_port), - "--worker-extension-cls", - "skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", - ] + if backend == "vllm": + remote_server_command = [ + "uv", + "run", + "--isolated", + "--extra", + "vllm", + "-m", + "skyrl_train.inference_engines.vllm.vllm_server", + "--model", + MODEL, + "--enforce-eager", + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "ray", + "--dtype", + "bfloat16", + "--host", + "127.0.0.1", + "--port", + str(engine_port), + "--worker-extension-cls", + "skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", + ] + elif backend == "sglang": + remote_server_command = [ + "uv", + "run", + "--isolated", + "--extra", + "sglang", + "-m", + "skyrl_train.inference_engines.sglang.sglang_server", + "--model-path", + MODEL, + "--tp-size", + str(tp_size), + "--dtype", + "bfloat16", + "--host", + "127.0.0.1", + "--port", + str(engine_port), + "--mm-attention-backend", + "fa3", + "--attention-backend", + "fa3", + ] + else: + raise ValueError(f"Unsupported backend: {backend}") # Set CUDA_VISIBLE_DEVICES environment variable for the subprocess env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = gpu_ids_str # Start the vLLM server process - vllm_process = subprocess.Popen(vllm_cmd, env=env) + server_process = subprocess.Popen(remote_server_command, env=env) wait_for_server(url=f"localhost:{engine_port}", health_path="health") print(f"Server at localhost:{engine_port} is online") engines = create_remote_inference_engines( urls=[f"localhost:{engine_port}"], - model_name=model, - engine_backend="vllm", + model_name=MODEL, + engine_backend=backend, tensor_parallel_size=tp_size, sampling_params=get_sampling_params_for_backend( - "vllm", DictConfig({"temperature": 0.0, "top_p": 1, "top_k": -1, "max_generate_length": 1024, "min_p": 0.0}) + backend, + DictConfig({"temperature": 0.0, "top_p": 1, "top_k": -1, "max_generate_length": 1024, "min_p": 0.0}), ), ) - return InferenceEngineClient(engines), vllm_process + return InferenceEngineClient(engines), server_process -def init_ray_vllm_engines(): +def init_ray_inference_engines(backend: str, tp_size: int) -> InferenceEngineClient: + """Initialize ray-wrapped inference engines for the specified backend""" engine = create_ray_wrapped_inference_engines( num_inference_engines=1, tensor_parallel_size=tp_size, model_dtype="bfloat16", - pretrain=model, + pretrain=MODEL, seed=42, vllm_v1_disable_multiproc=True, enable_prefix_caching=True, @@ -99,14 +149,16 @@ def init_ray_vllm_engines(): max_model_len=1536, shared_pg=None, gpu_memory_utilization=0.8, - vllm_enable_sleep=False, + inference_engine_enable_sleep=False, async_engine=True, max_num_batched_tokens=8192, max_num_seqs=1024, sampling_params=get_sampling_params_for_backend( - "vllm", DictConfig({"temperature": 0.0, "top_p": 1, "top_k": -1, "max_generate_length": 1024, "min_p": 0.0}) + backend, + DictConfig({"temperature": 0.0, "top_p": 1, "top_k": -1, "max_generate_length": 1024, "min_p": 0.0}), ), - tokenizer=AutoTokenizer.from_pretrained(model), + tokenizer=AutoTokenizer.from_pretrained(MODEL), + backend=backend, ) client = InferenceEngineClient(engine) return client @@ -160,111 +212,139 @@ async def run_single_generation_with_tokens(client, prompt_token_ids): return responses, finish_reasons -# TODO(tgriggs): Replicate for sglang -def test_inference_engines_generation(): +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + # TODO(Charlie): add TP > 1 tests for sglang when we support it + pytest.param("sglang", 1, marks=pytest.mark.sglang), + ], + ids=["vllm", "sglang"], +) +def test_inference_engines_generation(backend: str, tp_size: int): """ - Tests generation with a vllm remote engine. + Tests generation with both remote and ray-wrapped engines for the specified backend. """ - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) - - prompts = get_test_prompts(model) - - # Get responses from remote vllm engine. - llm_client, vllm_process = init_remote_vinference_engines(tp_size) try: - # Batched generation. - remote_batch_responses, batch_finish_reasons = asyncio.run(run_batch_generation(llm_client, prompts)) - assert len(remote_batch_responses) == len( + cfg = get_test_actor_config() + cfg.generator.backend = backend + initialize_ray(cfg) + + prompts = get_test_prompts(MODEL) + + try: + llm_client, remote_server_process = init_remote_inference_servers(tp_size, backend) + + # Batched generation + remote_batch_responses, batch_finish_reasons = asyncio.run(run_batch_generation(llm_client, prompts)) + assert len(remote_batch_responses) == len( + prompts + ), f"Number of responses should match number of prompts, got {len(remote_batch_responses)} responses but {len(prompts)} prompts" + assert len(batch_finish_reasons) == len( + prompts + ), f"Number of finish reasons should match number of prompts, got {len(batch_finish_reasons)} finish reasons but {len(prompts)} prompts" + + # Single generation (ie, submit individual requests) + remote_single_responses, single_finish_reasons = asyncio.run(run_single_generation(llm_client, prompts)) + assert len(remote_single_responses) == len( + prompts + ), f"Number of responses should match number of prompts, got {len(remote_single_responses)} responses but {len(prompts)} prompts" + assert len(single_finish_reasons) == len( + prompts + ), f"Number of finish reasons should match number of prompts, got {len(single_finish_reasons)} finish reasons but {len(prompts)} prompts" + + # Ensure batched and single generation outputs are (roughly) the same + for i in range(len(prompts)): + if not are_responses_similar(remote_batch_responses[i], remote_single_responses[i], tolerance=0.01): + print( + f"Remote batch and single generation responses are not similar, got batch={remote_batch_responses[i]} and single={remote_single_responses[i]}" + ) + + finally: + remote_server_process.terminate() + remote_server_process.wait() + + # Get responses from Ray engine + llm_client = init_ray_inference_engines(backend, tp_size) + + # Batched generation + local_batch_responses, batch_finish_reasons = asyncio.run(run_batch_generation(llm_client, prompts)) + assert len(local_batch_responses) == len( prompts - ), f"Number of responses should match number of prompts, got {len(remote_batch_responses)} responses but {len(prompts)} prompts" + ), f"Number of responses should match number of prompts, got {len(local_batch_responses)} responses but {len(prompts)} prompts" assert len(batch_finish_reasons) == len( prompts ), f"Number of finish reasons should match number of prompts, got {len(batch_finish_reasons)} finish reasons but {len(prompts)} prompts" - # Single generation (ie, submit individual requests). - remote_single_responses, single_finish_reasons = asyncio.run(run_single_generation(llm_client, prompts)) - assert len(remote_single_responses) == len( + # Single generation (ie, submit individual requests) + local_single_responses, single_finish_reasons = asyncio.run(run_single_generation(llm_client, prompts)) + assert len(local_single_responses) == len( prompts - ), f"Number of responses should match number of prompts, got {len(remote_single_responses)} responses but {len(prompts)} prompts" + ), f"Number of responses should match number of prompts, got {len(local_single_responses)} responses but {len(prompts)} prompts" assert len(single_finish_reasons) == len( prompts ), f"Number of finish reasons should match number of prompts, got {len(single_finish_reasons)} finish reasons but {len(prompts)} prompts" - # Ensure batched and single generation outputs are (roughly) the same. + # Ensure batched and single generation outputs are (roughly) the same + for i in range(len(prompts)): + if not are_responses_similar(local_batch_responses[i], local_single_responses[i], tolerance=0.01): + print( + f"Local batch and single generation responses are not similar, got batch={local_batch_responses[i]} and single={local_single_responses[i]}" + ) + + # Finally, ensure that remote and local outputs are (roughly) the same + for i in range(len(prompts)): + if not are_responses_similar(remote_batch_responses[i], local_batch_responses[i], tolerance=0.01): + print( + f"Remote and local batch generation responses are not similar, got remote={remote_batch_responses[i]} and local={local_batch_responses[i]}" + ) + + finally: + ray.shutdown() + + +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + # TODO(Charlie): add TP > 1 tests for sglang when we support it + pytest.param("sglang", 1, marks=pytest.mark.sglang), + ], + ids=["vllm", "sglang"], +) +def test_token_based_generation(backend: str, tp_size: int): + """Test generation using prompt_token_ids for the specified backend.""" + + try: + cfg = get_test_actor_config() + cfg.generator.backend = backend + initialize_ray(cfg) + + prompts = get_test_prompts(MODEL, 3) + tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompt_token_ids = tokenizer.apply_chat_template( + prompts, add_generation_prompt=True, tokenize=True, return_dict=True + )["input_ids"] + + llm_client = init_ray_inference_engines(backend, tp_size) + + # Test batch generation with tokens + token_batch_responses, _ = asyncio.run(run_batch_generation_with_tokens(llm_client, prompt_token_ids)) + assert len(token_batch_responses) == len(prompts) + + # Test single generation with tokens + token_single_responses, _ = asyncio.run(run_single_generation_with_tokens(llm_client, prompt_token_ids)) + assert len(token_single_responses) == len(prompts) + + # Compare with prompt-based generation + prompt_responses, _ = asyncio.run(run_batch_generation(llm_client, prompts)) + + # Outputs should be similar since we're using the same inputs for i in range(len(prompts)): - if not are_responses_similar(remote_batch_responses[i], remote_single_responses[i], tolerance=0.01): + if not are_responses_similar([token_batch_responses[i]], [prompt_responses[i]], tolerance=0.01): print( - f"Remote batch and single generation responses are not similar, got batch={remote_batch_responses[i]} and single={remote_single_responses[i]}" + f"Token and prompt responses differ: token={token_batch_responses[i]}, prompt={prompt_responses[i]}" ) + finally: - # Shut down the vllm server - vllm_process.terminate() - vllm_process.wait() - - # Get responses from Ray vllm engine. - llm_client = init_ray_vllm_engines() - # Batched generation. - local_batch_responses, batch_finish_reasons = asyncio.run(run_batch_generation(llm_client, prompts)) - assert len(local_batch_responses) == len( - prompts - ), f"Number of responses should match number of prompts, got {len(local_batch_responses)} responses but {len(prompts)} prompts" - assert len(batch_finish_reasons) == len( - prompts - ), f"Number of finish reasons should match number of prompts, got {len(batch_finish_reasons)} finish reasons but {len(prompts)} prompts" - - # Single generation (ie, submit individual requests). - local_single_responses, single_finish_reasons = asyncio.run(run_single_generation(llm_client, prompts)) - assert len(local_single_responses) == len( - prompts - ), f"Number of responses should match number of prompts, got {len(local_single_responses)} responses but {len(prompts)} prompts" - assert len(single_finish_reasons) == len( - prompts - ), f"Number of finish reasons should match number of prompts, got {len(single_finish_reasons)} finish reasons but {len(prompts)} prompts" - - # Ensure batched and single generation outputs are (roughly) the same. - for i in range(len(prompts)): - if not are_responses_similar(local_batch_responses[i], local_single_responses[i], tolerance=0.01): - print( - f"Local batch and single generation responses are not similar, got batch={local_batch_responses[i]} and single={local_single_responses[i]}" - ) - - # Finally, ensure that remote and local outputs are (roughly) the same. - for i in range(len(prompts)): - if not are_responses_similar(remote_batch_responses[i], local_batch_responses[i], tolerance=0.01): - print( - f"Remote and local batch generation responses are not similar, got remote={remote_batch_responses[i]} and local={local_batch_responses[i]}" - ) - - ray.shutdown() - - -def test_token_based_generation(): - """Test generation using prompt_token_ids.""" - - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) - - prompts = get_test_prompts(model, 3) - tokenizer = AutoTokenizer.from_pretrained(model) - prompt_token_ids = tokenizer.apply_chat_template( - prompts, add_generation_prompt=True, tokenize=True, return_dict=True - )["input_ids"] - - llm_client = init_ray_vllm_engines() - - # Test batch generation with tokens - token_batch_responses, _ = asyncio.run(run_batch_generation_with_tokens(llm_client, prompt_token_ids)) - assert len(token_batch_responses) == len(prompts) - - # Test single generation with tokens - token_single_responses, _ = asyncio.run(run_single_generation_with_tokens(llm_client, prompt_token_ids)) - assert len(token_single_responses) == len(prompts) - - # Compare with prompt-based generation - prompt_responses, _ = asyncio.run(run_batch_generation(llm_client, prompts)) - - # Outputs should be similar since we're using the same inputs - for i in range(len(prompts)): - if not are_responses_similar([token_batch_responses[i]], [prompt_responses[i]], tolerance=0.01): - print(f"Token and prompt responses differ: token={token_batch_responses[i]}, prompt={prompt_responses[i]}") - - ray.shutdown() + ray.shutdown() diff --git a/skyrl-train/tests/gpu/test_policy_vllm_e2e.py b/skyrl-train/tests/gpu/test_policy_local_engines_e2e.py similarity index 51% rename from skyrl-train/tests/gpu/test_policy_vllm_e2e.py rename to skyrl-train/tests/gpu/test_policy_local_engines_e2e.py index e89d4e656c..e6997e2e6c 100644 --- a/skyrl-train/tests/gpu/test_policy_vllm_e2e.py +++ b/skyrl-train/tests/gpu/test_policy_local_engines_e2e.py @@ -1,5 +1,9 @@ """ -uv run --isolated --extra dev --extra vllm pytest tests/gpu/test_policy_vllm_e2e.py +# Run only vllm tests (requires vllm extra): +uv run --isolated --extra dev --extra vllm --extra deepspeed pytest tests/gpu/test_policy_local_engines_e2e.py -m "vllm" + +# Run only sglang tests (requires sglang extra): +uv run --isolated --extra dev --extra sglang --extra deepspeed pytest tests/gpu/test_policy_local_engines_e2e.py -m "sglang" """ import pytest @@ -17,9 +21,9 @@ from skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl_train.inference_engines.base import InferenceEngineInput from skyrl_train.entrypoints.main_base import config_dir +from skyrl_train.utils import initialize_ray -model = "Qwen/Qwen2.5-1.5B-Instruct" -tp_size = 2 +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" def get_test_actor_config() -> DictConfig: @@ -28,37 +32,25 @@ def get_test_actor_config() -> DictConfig: cfg = hydra.compose(config_name="ppo_base_config") # Override specific parameters - cfg.trainer.policy.model.path = model + cfg.trainer.policy.model.path = MODEL cfg.trainer.critic.model.path = "" cfg.trainer.placement.policy_num_gpus_per_node = 2 cfg.generator.async_engine = True cfg.generator.num_inference_engines = 1 - cfg.generator.inference_engine_tensor_parallel_size = tp_size cfg.generator.run_engines_locally = True return cfg -async def run_vllm_inference(client, prompts): +async def run_inference(client, prompts): engine_input = InferenceEngineInput(prompts=prompts) - await client.generate(engine_input) + return await client.generate(engine_input) -def init_inference_engines(cfg, v1, use_local, async_engine, tp_size, colocate_all): +def init_inference_engines(cfg, use_local, async_engine, tp_size, colocate_all, backend): assert use_local, "This test does not yet support remote engines." - ray.init( - ignore_reinit_error=True, - runtime_env={ - "env_vars": { - "NCCL_CUMEM_ENABLE": "0", - "NCCL_P2P_DISABLE": "0", - "CUDA_LAUNCH_BLOCKING": "1", - "VLLM_USE_V1": "1" if v1 else "0", - "VLLM_ENABLE_V1_MULTIPROCESSING": "0", - "PYTORCH_NVML_BASED_CUDA_CHECK": "1", - } - }, - ) + assert backend in ["vllm", "sglang"] + initialize_ray(cfg) if colocate_all: pg = placement_group([{"GPU": 1, "CPU": 1}] * tp_size, strategy="PACK") get_ray_pg_ready_with_timeout(pg, timeout=30) @@ -70,20 +62,21 @@ def init_inference_engines(cfg, v1, use_local, async_engine, tp_size, colocate_a num_inference_engines=1, tensor_parallel_size=tp_size, model_dtype="bfloat16", - pretrain=model, + pretrain=MODEL, seed=42, vllm_v1_disable_multiproc=True, enable_prefix_caching=True, enforce_eager=True, max_model_len=1536, shared_pg=pg, - gpu_memory_utilization=0.8, - vllm_enable_sleep=sleep, + gpu_memory_utilization=0.6, + inference_engine_enable_sleep=sleep, async_engine=async_engine, max_num_batched_tokens=8192, max_num_seqs=1024, - sampling_params=get_sampling_params_for_backend("vllm", cfg.generator.sampling_params), - tokenizer=AutoTokenizer.from_pretrained(model), + sampling_params=get_sampling_params_for_backend(backend, cfg.generator.sampling_params), + tokenizer=AutoTokenizer.from_pretrained(MODEL), + backend=backend, ) client = InferenceEngineClient(eps) if sleep: @@ -92,29 +85,42 @@ def init_inference_engines(cfg, v1, use_local, async_engine, tp_size, colocate_a @pytest.mark.parametrize( - ("colocate_all", "weight_sync_backend", "strategy"), + ("colocate_all", "weight_sync_backend", "strategy", "backend", "tp_size"), [ - (False, "nccl", "fsdp"), - (True, "nccl", "fsdp"), - (False, "gloo", "fsdp"), - (True, "gloo", "fsdp"), - (False, "nccl", "deepspeed"), - (True, "nccl", "deepspeed"), - (False, "nccl", "fsdp2"), - (True, "nccl", "fsdp2"), + pytest.param(False, "nccl", "fsdp", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(True, "nccl", "fsdp", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(False, "gloo", "fsdp", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(True, "gloo", "fsdp", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(False, "nccl", "deepspeed", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(True, "nccl", "deepspeed", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(False, "nccl", "fsdp2", "vllm", 2, marks=pytest.mark.vllm), + pytest.param(True, "nccl", "fsdp2", "vllm", 2, marks=pytest.mark.vllm), + # TODO(Charlie): add TP > 1 tests for sglang when we support it + pytest.param(False, "nccl", "deepspeed", "sglang", 1, marks=pytest.mark.sglang), + pytest.param(True, "nccl", "deepspeed", "sglang", 1, marks=pytest.mark.sglang), + pytest.param(False, "nccl", "fsdp2", "sglang", 1, marks=pytest.mark.sglang), + pytest.param(True, "nccl", "fsdp2", "sglang", 1, marks=pytest.mark.sglang), + pytest.param(False, "gloo", "fsdp", "sglang", 1, marks=pytest.mark.sglang), + pytest.param(True, "gloo", "fsdp", "sglang", 1, marks=pytest.mark.sglang), ], ids=[ - "no_colocate_nccl_fsdp", - "colocate_nccl_fsdp", - "no_colocate_gloo_fsdp", - "colocate_gloo_fsdp", - "no_colocate_nccl_deepspeed", - "colocate_nccl_deepspeed", - "no_colocate_nccl_fsdp2", - "colocate_nccl_fsdp2", + "no_colocate_nccl_fsdp_vllm", + "colocate_nccl_fsdp_vllm", + "no_colocate_gloo_fsdp_vllm", + "colocate_gloo_fsdp_vllm", + "no_colocate_nccl_deepspeed_vllm", + "colocate_nccl_deepspeed_vllm", + "no_colocate_nccl_fsdp2_vllm", + "colocate_nccl_fsdp2_vllm", + "no_colocate_nccl_deepspeed_sglang", + "colocate_nccl_deepspeed_sglang", + "no_colocate_nccl_fsdp2_sglang", + "colocate_nccl_fsdp2_sglang", + "no_colocate_gloo_fsdp_sglang", + "colocate_gloo_fsdp_sglang", ], ) -def test_policy_vllm_e2e(colocate_all, weight_sync_backend, strategy): +def test_policy_local_engines_e2e(colocate_all, weight_sync_backend, strategy, backend, tp_size): """ Tests initalizing the policy actor group and inference engine, syncing weights, and performing generation. """ @@ -123,14 +129,17 @@ def test_policy_vllm_e2e(colocate_all, weight_sync_backend, strategy): cfg.trainer.placement.colocate_all = colocate_all cfg.generator.weight_sync_backend = weight_sync_backend cfg.trainer.strategy = strategy + cfg.generator.backend = backend + cfg.generator.inference_engine_tensor_parallel_size = tp_size + # If colocate is True, this will load the engine, sleep, and wake up the engine client, pg = init_inference_engines( cfg=cfg, - v1=True, use_local=True, async_engine=cfg.generator.async_engine, tp_size=cfg.generator.inference_engine_tensor_parallel_size, colocate_all=cfg.trainer.placement.colocate_all, + backend=backend, ) policy = init_worker_with_type( @@ -143,6 +152,8 @@ def test_policy_vllm_e2e(colocate_all, weight_sync_backend, strategy): ray.get(policy.async_run_ray_method("pass_through", "init_weight_sync_state", client)) asyncio.run(client.reset_prefix_cache()) ray.get(policy.async_run_ray_method("pass_through", "broadcast_to_inference_engines", client)) - asyncio.run(run_vllm_inference(client, get_test_prompts(model))) + outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL))) + + print(f"Example output: {outputs['responses'][0]}, {outputs['stop_reasons'][0]}") finally: ray.shutdown() diff --git a/skyrl-train/tests/gpu/test_skyrl_gym_generator.py b/skyrl-train/tests/gpu/test_skyrl_gym_generator.py index 7943ab7335..6ee0005595 100644 --- a/skyrl-train/tests/gpu/test_skyrl_gym_generator.py +++ b/skyrl-train/tests/gpu/test_skyrl_gym_generator.py @@ -17,6 +17,17 @@ from skyrl_gym.envs import register from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput from typing import Any, Dict +import hydra +from skyrl_train.entrypoints.main_base import config_dir + + +def get_test_actor_config() -> DictConfig: + """Get base config with test-specific overrides.""" + with hydra.initialize_config_dir(config_dir=config_dir): + cfg = hydra.compose(config_name="ppo_base_config") + cfg.generator.backend = "vllm" + + return cfg # Setup for formatting tests @@ -86,7 +97,7 @@ async def run_generator_end_to_end( max_model_len=max_input_length + max_generate_length, shared_pg=None, gpu_memory_utilization=0.8, - vllm_enable_sleep=True, + inference_engine_enable_sleep=True, async_engine=use_async_engine, max_num_batched_tokens=8192, max_num_seqs=1024, @@ -117,6 +128,7 @@ async def run_generator_end_to_end( "max_turns": max_turns, "zero_reward_on_non_stop": False, "use_conversation_multi_turn": use_conversation_multi_turn, + "apply_overlong_filtering": False, } ) @@ -204,7 +216,7 @@ async def test_generator_single_turn_gsm8k( """ Test the generator with a single turn of GSM8K """ - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) + initialize_ray(get_test_actor_config()) try: await run_generator_end_to_end( use_async_engine=use_async_engine, @@ -222,7 +234,7 @@ async def test_generator_multi_turn_text2sql(): """ Test the generator with multiple turns of text2sql """ - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) + initialize_ray(get_test_actor_config()) try: await run_generator_end_to_end( use_async_engine=True, @@ -249,7 +261,7 @@ async def test_generator_multi_turn_search(): """ Test the generator with multiple turns of search """ - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) + initialize_ray(get_test_actor_config()) try: await run_generator_end_to_end( use_async_engine=True, @@ -280,7 +292,7 @@ async def test_generator_formatting_use_conversation_multi_turn(model_name): """ Test generator formatting when using conversation formatting for multi-turn """ - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) + initialize_ray(get_test_actor_config()) try: tokenizer = AutoTokenizer.from_pretrained(model_name) generator_output = await run_generator_end_to_end( @@ -338,7 +350,7 @@ async def test_generator_formatting_no_use_conversation_multi_turn(model_name): """ Test generator formatting when not using conversation formatting for multi-turn """ - initialize_ray(DictConfig({"generator": {"backend": "vllm"}})) + initialize_ray(get_test_actor_config()) try: tokenizer = AutoTokenizer.from_pretrained(model_name) generator_output = await run_generator_end_to_end( diff --git a/skyrl-train/tests/sglang/__init__.py b/skyrl-train/tests/sglang/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/skyrl-train/tests/sglang/test_policy_sglang_e2e.py b/skyrl-train/tests/sglang/test_policy_sglang_e2e.py deleted file mode 100644 index e3adb05715..0000000000 --- a/skyrl-train/tests/sglang/test_policy_sglang_e2e.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -uv run --isolated --extra dev --extra sglang pytest tests/gpu/test_policy_sglang_e2e.py -""" - -import pytest -import asyncio -import ray -import hydra -from omegaconf import DictConfig - -from tests.gpu.utils import init_worker_with_type, get_test_prompts -from skyrl_train.inference_engines.remote_inference_engine import create_remote_inference_engines -from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient -from ray.util.placement_group import placement_group -from skyrl_train.inference_engines.sglang.sglang_server import SGLangServer -from sglang.srt.server_args import ServerArgs -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from skyrl_train.inference_engines.utils import get_sampling_params_for_backend -from skyrl_train.inference_engines.base import InferenceEngineInput -from skyrl_train.utils import get_ray_pg_ready_with_timeout, initialize_ray -from skyrl_train.entrypoints.main_base import config_dir - -model = "Qwen/Qwen2.5-1.5B-Instruct" -tp_size = 2 - - -def get_test_actor_config() -> DictConfig: - """Get base config with test-specific overrides.""" - with hydra.initialize_config_dir(config_dir=config_dir): - cfg = hydra.compose(config_name="ppo_base_config") - - # Override specific parameters for test - cfg.trainer.policy.model.path = model - cfg.trainer.critic.model.path = "" - cfg.trainer.placement.policy_num_gpus_per_node = 2 - - cfg.trainer.placement.colocate_all = False - cfg.trainer.train_batch_size = 128 - cfg.generator.backend = "sglang" - cfg.generator.async_engine = True - cfg.generator.num_inference_engines = 1 - cfg.generator.inference_engine_tensor_parallel_size = tp_size - cfg.generator.run_engines_locally = False - - return cfg - - -async def run_generation(client, prompts): - engine_input = InferenceEngineInput(prompts=prompts) - await client.generate(engine_input) - - -def init_sglang_engines(use_local, tp_size, colocate_all, sampling_params): - assert not use_local, "SGLang currently does not support local engines" - assert not colocate_all, "SGLang currently does not support colocation" - - initialize_ray(DictConfig({"generator": {"backend": "sglang"}})) - - if colocate_all: - pg = placement_group([{"GPU": 1, "CPU": 1}] * tp_size, strategy="PACK") - get_ray_pg_ready_with_timeout(pg, timeout=30) - sleep = True - else: - pg, sleep = None, False - - def get_free_port(): - import socket - - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() - return port - - engine_port = get_free_port() - - sglang_pg = placement_group([{"GPU": tp_size, "CPU": tp_size}], strategy="PACK") - get_ray_pg_ready_with_timeout(sglang_pg, timeout=30) - - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=sglang_pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=0, - ) - - SGLangServerRayActor = ray.remote(SGLangServer) - server_actor = SGLangServerRayActor.options( - num_gpus=tp_size, - num_cpus=tp_size, - scheduling_strategy=scheduling_strategy, - ).remote( - ServerArgs( - model_path=model, - tp_size=tp_size, - dtype="bfloat16", - mem_fraction_static=0.7, - enable_memory_saver=True, - base_gpu_id=0, - gpu_id_step=1, - port=engine_port, - ) - ) - server_actor.run_server.remote() - - # Wait for server to come online - import requests - import time - - def wait_for_server(url: str, timeout: int = 60, interval: float = 1.0): - start_time = time.time() - while True: - try: - response = requests.get(f"http://{url}/health_generate") - if response.ok: - return - except requests.exceptions.ConnectionError: - if time.time() - start_time > timeout: - raise TimeoutError(f"Server at {url} did not come online within {timeout} seconds") - time.sleep(interval) - - wait_for_server(f"localhost:{engine_port}") - print(f"Server at localhost:{engine_port} is online") - - engines = create_remote_inference_engines( - urls=[f"localhost:{engine_port}"], - model_name=model, - engine_backend="sglang", - tensor_parallel_size=tp_size, - sampling_params=sampling_params, - ) - client = InferenceEngineClient(engines) - if sleep: - asyncio.run(client.wake_up()) - - return client, pg, server_actor - - -@pytest.mark.parametrize( - ("weight_sync_backend"), - [ - ("nccl"), - ("gloo"), - ], - ids=[ - "nccl", - "gloo", - ], -) -def test_policy_sglang_e2e(weight_sync_backend): - """ - Tests initalizing the policy actor group and InferenceEngines, syncing weights, and performing generation. - """ - cfg = get_test_actor_config() - cfg.generator.weight_sync_backend = weight_sync_backend - - llm_client, pg, server_actor = init_sglang_engines( - use_local=cfg.generator.run_engines_locally, - tp_size=cfg.generator.inference_engine_tensor_parallel_size, - colocate_all=cfg.trainer.placement.colocate_all, - sampling_params=get_sampling_params_for_backend("sglang", cfg.generator.sampling_params), - ) - policy = init_worker_with_type( - "policy", - shared_pg=pg, - colocate_all=cfg.trainer.placement.colocate_all, - num_gpus_per_node=cfg.generator.inference_engine_tensor_parallel_size, - cfg=cfg, - ) - ray.get(policy.async_run_ray_method("pass_through", "init_weight_sync_state", llm_client)) - asyncio.run(llm_client.reset_prefix_cache()) - ray.get(policy.async_run_ray_method("pass_through", "broadcast_to_inference_engines", llm_client)) - asyncio.run(run_generation(llm_client, get_test_prompts(model))) diff --git a/skyrl-train/uv.lock b/skyrl-train/uv.lock index 078e56d18e..45c0cc0b0c 100644 --- a/skyrl-train/uv.lock +++ b/skyrl-train/uv.lock @@ -3889,4 +3889,4 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, -] +] \ No newline at end of file