Skip to content
9 changes: 5 additions & 4 deletions skyrl-train/skyrl_train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ray.util.placement_group import placement_group, PlacementGroup

from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from skyrl_train.dataset import PromptDataset
from skyrl_train.utils import validate_cfg

Expand Down Expand Up @@ -36,7 +36,7 @@
__all__ = ["BasePPOExp", "config_dir"]


def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_pg, tokenizer):
def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_pg, tokenizer: PreTrainedTokenizerBase):
from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines

return create_ray_wrapped_inference_engines(
Expand All @@ -61,12 +61,13 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p
)


def create_remote_inference_engines_from_config(cfg: DictConfig):
def create_remote_inference_engines_from_config(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase):
# TODO(tgriggs): We may want a separate config for the model name in case it's different from the name used in the OpenAI API
return create_remote_inference_engines(
urls=cfg.generator.remote_inference_engine_urls,
model_name=cfg.trainer.policy.model.path,
engine_backend=cfg.generator.backend,
tokenizer=tokenizer,
tensor_parallel_size=cfg.generator.inference_engine_tensor_parallel_size,
sampling_params=get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params),
)
Expand Down Expand Up @@ -247,7 +248,7 @@ def _setup_trainer(self):
if self.cfg.generator.run_engines_locally:
inference_engines = create_ray_wrapped_inference_engines_from_config(self.cfg, self.colocate_pg, tokenizer)
else:
inference_engines = create_remote_inference_engines_from_config(self.cfg)
inference_engines = create_remote_inference_engines_from_config(self.cfg, tokenizer)

inference_engine_client = InferenceEngineClient(inference_engines)

Expand Down
249 changes: 158 additions & 91 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ class InferenceEngineInput(TypedDict):


class InferenceEngineOutput(TypedDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have small nits on this paragraph just for brevity / clarity, but will draft something in next round of updates :)

# We always return both tokens and text outputs. The tokens are the outputs
# of inference engine, and the text is the decoded text output. Therefore,
# it is guaranteed that tokenizer.decode(response_token_ids) == responses,
# but the reverse is not guaranteed, since there are multiple ways to
# represent the same text with tokens. Therefore, for multi-turn generation,
# please use token-in-token-out to ensure correctness.
responses: List[str]
response_ids: List[List[int]]
stop_reasons: List[str]
response_ids: Optional[List[List[int]]]
response_logprobs: Optional[List[List[float]]]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
# Split evenly across engines
return await self._generate_batched(prompts, prompt_token_ids, sampling_params)

async def _generate_with_trajectory_routing(self, prompts, prompt_token_ids, trajectory_ids, sampling_params):
async def _generate_with_trajectory_routing(
self, prompts, prompt_token_ids, trajectory_ids, sampling_params
) -> InferenceEngineOutput:
"""
Route prompts to engines based on trajectory_ids and return results in the original order of the prompts.
"""
Expand Down Expand Up @@ -80,37 +82,27 @@ async def _generate_with_trajectory_routing(self, prompts, prompt_token_ids, tra
responses: list[str] = [""] * n
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
# a bit hacky for now
add_resp_ids = False
add_resp_logprobs = False

for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
responses[original_idx] = result["responses"][local_idx]
stop_reasons[original_idx] = result["stop_reasons"][local_idx]
if result.get("response_ids", None):
add_resp_ids = True
response_ids[original_idx] = result["response_ids"][local_idx]
response_ids[original_idx] = result["response_ids"][local_idx]
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]

if any([len(response) == 0 for response in responses]) or (
add_resp_ids and not all([isinstance(sample_ids, list) for sample_ids in response_ids])
):
raise RuntimeError(
"Did not receive responses / response ids for some prompts. This should never happen. There is likely something wrong with the inference engine"
)

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids if add_resp_ids else None,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
)

async def _generate_batched(self, prompts, prompt_token_ids, sampling_params):
async def _generate_batched(self, prompts, prompt_token_ids, sampling_params) -> InferenceEngineOutput:
"""
Split prompts evenly across engines and return results in the original order of the prompts.
"""
Expand Down Expand Up @@ -144,15 +136,14 @@ async def _generate_batched(self, prompts, prompt_token_ids, sampling_params):
for output in all_outputs:
responses.extend(output["responses"])
stop_reasons.extend(output["stop_reasons"])
if output.get("response_ids", None):
response_ids.extend(output["response_ids"])
response_ids.extend(output["response_ids"])
if output.get("response_logprobs", None):
response_logprobs.extend(output["response_logprobs"])

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids if len(response_ids) else None,
response_ids=response_ids,
response_logprobs=response_logprobs if len(response_logprobs) else None,
)

Expand Down
100 changes: 73 additions & 27 deletions skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from typing import List, Optional, Dict, Any
import json
import asyncio
from transformers import PreTrainedTokenizerBase


class RemoteInferenceEngine(InferenceEngineInterface):
Expand All @@ -20,6 +20,7 @@ def __init__(
url: str,
model_name: str,
engine_backend: str,
tokenizer: PreTrainedTokenizerBase,
tp_size: Optional[int] = None,
sampling_params: Optional[Dict[str, Any]] = None,
):
Expand All @@ -29,45 +30,88 @@ def __init__(
self.engine_backend = engine_backend
self.tp_size = tp_size
self.sampling_params = sampling_params if sampling_params is not None else {}
self.tokenizer = tokenizer

async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
# 1. Prepare inputs
prompts = input_batch.get("prompts")
prompt_token_ids = input_batch.get("prompt_token_ids")
prompt_token_ids: Optional[List[List[int]]] = input_batch.get("prompt_token_ids")
request_sampling_params = input_batch.get("sampling_params")

# For token-in-token-out, convert prompts to token ids if needed
if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.apply_chat_template(
prompts,
add_generation_prompt=True,
add_special_tokens=False,
return_dict=True,
tokenize=True,
)["input_ids"]

sampling_params = request_sampling_params if request_sampling_params is not None else self.sampling_params
if "n" in sampling_params and sampling_params["n"] > 1:
raise ValueError(
"n is not supported yet for remote inference engines. "
"You can set `config.generator.n_samples_per_prompt` instead."
)

output_tasks = []
# 2. Send a batched request to the server
response = None
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
headers = {"Content-Type": "application/json"}
payload = sampling_params.copy()
payload["model"] = self.model_name

if prompts is not None:
for prompt in prompts:
payload["messages"] = prompt
output_tasks.append(session.post(f"{self.url}/v1/chat/completions", json=payload, headers=headers))
else: # prompt_token_ids is not None
for p_ids in prompt_token_ids:
payload["prompt"] = p_ids
output_tasks.append(session.post(f"{self.url}/v1/completions", json=payload, headers=headers))

request_outputs = await asyncio.gather(*output_tasks)

outputs = []
finish_reasons = []
# TODO (sumanthrh): This is creating a flattened list of outputs. If sampling n > 1, we should fix this.
for request_output in request_outputs:
response = await request_output.json()
for choice in response.get("choices", []):
text = choice.get("message", {}).get("content", "")
outputs.append(text)
finish_reasons.append(choice.get("finish_reason"))
payload = {}
request_url = ""
if self.engine_backend == "vllm":
# vLLM does not support /generate, use /completions instead. It supports batch generation.
payload = sampling_params.copy()
payload["model"] = self.model_name
payload["prompt"] = prompt_token_ids
request_url = f"{self.url}/v1/completions"
elif self.engine_backend == "sglang":
# SGLang supports /generate, works exactly like its Python `async_generate()` method
# and can do batch generation.
payload = {
"input_ids": prompt_token_ids,
"sampling_params": sampling_params,
}
request_url = f"{self.url}/generate"
else:
raise ValueError(f"Invalid engine backend: {self.engine_backend}")
async with session.post(request_url, json=payload, headers=headers) as resp:
response = await resp.json()

# 3. Parse outputs
outputs = []
output_ids = []
finish_reasons = []

if self.engine_backend == "vllm":
for i, choice in enumerate(response.get("choices", [])):
# Since n=1, index i represents the output for `prompt[i]`
assert choice["index"] == i, "Expect the choices to be ordered by index."
text = choice["text"]
outputs.append(text)
finish_reasons.append(choice["finish_reason"])
# TODO(Charlie): this is not token-in-token-out because vLLM does not support
# returning token IDs via HTTP requests. Fix after this vLLM PR is merged:
# https://github.com/vllm-project/vllm/pull/22587
output_ids.append(self.tokenizer.encode(text, add_special_tokens=False))
elif self.engine_backend == "sglang":
# since prompt_token_ids is a list of lists, response is a list of dicts
for output in response:
cur_output_ids = output["output_ids"]
output_ids.append(cur_output_ids)
# SGLang only returns tokens not text when skip_tokenizer_init is True, so
# we manually decode it.
outputs.append(self.tokenizer.decode(cur_output_ids))
finish_reasons.append(output["meta_info"]["finish_reason"]["type"])
else:
raise ValueError(f"Invalid engine backend: {self.engine_backend}")

return InferenceEngineOutput(
responses=outputs, stop_reasons=finish_reasons, response_ids=None, response_logprobs=None
responses=outputs, stop_reasons=finish_reasons, response_ids=output_ids, response_logprobs=None
)

async def wake_up(self, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -173,13 +217,15 @@ def create_remote_inference_engines(
urls: List[str],
model_name: str,
engine_backend: str,
tokenizer: PreTrainedTokenizerBase,
tensor_parallel_size: Optional[int] = None,
sampling_params: Optional[Dict[str, Any]] = None,
):
return [
RemoteInferenceEngine(
url=url,
model_name=model_name,
tokenizer=tokenizer,
engine_backend=engine_backend,
tp_size=tensor_parallel_size,
sampling_params=sampling_params,
Expand Down
13 changes: 10 additions & 3 deletions skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def __init__(self, *args, bundle_indices: Optional[List[int]] = None, **kwargs):
# Add custom weight loader
kwargs["custom_weight_loader"] = CUSTOM_WEIGHT_LOADER_PATH

# Always use token-in-token-out SGLang engine
# NOTE(Charlie): unlike vLLM, SGLang cannot do token-in-token-out and
# token-in-text-out in the same engine config.
kwargs["skip_tokenizer_init"] = True

# Create the SGLang engine (signal handler issue is now fixed by patching)
self.engine = Engine(**kwargs)
print(f"Created SGLang engine with kwargs: {kwargs}")
Expand Down Expand Up @@ -224,16 +229,18 @@ def _postprocess_outputs(self, outputs):
"""Process SGLang outputs to match expected format."""
responses: List[str] = []
stop_reasons: List[str] = []
response_ids: List[List[int]] = []

for output in outputs:
responses.append(output["text"])
response_ids.append(output["output_ids"])
responses.append(self.tokenizer.decode(output["output_ids"]))
stop_reasons.append(output["meta_info"]["finish_reason"]["type"])

return InferenceEngineOutput(
responses=responses,
# not supported with sglang yet
response_ids=None,
response_ids=response_ids,
stop_reasons=stop_reasons,
response_logprobs=None,
)

async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def run_server(self) -> None:


if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
args = sys.argv[1:]
# SGLang requires `skip-tokenizer-init` to do token-in-token-out with `/generate` endpoint
if "--skip-tokenizer-init" not in args:
args.append("--skip-tokenizer-init")
server_args = prepare_server_args(args)
sglang_server = SGLangServer(server_args)
sglang_server.run_server()
Loading