Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions skyrl-train/skyrl_train/dataset/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Tuple, Optional
import torch
from transformers import AutoTokenizer
from jaxtyping import Float
from jaxtyping import Float, Integer


def _verify_inputs(
Expand Down Expand Up @@ -32,13 +32,15 @@ def convert_prompts_responses_to_batch_tensors(
rewards: List[List[float]],
loss_masks: List[List[int]],
logprobs: Optional[List[List[float]]] = None,
sampling_masks: Optional[List[List[List[int]]]] = None,
) -> Tuple[
Float[torch.Tensor, "batch seq_len"],
Float[torch.Tensor, "batch seq_len"],
Float[torch.Tensor, "batch response_len"],
Float[torch.Tensor, "batch response_len"],
Float[torch.Tensor, "batch response_len"],
Optional[Float[torch.Tensor, "batch response_len"]],
Optional[Integer[torch.Tensor, "batch response_len mask_size"]],
]:
"""
Convert prompts and responses to batch tensors for training.
Expand All @@ -59,13 +61,15 @@ def convert_prompts_responses_to_batch_tensors(
rewards: List of rewards for each response
loss_masks: List of loss masks for each response
logprobs: List of rollout log probs for each response
sampling_masks: Optional list of sampling masks (top-k/top-p valid token indices) for each response

Returns:
sequences: Full trajectories (padded and concatenated prompts and responses). Size: (batch, seq_len).
attention_mask: Attention mask for the model. Size: (batch, seq_len)
action_mask: Response mask for the model. Size: (batch, response_len)
rewards: Rewards for each output. Size: (batch, response_len)
loss_masks: Loss masks for each output. Size: (batch, response_len)
sampling_masks_tensor: Sampling masks tensor. Size: (batch, response_len, max_k) with -1 padding
"""
_verify_inputs(prompts, responses, rewards, loss_masks)

Expand Down Expand Up @@ -129,4 +133,29 @@ def convert_prompts_responses_to_batch_tensors(
]
logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float)

return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor
sampling_masks_tensor = None
if sampling_masks:
batch_size = len(sampling_masks)
max_seq_len = action_mask.size(1)

max_k = 0
for sample_masks in sampling_masks:
for step_mask in sample_masks:
max_k = max(max_k, len(step_mask))

if max_k > 0:
# shape: (batch_size, seq_len, max_k)
sampling_masks_tensor = torch.full(
(batch_size, max_seq_len, max_k),
fill_value=-1,
dtype=torch.int64,
)

for i, sample_masks in enumerate(sampling_masks):
for j, step_mask in enumerate(sample_masks):
if j < max_seq_len:
num_valid = len(step_mask)
if num_valid > 0:
sampling_masks_tensor[i, j, :num_valid] = torch.tensor(step_mask, dtype=torch.int64)

return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor, sampling_masks_tensor
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/dataset/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Experience:
loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]]
action_mask: Optional[Integer[torch.Tensor, "batch response_len"]]
rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]]
sampling_mask: Optional[Integer[torch.Tensor, "batch seq_len mask_size"]]
num_actions: int
info: Optional[dict]
kl: Optional[Float[torch.Tensor, "batch response_len"]] = None
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class GeneratorOutput(TypedDict):
trajectory_ids: Optional[List[TrajectoryID]]
# Applicable only for step-wise training
is_last_step: Optional[List[bool]]
sampling_masks: Optional[List[List[List[int]]]]


class MetricsOutput(TypedDict):
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class InferenceEngineOutput(TypedDict):
response_ids: List[List[int]]
stop_reasons: List[str]
response_logprobs: Optional[List[List[float]]]
sampling_masks: Optional[List[List[List[int]]]]


class InferenceEngineInterface(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
# a bit hacky for now
add_resp_logprobs = False

sampling_masks: List[List[List[int]]] = [[] for _ in range(n)]

for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
responses[original_idx] = result["responses"][local_idx]
Expand All @@ -145,12 +147,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
# TODO(devpatel): see patch in vllm_engine.py for more details.
if result.get("sampling_masks", None):
sampling_masks[original_idx] = result["sampling_masks"][local_idx]

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
sampling_masks=sampling_masks,
)

async def _generate_single_with_retry(
Expand Down
65 changes: 65 additions & 0 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List, Any, Dict, Optional, TYPE_CHECKING
import threading

if TYPE_CHECKING:
from skyrl_train.weight_sync.transfer_strategy import WeightSyncInitInfo
Expand Down Expand Up @@ -37,6 +38,55 @@
from packaging import version


# TODO(devpatel): This is a hack to get the sampling masks. We should find a better way to do this... fast
_sampling_masks = threading.local()
_sampler_patched = False


def _reset_sampling_masks() -> None:
_sampling_masks.items = []


def _append_sampling_mask(mask: torch.Tensor) -> None:
if not hasattr(_sampling_masks, "items"):
_sampling_masks.items = []
_sampling_masks.items.append(mask)


def _consume_sampling_masks() -> Optional[List[torch.Tensor]]:
masks = getattr(_sampling_masks, "items", None)
_sampling_masks.items = []
return masks


def _patch_vllm_sampler() -> None:
global _sampler_patched
if _sampler_patched:
return
try:
from vllm.v1.sample.ops import topk_topp_sampler as sampler
except Exception as exc:
logger.warning(f"Could not import vLLM topk_topp_sampler op and/or Sampler class: {exc}")
return

original_top_k_top_p = sampler.apply_top_k_top_p
original_top_k_only = sampler.apply_top_k_only

def _wrapped_top_k_top_p(logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None) -> torch.Tensor:
output = original_top_k_top_p(logits, k, p)
_append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu())
return output

def _wrapped_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
output = original_top_k_only(logits, k)
_append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu())
return output

sampler.apply_top_k_top_p = _wrapped_top_k_top_p
sampler.apply_top_k_only = _wrapped_top_k_only
_sampler_patched = True
Comment on lines +41 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dynamic patching of vLLM's internal sampler functions (apply_top_k_top_p and apply_top_k_only) using threading.local() is a highly fragile approach. This relies on specific internal implementation details of vLLM, which are not part of its public API and can change without warning in future updates. This could lead to unexpected behavior, crashes, or incorrect sampling mask generation if vLLM's internal structure changes. While the TODO comment acknowledges this is a hack, it poses a significant risk to the maintainability and stability of the system. It would be preferable to find a more robust and officially supported way to extract this information from vLLM, or to encapsulate this hack more thoroughly with version checks and fallback mechanisms.


Comment on lines +62 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The monkey-patching of vllm's internal sampler functions is a clever way to extract the sampling masks, but it's very fragile and creates a significant maintenance burden. This implementation is tightly coupled to a specific version and internal structure of vllm. Any changes in future vllm versions, even minor ones, could break this functionality silently, leading to incorrect sampling masks being used in training without any errors being raised. While the try-except block handles import failures, it doesn't protect against changes in the patched function's behavior or signature.

A more robust, long-term solution would be to contribute to the vllm project to expose sampling metadata through a public API. For now, consider adding assertions or checks to verify the shape and content of the captured masks to catch potential issues early.


@dataclass
class Logprob:
logprob: float
Expand Down Expand Up @@ -137,6 +187,7 @@ class BaseVLLMInferenceEngine(InferenceEngineInterface):

def __init__(self, *args, bundle_indices: list = None, **kwargs):
setup_envvars_for_vllm(kwargs, bundle_indices)
_patch_vllm_sampler()
vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False)
if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2":
# https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11
Expand Down Expand Up @@ -169,6 +220,7 @@ def _create_engine(self, *args, **kwargs):

def _preprocess_prompts(self, input_batch: InferenceEngineInput):
"""Common prompt preprocessing logic."""
_reset_sampling_masks()
prompts = input_batch.get("prompts")
prompt_token_ids = input_batch.get("prompt_token_ids")
request_sampling_params = input_batch.get("sampling_params")
Expand Down Expand Up @@ -213,11 +265,24 @@ def _postprocess_outputs(self, outputs):
if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params

sampling_masks = None
masks = _consume_sampling_masks()
if masks:
sampling_masks = []
# TODO(devpatel): We don't have the request_ids in the sampling metadata, so order by index.
for output_idx in range(len(outputs)):
per_request = []
for step_mask in masks:
if output_idx < step_mask.shape[0]:
per_request.append(step_mask[output_idx].nonzero(as_tuple=False).squeeze(-1).tolist())
sampling_masks.append(per_request)
Comment on lines +272 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment "We don't have the request_ids in the sampling metadata, so order by index" indicates an implicit assumption about the order of masks matching the order of outputs. If vLLM processes requests asynchronously or reorders them internally, this assumption could lead to incorrect sampling masks being associated with the wrong outputs. This needs to be explicitly guaranteed by vLLM's behavior or handled more robustly (e.g., by associating request_id with sampling masks if possible).


return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs,
sampling_masks=sampling_masks,
)

def _get_engine(self):
Expand Down
6 changes: 5 additions & 1 deletion skyrl-train/skyrl_train/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
import numpy as np
from skyrl_train.distributed.ulysses.utils import ulysses_pad_and_slice_inputs, gather_outputs_and_unpad
from skyrl_train.utils.torch_utils import chunked_entropy_from_logits, logprobs_from_logits
from skyrl_train.utils.torch_utils import chunked_entropy_from_logits, logprobs_from_logits, apply_sampling_mask
from flash_attn.bert_padding import pad_input, unpad_input
from packaging.version import Version

Expand Down Expand Up @@ -267,6 +267,7 @@ def forward(
return_output=False,
compute_entropy=False,
entropy_requires_grad=True,
sampling_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Returns action log probs"""
position_ids = attention_mask.long().cumsum(-1) - 1
Expand Down Expand Up @@ -313,6 +314,9 @@ def forward(
logits_BSV = output["logits"]
logits_BSV.div_(temperature)

if sampling_mask:
logits_BSV = apply_sampling_mask(logits_BSV, sampling_mask)

# NOTE: this is slightly inaccurate with sample packing because last token from nth seq -> first token of n+1th seq loss is added.
log_probs = logprobs_from_logits(
logits_BSV,
Expand Down
10 changes: 10 additions & 0 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
rewards: List[List[float]] = generator_output["rewards"]
loss_masks: List[List[int]] = generator_output["loss_masks"]

# TODO(devpatel): test if handoff is working correctly for batching.
sampling_masks: Optional[List[List[List[int]]]] = generator_output.get("sampling_masks", None)
logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)

(
Expand All @@ -578,20 +580,23 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
rewards_tensor,
loss_masks_tensor,
rollout_logprobs_tensor,
sampling_masks_tensor,
) = convert_prompts_responses_to_batch_tensors(
self.tokenizer,
prompt_ids,
response_ids,
rewards,
loss_masks,
logprobs,
sampling_masks,
)
# sanity check for tis
if self.cfg.trainer.algorithm.use_tis:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
Expand All @@ -605,6 +610,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
if generator_output.get("is_last_step", None) is not None
else None
),
"sampling_mask": sampling_masks_tensor,
},
)
training_input.metadata = {"uids": uids}
Expand Down Expand Up @@ -861,6 +867,10 @@ def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
elif key == "loss_mask":
# ensures that padding tensors don't count towards the loss
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
elif key == "sampling_mask":
padding_tensor = torch.full(
(pad_size, *additional_dims), fill_value=-1, dtype=tensor.dtype, device=tensor.device
)
else:
# ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
# `pad_size` is guaranteed to be smaller than batch_size
Expand Down
3 changes: 3 additions & 0 deletions skyrl-train/skyrl_train/training_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ class TrainingInput(TypedDict, total=False):
kl: Float[torch.Tensor, "batch_size seq_len"]
rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]]
rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]]
sampling_mask: Optional[
Integer[torch.Tensor, "batch_size seq_len mask_size"]
] ## logits mask for sampling truncation, see https://arxiv.org/pdf/2512.02556 (3.1)
Comment on lines +336 to +338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment includes a link to an arXiv paper from the year 2512 (https://arxiv.org/pdf/2512.02556). This appears to be a typo. Could you please verify the correct paper and update the link? It's important for code clarity and future reference.

Suggested change
sampling_mask: Optional[
Integer[torch.Tensor, "batch_size seq_len mask_size"]
] ## logits mask for sampling truncation, see https://arxiv.org/pdf/2512.02556 (3.1)
sampling_mask: Optional[
Integer[torch.Tensor, "batch_size seq_len mask_size"]
] ## logits mask for sampling truncation, see e.g. https://arxiv.org/pdf/2305.12256



class TrainingInputBatch(TensorBatch[TrainingInput]):
Expand Down
32 changes: 32 additions & 0 deletions skyrl-train/skyrl_train/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,35 @@ def logprobs_from_logits_v2(
logprobs_labels.append(row_logprobs_labels)
logprobs_labels = torch.stack(logprobs_labels)
return logprobs_labels


# def compute_sampling_mask(
# logits: Float[torch.Tensor, "batch_size seqlen vocab_size"],
# top_k: int = None,
# top_p: float = None,
# min_p: float = None,
# ) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]:
# pass


def apply_sampling_mask(
logits: Float[torch.Tensor, "batch_size seqlen top_tokens"],
sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"],
) -> Float[torch.Tensor, "batch_size seqlen top_tokens"]:
Comment on lines +190 to +192
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for logits is Float[torch.Tensor, "batch_size seqlen top_tokens"]. However, the implementation assumes logits.shape[2] represents the full vocab_size when creating valid_token_mask (line 201). If logits has already been truncated to top_tokens (a subset of the full vocabulary), and sampling_mask contains indices from the full vocabulary, then sampling_mask indices could exceed logits.shape[2], leading to an out-of-bounds error during the scatter_ operation. Please clarify the expected shape of logits and ensure consistency between the type hint and the actual vocab_size used for masking.

def apply_sampling_mask(
    logits: Float[torch.Tensor, "batch_size seqlen vocab_size"],
    sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"],
) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]:


if sampling_mask is None:
return logits

batch_size, seqlen, vocab_size = logits.shape
device = logits.device

# TODO(devpatel) if we sort the tokens, then indices might be wrong
valid_token_mask = torch.zeros((batch_size, seqlen, vocab_size), dtype=torch.bool, device=device)
valid = sampling_mask >= 0
idx = sampling_mask.clamp(min=0)
valid_token_mask.scatter_(dim=2, index=idx, src=valid)

masked_logits = logits.clone()
masked_logits[~valid_token_mask] = float("-inf")

return masked_logits
2 changes: 2 additions & 0 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutpu
sequences = micro_batch["sequences"]
response_length = micro_batch.metadata["response_length"]
attention_mask = micro_batch["attention_mask"]
sampling_mask = micro_batch.get("sampling_mask", None)

with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
policy_logprob = self.model(
Expand All @@ -872,6 +873,7 @@ def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutpu
attention_mask,
return_output=False,
temperature=self.cfg.generator.sampling_params.temperature,
sampling_mask=sampling_mask,
)
policy_logprob = policy_logprob.to("cpu")
output = TrainingOutputBatch(
Expand Down
3 changes: 3 additions & 0 deletions skyrl-train/skyrl_train/workers/worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def batch_to_experience(batch: TrainingInputBatch):
action_mask=batch["response_mask"],
num_actions=batch.metadata["response_length"], # int
rollout_logprobs=batch["rollout_logprobs"] if "rollout_logprobs" in batch else None,
sampling_mask=(
batch["sampling_mask"] if "sampling_mask" in batch else None
), # shape: (batch_size, seq_len, top_tokens)
# additional info
# can be used to log metrics etc for micro-batches in the worker
info={},
Expand Down
Loading