Skip to content

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 20, 2026

Summary

  • Add _compute_chunked_logprobs to LogitsProcessorMixin that
    computes log probabilities in chunks, avoiding materialization of
    the full [B*T, V] logits tensor
  • Add loss_chunk_size config field (default 1024) to control
    chunk size, set to 0 to disable
  • Add gradient_checkpointing config field to enable
    recomputation during backward pass
  • Chunked path supports LoRA by computing adapter indices
    on-the-fly per chunk

Test plan

  • test_chunked_logprobs verifies chunked and non-chunked
    paths produce identical results across different chunk sizes (8,
    16, 32)
  • test_mixed_train_unembed_adapters verifies chunked path
    works correctly with mixed train_unembed adapters
  • unit tests in test_logits_processor.py

Benchmark results #891 (comment)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a memory-efficient chunked cross-entropy loss computation, which is a significant improvement for training large models. The implementation correctly avoids materializing the full logits tensor by processing lm_head in chunks using jax.lax.map. The fallback mechanism for train_unembed=True is well-handled, ensuring LoRA is applied to lm_head when required. The new tests for chunked loss are comprehensive, covering various batch and sequence lengths, and verifying numerical equivalence with the non-chunked approach. Overall, the changes are well-designed, correctly implemented, and thoroughly tested, contributing positively to the backend's efficiency and robustness.

raulchen and others added 6 commits January 20, 2026 18:55
Compute lm_head projection in chunks to avoid materializing the full
[B*T, V] logits tensor. Key changes:

- Add compute_logits flag to model.__call__ (skip lm_head when False)
- Add lm_head weight to CausalLMOutput for external computation
- Implement chunked logprobs with jax.lax.map (default chunk_size=1024)
- Add loss_chunk_size config option

Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor.
For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0

The chunked cross-entropy path computes logits via direct matmul with
lm_head weight, bypassing LoRA adapters. This is incorrect when
train_unembed=True since LoRA should be applied to lm_head.

Changes:
- Rename is_training to skip_logits for clarity
- Add _use_chunked_loss flag to backend
- Automatically switch to non-chunked mode when:
  - train_unembed=True (requires LoRA on lm_head)
  - loss_chunk_size <= 0 (config-based disable)
- Non-chunked path uses pre-computed logits with LoRA correctly applied
raulchen and others added 20 commits January 21, 2026 13:28
- Resolve conflicts in llama3.py and qwen3.py
- Integrate LogitsProcessor from main
- Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths
- Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers
- Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods:
  compute_logits(), compute_logprobs(), logits_to_logprobs()
- Model forward() returns only hidden_states (removed logits computation)
- Simplified CausalLMOutput: removed logits and lm_head fields
- Generator uses LogitsProcessor for all logits/logprobs computation
- Backend uses LogitsProcessor.compute_logprobs() with chunking
- Updated tests to use new LogitsProcessor API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods
- Models expose wrapper methods instead of direct LogitsProcessor access
- Update generator and jax.py backend to use model methods
- LogitsProcessor is now internal implementation detail

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array
- Check at runtime if any adapter in batch needs LoRA on lm_head
- Use jax.lax.cond to choose chunked vs non-chunked path
- Handle adapter reuse correctly (reset mask on delete)
- Remove unused _use_chunked_loss flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class
- Subclasses explicitly call CausalLMBase.__init__(self, lm_head)
- Fix test to support multiple adapters for mixed train_unembed test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When tie_word_embeddings=True, lm_head is a lambda from LoRAEmbed.T

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts the chunked logprobs feature while keeping the CausalLMBase
refactoring. Changes removed:
- _compute_chunked_logprobs method
- lm_head_weight property
- loss_chunk_size config
- _train_unembed_mask runtime check
- TestChunkedCrossEntropyLoss tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_compute_logits: compare with HuggingFace logits
- test_compute_logprobs: verify equivalence with manual computation
- Remove generation tests (belong in generator tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The prompt_logprobs computation was not passing adapter_indices to
compute_logprobs, which would cause incorrect results when using LoRA
adapters. Added test coverage for this case.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
raulchen and others added 11 commits January 22, 2026 17:24
- Remove get_lm_head_weight() abstract method (no longer needed)
- Chunked path now uses lm_head() directly instead of raw matmul
- Expand adapter_indices from [B] to [B*T] for per-token handling
- Remove restriction that disabled chunking with adapter_indices

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove _train_unembed_mask tracking from JaxBackend
- Simplify _model_forward to always pass adapter_indices to compute_logprobs
- Fix chunked path to reshape hidden states to [chunk_size, 1, H] for LoRA compatibility

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of allocating [B*T] array via jnp.repeat, compute adapter
indices per-chunk using only a [chunk_size] buffer. This reduces
memory overhead significantly for long sequences.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Load HF model, get logits, save weights, delete HF model, then load
our model. This avoids having both models in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add loss_chunk_size and gradient_checkpointing to config in tests
- Restore test_chunked_logprobs test that was lost during merge

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts:
- tx/utils/logits_processor.py: Keep chunked logprobs implementation
- tx/utils/generator.py: Keep left-padded sequence handling from main
- tests/models/test_models_common.py: Keep chunked logprobs test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by loading and deleting models sequentially
instead of having two models in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by creating and deleting backends
sequentially instead of having two in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a memory-efficient chunked log-probability computation, a significant optimization for training. The core logic is encapsulated within a new LogitsProcessorMixin, which conditionally routes to a chunked or non-chunked computation path based on the loss_chunk_size configuration, leveraging jax.lax.map for efficient token processing. A security audit focusing on memory safety, access control, injection, and LLM-specific safety found no high or critical severity vulnerabilities. The implementation is robust, well-tested, and appears well-executed for performance and memory efficiency.

raulchen and others added 6 commits January 23, 2026 08:46
Test coverage for:
- Chunk boundary cases (padding, exact division, larger than total)
- Adapter indices handling (None, per-batch, same for all)
- Gradient checkpointing flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@vercel
Copy link

vercel bot commented Jan 25, 2026

@pcmoritz is attempting to deploy a commit to the Tyler's projects Team on Vercel.

A member of the Team first needs to authorize it.

@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a memory-efficient chunked computation for log probabilities, which is a valuable optimization. The implementation is clean, and the new logic is well-supported by a comprehensive set of unit and integration tests that cover various edge cases for chunking and adapter handling. I have one suggestion to enhance the robustness of the new chunked computation path to ensure its behavior is perfectly aligned with the non-chunked path.

@pcmoritz pcmoritz merged commit e139c05 into NovaSky-AI:main Jan 26, 2026
4 of 6 checks passed
@pcmoritz pcmoritz mentioned this pull request Jan 27, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants