-
Notifications
You must be signed in to change notification settings - Fork 254
[tx] move logits computation LogitsProcessorMixin and fix prompt_logprobs #919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied
- Resolve conflicts in llama3.py and qwen3.py - Integrate LogitsProcessor from main - Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When tie_word_embeddings=True, lm_head is a lambda from LoRAEmbed.T Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts the chunked logprobs feature while keeping the CausalLMBase refactoring. Changes removed: - _compute_chunked_logprobs method - lm_head_weight property - loss_chunk_size config - _train_unembed_mask runtime check - TestChunkedCrossEntropyLoss tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_compute_logits: compare with HuggingFace logits - test_compute_logprobs: verify equivalence with manual computation - Remove generation tests (belong in generator tests) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is a well-executed refactoring that improves memory efficiency and code structure by separating the model's forward pass from the logits computation. The introduction of the CausalLMBase class effectively centralizes logic for computing logits and log probabilities, leading to cleaner and more maintainable code. The changes have been consistently applied across the different model implementations, the generator, and the training backend. I've identified a high-severity issue concerning LoRA with tied embeddings where adapter weights are not being applied to the final output projection. Additionally, there's a minor type-safety concern in one of the test files. Apart from these points, this is a solid pull request with clear benefits.
skyrl-tx/tx/models/llama3.py
Outdated
| if self.config.tie_word_embeddings: | ||
| self.lm_head = self.model.embed_tokens.T | ||
| if config.tie_word_embeddings: | ||
| lm_head = self.model.embed_tokens.T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a bug here when using tied embeddings with LoRA adapters on the embed_tokens layer. The self.model.embed_tokens.T property returns a lambda function for the lm_head that computes hidden_states @ self.embedding.value.T. This uses the base embedding weights but ignores the LoRA adapter weights, as indicated by the TODO in LoRAEmbed.T. Consequently, the LoRA adaptation is not applied during the final projection to logits, leading to incorrect outputs. This should be addressed to ensure correctness for LoRA fine-tuning with tied embeddings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an existing bug / TODO.
added an assertion to prevent enabling tie_word_embedding and train_unembed together.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is an excellent pull request that significantly refactors the logits computation logic for better flexibility and memory efficiency. The introduction of the CausalLMBase class provides a clean and consolidated API for computing logits and logprobs, effectively removing duplication and simplifying the model's forward pass by eliminating the LogitsProcessor and the skip_prompt_logits parameter.
The changes are consistently applied across all models, tests, and utility functions. The generator logic is now more explicit about when and what kind of logits are computed, which is a great improvement for both prefill and decode steps. The adaptation in the JAX backend to use compute_logprobs within the gradient-checkpointed function is a smart move for memory optimization during training.
The tests have been thoughtfully updated to validate the new implementation, and the new check in init_lora_adapter to prevent misconfiguration with tied embeddings is a valuable addition. Overall, this is a high-quality, well-executed refactoring that improves the codebase's structure and performance characteristics.
The prompt_logprobs computation was not passing adapter_indices to compute_logprobs, which would cause incorrect results when using LoRA adapters. Added test coverage for this case. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring of the logits and log-probabilities computation. By centralizing this logic into a LogitsProcessorMixin, the changes improve code clarity, remove redundancy, and fix a critical bug related to prompt_logprobs with LoRA adapters. The addition of a new test case to cover this fix and a validation check for incompatible configurations are also valuable improvements. My review includes one suggestion to optimize the prefill step to avoid redundant computations when prompt log-probabilities are requested.
- Force float32 for our model to match HF for accurate comparison - Use per-model tolerances: 3e-2 for llama3, 5e-4 for qwen3 (llama3 has larger numerical differences, see test_llama3.py) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Load HF model, get logits, save weights, delete HF model, then load our model. This avoids having both models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
561dbd7 to
a82cd53
Compare
Follow up to #919 to fix the CI
) This is in preparation for merging #879. Updated version of #918, rebased on top of #919 The reason we do this is, while it is slightly less natural for sampling prefill to be left aligned (right padded), it makes things more uniform, since that's the same alignment used during training, and the jax cudnn flash attention doesn't currently support left padding as far as we are aware. It is a small change and actually a little simpler, so what is not to like about it. Plus the work on the PR uncovered a bug in the prompt logprobs.
…ovaSky-AI#927) This is in preparation for merging NovaSky-AI#879. Updated version of NovaSky-AI#918, rebased on top of NovaSky-AI#919 The reason we do this is, while it is slightly less natural for sampling prefill to be left aligned (right padded), it makes things more uniform, since that's the same alignment used during training, and the jax cudnn flash attention doesn't currently support left padding as far as we are aware. It is a small change and actually a little simpler, so what is not to like about it. Plus the work on the PR uncovered a bug in the prompt logprobs.
Motivation
Different scenarios have different requirements for logits computation:
[B, T, V]logits[B, 1, V]logits for sampling transforms (temperature, top_k, top_p)By returning
hidden_statesand letting callers decide what to compute, we avoid unnecessary materialization offull vocabulary logits. This also enables future chunked logprobs computation for memory efficiency.
Changes
LogitsProcessorMixinintx/utils/logits_processor.pywith:get_lm_head()methodcompute_logits(),compute_logprobs(),logits_to_logprobs()logitsfromCausalLMOutput- callers now explicitly callcompute_logits()skip_prompt_logitsparameter from model forwardtrain_unembed=Trueincompatible withtie_word_embeddings=TrueCode consolidation
Logits/logprobs computation was scattered across multiple locations:
tx/layers/logits_processor.py(deleted)tx/utils/generator.py(compute_prompt_logprobs()helper, manual log_softmax)tx/tinker/backends/jax.py(manual logprobs from logits)Now unified in
LogitsProcessorMixinwith a single implementation.Bug fix
The prompt_logprobs computation was not passing adapter_indices to compute_logprobs, causing incorrect results
with LoRA adapters. Added test coverage.