-
Notifications
You must be signed in to change notification settings - Fork 243
[tx] chunked logprobs computation for memory efficiency #902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a memory-efficient chunked cross-entropy loss computation, which is a significant improvement for training large models. The implementation correctly avoids materializing the full logits tensor by processing lm_head in chunks using jax.lax.map. The fallback mechanism for train_unembed=True is well-handled, ensuring LoRA is applied to lm_head when required. The new tests for chunked loss are comprehensive, covering various batch and sequence lengths, and verifying numerical equivalence with the non-chunked approach. Overall, the changes are well-designed, correctly implemented, and thoroughly tested, contributing positively to the backend's efficiency and robustness.
Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied
9b53fca to
30f083a
Compare
- Resolve conflicts in llama3.py and qwen3.py - Integrate LogitsProcessor from main - Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When tie_word_embeddings=True, lm_head is a lambda from LoRAEmbed.T Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts the chunked logprobs feature while keeping the CausalLMBase refactoring. Changes removed: - _compute_chunked_logprobs method - lm_head_weight property - loss_chunk_size config - _train_unembed_mask runtime check - TestChunkedCrossEntropyLoss tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_compute_logits: compare with HuggingFace logits - test_compute_logprobs: verify equivalence with manual computation - Remove generation tests (belong in generator tests) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The prompt_logprobs computation was not passing adapter_indices to compute_logprobs, which would cause incorrect results when using LoRA adapters. Added test coverage for this case. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove get_lm_head_weight() abstract method (no longer needed) - Chunked path now uses lm_head() directly instead of raw matmul - Expand adapter_indices from [B] to [B*T] for per-token handling - Remove restriction that disabled chunking with adapter_indices Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove _train_unembed_mask tracking from JaxBackend - Simplify _model_forward to always pass adapter_indices to compute_logprobs - Fix chunked path to reshape hidden states to [chunk_size, 1, H] for LoRA compatibility Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of allocating [B*T] array via jnp.repeat, compute adapter indices per-chunk using only a [chunk_size] buffer. This reduces memory overhead significantly for long sequences. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Load HF model, get logits, save weights, delete HF model, then load our model. This avoids having both models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add loss_chunk_size and gradient_checkpointing to config in tests - Restore test_chunked_logprobs test that was lost during merge Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts: - tx/utils/logits_processor.py: Keep chunked logprobs implementation - tx/utils/generator.py: Keep left-padded sequence handling from main - tests/models/test_models_common.py: Keep chunked logprobs test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by loading and deleting models sequentially instead of having two models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by creating and deleting backends sequentially instead of having two in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a memory-efficient chunked log-probability computation, a significant optimization for training. The core logic is encapsulated within a new LogitsProcessorMixin, which conditionally routes to a chunked or non-chunked computation path based on the loss_chunk_size configuration, leveraging jax.lax.map for efficient token processing. A security audit focusing on memory safety, access control, injection, and LLM-specific safety found no high or critical severity vulnerabilities. The implementation is robust, well-tested, and appears well-executed for performance and memory efficiency.
Test coverage for: - Chunk boundary cases (padding, exact division, larger than total) - Adapter indices handling (None, per-batch, same for all) - Gradient checkpointing flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
@pcmoritz is attempting to deploy a commit to the Tyler's projects Team on Vercel. A member of the Team first needs to authorize it. |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a memory-efficient chunked computation for log probabilities, which is a valuable optimization. The implementation is clean, and the new logic is well-supported by a comprehensive set of unit and integration tests that cover various edge cases for chunking and adapter handling. I have one suggestion to enhance the robustness of the new chunked computation path to ensure its behavior is perfectly aligned with the non-chunked path.
Summary
_compute_chunked_logprobstoLogitsProcessorMixinthatcomputes log probabilities in chunks, avoiding materialization of
the full
[B*T, V]logits tensorloss_chunk_sizeconfig field (default 1024) to controlchunk size, set to 0 to disable
gradient_checkpointingconfig field to enablerecomputation during backward pass
on-the-fly per chunk
Test plan
test_chunked_logprobsverifies chunked and non-chunkedpaths produce identical results across different chunk sizes (8,
16, 32)
test_mixed_train_unembed_adaptersverifies chunked pathworks correctly with mixed train_unembed adapters
test_logits_processor.pyBenchmark results #891 (comment)