Skip to content

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 22, 2026

Motivation

Different scenarios have different requirements for logits computation:

  • Training: Only need logprobs of target tokens, not full [B, T, V] logits
  • Prefill: Only need last position logits for sampling, plus optional prompt logprobs
  • Decode: Need full [B, 1, V] logits for sampling transforms (temperature, top_k, top_p)

By returning hidden_states and letting callers decide what to compute, we avoid unnecessary materialization of
full vocabulary logits. This also enables future chunked logprobs computation for memory efficiency.

Changes

  • Add LogitsProcessorMixin in tx/utils/logits_processor.py with:
    • Abstract get_lm_head() method
    • compute_logits(), compute_logprobs(), logits_to_logprobs()
  • Remove logits from CausalLMOutput - callers now explicitly call compute_logits()
  • Remove skip_prompt_logits parameter from model forward
  • Add validation: train_unembed=True incompatible with tie_word_embeddings=True

Code consolidation

Logits/logprobs computation was scattered across multiple locations:

  • tx/layers/logits_processor.py (deleted)
  • tx/utils/generator.py (compute_prompt_logprobs() helper, manual log_softmax)
  • tx/tinker/backends/jax.py (manual logprobs from logits)

Now unified in LogitsProcessorMixin with a single implementation.

Bug fix

The prompt_logprobs computation was not passing adapter_indices to compute_logprobs, causing incorrect results
with LoRA adapters. Added test coverage.

raulchen and others added 24 commits January 20, 2026 18:55
Compute lm_head projection in chunks to avoid materializing the full
[B*T, V] logits tensor. Key changes:

- Add compute_logits flag to model.__call__ (skip lm_head when False)
- Add lm_head weight to CausalLMOutput for external computation
- Implement chunked logprobs with jax.lax.map (default chunk_size=1024)
- Add loss_chunk_size config option

Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor.
For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0

The chunked cross-entropy path computes logits via direct matmul with
lm_head weight, bypassing LoRA adapters. This is incorrect when
train_unembed=True since LoRA should be applied to lm_head.

Changes:
- Rename is_training to skip_logits for clarity
- Add _use_chunked_loss flag to backend
- Automatically switch to non-chunked mode when:
  - train_unembed=True (requires LoRA on lm_head)
  - loss_chunk_size <= 0 (config-based disable)
- Non-chunked path uses pre-computed logits with LoRA correctly applied
- Resolve conflicts in llama3.py and qwen3.py
- Integrate LogitsProcessor from main
- Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths
- Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers
- Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods:
  compute_logits(), compute_logprobs(), logits_to_logprobs()
- Model forward() returns only hidden_states (removed logits computation)
- Simplified CausalLMOutput: removed logits and lm_head fields
- Generator uses LogitsProcessor for all logits/logprobs computation
- Backend uses LogitsProcessor.compute_logprobs() with chunking
- Updated tests to use new LogitsProcessor API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods
- Models expose wrapper methods instead of direct LogitsProcessor access
- Update generator and jax.py backend to use model methods
- LogitsProcessor is now internal implementation detail

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array
- Check at runtime if any adapter in batch needs LoRA on lm_head
- Use jax.lax.cond to choose chunked vs non-chunked path
- Handle adapter reuse correctly (reset mask on delete)
- Remove unused _use_chunked_loss flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class
- Subclasses explicitly call CausalLMBase.__init__(self, lm_head)
- Fix test to support multiple adapters for mixed train_unembed test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When tie_word_embeddings=True, lm_head is a lambda from LoRAEmbed.T

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts the chunked logprobs feature while keeping the CausalLMBase
refactoring. Changes removed:
- _compute_chunked_logprobs method
- lm_head_weight property
- loss_chunk_size config
- _train_unembed_mask runtime check
- TestChunkedCrossEntropyLoss tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_compute_logits: compare with HuggingFace logits
- test_compute_logprobs: verify equivalence with manual computation
- Remove generation tests (belong in generator tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a well-executed refactoring that improves memory efficiency and code structure by separating the model's forward pass from the logits computation. The introduction of the CausalLMBase class effectively centralizes logic for computing logits and log probabilities, leading to cleaner and more maintainable code. The changes have been consistently applied across the different model implementations, the generator, and the training backend. I've identified a high-severity issue concerning LoRA with tied embeddings where adapter weights are not being applied to the final output projection. Additionally, there's a minor type-safety concern in one of the test files. Apart from these points, this is a solid pull request with clear benefits.

if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens.T
if config.tie_word_embeddings:
lm_head = self.model.embed_tokens.T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There appears to be a bug here when using tied embeddings with LoRA adapters on the embed_tokens layer. The self.model.embed_tokens.T property returns a lambda function for the lm_head that computes hidden_states @ self.embedding.value.T. This uses the base embedding weights but ignores the LoRA adapter weights, as indicated by the TODO in LoRAEmbed.T. Consequently, the LoRA adaptation is not applied during the final projection to logits, leading to incorrect outputs. This should be addressed to ensure correctness for LoRA fine-tuning with tied embeddings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an existing bug / TODO.
added an assertion to prevent enabling tie_word_embedding and train_unembed together.

@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an excellent pull request that significantly refactors the logits computation logic for better flexibility and memory efficiency. The introduction of the CausalLMBase class provides a clean and consolidated API for computing logits and logprobs, effectively removing duplication and simplifying the model's forward pass by eliminating the LogitsProcessor and the skip_prompt_logits parameter.

The changes are consistently applied across all models, tests, and utility functions. The generator logic is now more explicit about when and what kind of logits are computed, which is a great improvement for both prefill and decode steps. The adaptation in the JAX backend to use compute_logprobs within the gradient-checkpointed function is a smart move for memory optimization during training.

The tests have been thoughtfully updated to validate the new implementation, and the new check in init_lora_adapter to prevent misconfiguration with tied embeddings is a valuable addition. Overall, this is a high-quality, well-executed refactoring that improves the codebase's structure and performance characteristics.

raulchen and others added 2 commits January 22, 2026 12:55
The prompt_logprobs computation was not passing adapter_indices to
compute_logprobs, which would cause incorrect results when using LoRA
adapters. Added test coverage for this case.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen raulchen changed the title [tx] refactor: move logits computation from model forward to CausalLMBase methods [tx] refactor: move logits computation LogitsProcessorMixin and fix prompt_logprobs Jan 22, 2026
@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring of the logits and log-probabilities computation. By centralizing this logic into a LogitsProcessorMixin, the changes improve code clarity, remove redundancy, and fix a critical bug related to prompt_logprobs with LoRA adapters. The addition of a new test case to cover this fix and a validation check for incompatible configurations are also valuable improvements. My review includes one suggestion to optimize the prefill step to avoid redundant computations when prompt log-probabilities are requested.

raulchen and others added 5 commits January 22, 2026 15:39
- Force float32 for our model to match HF for accurate comparison
- Use per-model tolerances: 3e-2 for llama3, 5e-4 for qwen3
  (llama3 has larger numerical differences, see test_llama3.py)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen raulchen changed the title [tx] refactor: move logits computation LogitsProcessorMixin and fix prompt_logprobs [tx] move logits computation LogitsProcessorMixin and fix prompt_logprobs Jan 23, 2026
Load HF model, get logits, save weights, delete HF model, then load
our model. This avoids having both models in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen raulchen force-pushed the refactor-logits-compute branch from 561dbd7 to a82cd53 Compare January 23, 2026 02:06
@pcmoritz pcmoritz merged commit d55c05d into NovaSky-AI:main Jan 23, 2026
3 of 4 checks passed
pcmoritz added a commit that referenced this pull request Jan 23, 2026
pcmoritz added a commit that referenced this pull request Jan 23, 2026
)

This is in preparation for merging
#879.
Updated version of #918, rebased
on top of #919

The reason we do this is, while it is slightly less natural for sampling
prefill to be left aligned (right padded), it makes things more uniform,
since that's the same alignment used during training, and the jax cudnn
flash attention doesn't currently support left padding as far as we are
aware. It is a small change and actually a little simpler, so what is
not to like about it. Plus the work on the PR uncovered a bug in the
prompt logprobs.
@raulchen raulchen deleted the refactor-logits-compute branch January 23, 2026 16:09
tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Jan 25, 2026
…ovaSky-AI#927)

This is in preparation for merging
NovaSky-AI#879.
Updated version of NovaSky-AI#918, rebased
on top of NovaSky-AI#919

The reason we do this is, while it is slightly less natural for sampling
prefill to be left aligned (right padded), it makes things more uniform,
since that's the same alignment used during training, and the jax cudnn
flash attention doesn't currently support left padding as far as we are
aware. It is a small change and actually a little simpler, so what is
not to like about it. Plus the work on the PR uncovered a bug in the
prompt logprobs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants