-
Notifications
You must be signed in to change notification settings - Fork 240
[algorithm] Sampling mask from Inference Engine Applies to Trainer #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import os | ||
| from typing import List, Any, Dict, Optional, TYPE_CHECKING | ||
| import threading | ||
|
|
||
| if TYPE_CHECKING: | ||
| from skyrl_train.weight_sync.transfer_strategy import WeightSyncInitInfo | ||
|
|
@@ -37,6 +38,55 @@ | |
| from packaging import version | ||
|
|
||
|
|
||
| # TODO(devpatel): This is a hack to get the sampling masks. We should find a better way to do this... fast | ||
| _sampling_masks = threading.local() | ||
| _sampler_patched = False | ||
|
|
||
|
|
||
| def _reset_sampling_masks() -> None: | ||
| _sampling_masks.items = [] | ||
|
|
||
|
|
||
| def _append_sampling_mask(mask: torch.Tensor) -> None: | ||
| if not hasattr(_sampling_masks, "items"): | ||
| _sampling_masks.items = [] | ||
| _sampling_masks.items.append(mask) | ||
|
|
||
|
|
||
| def _consume_sampling_masks() -> Optional[List[torch.Tensor]]: | ||
| masks = getattr(_sampling_masks, "items", None) | ||
| _sampling_masks.items = [] | ||
| return masks | ||
|
|
||
|
|
||
| def _patch_vllm_sampler() -> None: | ||
| global _sampler_patched | ||
| if _sampler_patched: | ||
| return | ||
| try: | ||
| from vllm.v1.sample.ops import topk_topp_sampler as sampler | ||
| except Exception as exc: | ||
| logger.warning(f"Could not import vLLM topk_topp_sampler op and/or Sampler class: {exc}") | ||
| return | ||
|
|
||
| original_top_k_top_p = sampler.apply_top_k_top_p | ||
| original_top_k_only = sampler.apply_top_k_only | ||
|
|
||
| def _wrapped_top_k_top_p(logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None) -> torch.Tensor: | ||
| output = original_top_k_top_p(logits, k, p) | ||
| _append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu()) | ||
| return output | ||
|
|
||
| def _wrapped_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: | ||
| output = original_top_k_only(logits, k) | ||
| _append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu()) | ||
| return output | ||
|
|
||
| sampler.apply_top_k_top_p = _wrapped_top_k_top_p | ||
| sampler.apply_top_k_only = _wrapped_top_k_only | ||
| _sampler_patched = True | ||
|
|
||
|
Comment on lines
+62
to
+88
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The monkey-patching of A more robust, long-term solution would be to contribute to the |
||
|
|
||
| @dataclass | ||
| class Logprob: | ||
| logprob: float | ||
|
|
@@ -137,6 +187,7 @@ class BaseVLLMInferenceEngine(InferenceEngineInterface): | |
|
|
||
| def __init__(self, *args, bundle_indices: list = None, **kwargs): | ||
| setup_envvars_for_vllm(kwargs, bundle_indices) | ||
| _patch_vllm_sampler() | ||
| vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False) | ||
| if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2": | ||
| # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 | ||
|
|
@@ -169,6 +220,7 @@ def _create_engine(self, *args, **kwargs): | |
|
|
||
| def _preprocess_prompts(self, input_batch: InferenceEngineInput): | ||
| """Common prompt preprocessing logic.""" | ||
| _reset_sampling_masks() | ||
| prompts = input_batch.get("prompts") | ||
| prompt_token_ids = input_batch.get("prompt_token_ids") | ||
| request_sampling_params = input_batch.get("sampling_params") | ||
|
|
@@ -213,11 +265,24 @@ def _postprocess_outputs(self, outputs): | |
| if len(response_logprobs) and response_logprobs[0] is None: | ||
| response_logprobs = None # hack: assume uniform sampling params | ||
|
|
||
| sampling_masks = None | ||
| masks = _consume_sampling_masks() | ||
| if masks: | ||
| sampling_masks = [] | ||
| # TODO(devpatel): We don't have the request_ids in the sampling metadata, so order by index. | ||
| for output_idx in range(len(outputs)): | ||
| per_request = [] | ||
| for step_mask in masks: | ||
| if output_idx < step_mask.shape[0]: | ||
| per_request.append(step_mask[output_idx].nonzero(as_tuple=False).squeeze(-1).tolist()) | ||
| sampling_masks.append(per_request) | ||
|
Comment on lines
+272
to
+278
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment "We don't have the request_ids in the sampling metadata, so order by index" indicates an implicit assumption about the order of |
||
|
|
||
| return InferenceEngineOutput( | ||
| responses=responses, | ||
| stop_reasons=stop_reasons, | ||
| response_ids=response_ids, | ||
| response_logprobs=response_logprobs, | ||
| sampling_masks=sampling_masks, | ||
| ) | ||
|
|
||
| def _get_engine(self): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -333,6 +333,9 @@ class TrainingInput(TypedDict, total=False): | |||||||||||||
| kl: Float[torch.Tensor, "batch_size seq_len"] | ||||||||||||||
| rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] | ||||||||||||||
| rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] | ||||||||||||||
| sampling_mask: Optional[ | ||||||||||||||
| Integer[torch.Tensor, "batch_size seq_len mask_size"] | ||||||||||||||
| ] ## logits mask for sampling truncation, see https://arxiv.org/pdf/2512.02556 (3.1) | ||||||||||||||
|
Comment on lines
+336
to
+338
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment includes a link to an arXiv paper from the year 2512 (
Suggested change
|
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TrainingInputBatch(TensorBatch[TrainingInput]): | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,3 +175,35 @@ def logprobs_from_logits_v2( | |
| logprobs_labels.append(row_logprobs_labels) | ||
| logprobs_labels = torch.stack(logprobs_labels) | ||
| return logprobs_labels | ||
|
|
||
|
|
||
| # def compute_sampling_mask( | ||
| # logits: Float[torch.Tensor, "batch_size seqlen vocab_size"], | ||
| # top_k: int = None, | ||
| # top_p: float = None, | ||
| # min_p: float = None, | ||
| # ) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]: | ||
| # pass | ||
|
|
||
|
|
||
| def apply_sampling_mask( | ||
| logits: Float[torch.Tensor, "batch_size seqlen top_tokens"], | ||
| sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"], | ||
| ) -> Float[torch.Tensor, "batch_size seqlen top_tokens"]: | ||
|
Comment on lines
+190
to
+192
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for def apply_sampling_mask(
logits: Float[torch.Tensor, "batch_size seqlen vocab_size"],
sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"],
) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]: |
||
|
|
||
| if sampling_mask is None: | ||
| return logits | ||
|
|
||
| batch_size, seqlen, vocab_size = logits.shape | ||
| device = logits.device | ||
|
|
||
| # TODO(devpatel) if we sort the tokens, then indices might be wrong | ||
| valid_token_mask = torch.zeros((batch_size, seqlen, vocab_size), dtype=torch.bool, device=device) | ||
| valid = sampling_mask >= 0 | ||
| idx = sampling_mask.clamp(min=0) | ||
| valid_token_mask.scatter_(dim=2, index=idx, src=valid) | ||
|
|
||
| masked_logits = logits.clone() | ||
| masked_logits[~valid_token_mask] = float("-inf") | ||
|
|
||
| return masked_logits | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dynamic patching of vLLM's internal sampler functions (
apply_top_k_top_pandapply_top_k_only) usingthreading.local()is a highly fragile approach. This relies on specific internal implementation details of vLLM, which are not part of its public API and can change without warning in future updates. This could lead to unexpected behavior, crashes, or incorrect sampling mask generation if vLLM's internal structure changes. While theTODOcomment acknowledges this is a hack, it poses a significant risk to the maintainability and stability of the system. It would be preferable to find a more robust and officially supported way to extract this information from vLLM, or to encapsulate this hack more thoroughly with version checks and fallback mechanisms.