-
Notifications
You must be signed in to change notification settings - Fork 235
[tx] DeepseekV3 implementation #889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the JAX implementation for the DeepseekV3 model. The implementation is comprehensive and covers the model's unique features like Multi-Head Latent Attention and Mixture of Experts with shared experts. The code is well-structured.
My review focuses on a critical bug that will prevent the model from running, along with some suggestions to improve maintainability by reducing code duplication and avoiding magic numbers. Addressing these points will make the implementation more robust and easier to maintain.
skyrl-tx/tx/models/deepseekv3.py
Outdated
| # Precompute RoPE frequencies | ||
| # qk_rope_head_dim = config.qk_rope_head_dim | ||
| # original_seq_len = getattr(config, "original_seq_len", config.max_position_embeddings) | ||
| # rope_factor = getattr(config, "rope_factor", 1.0) | ||
| # beta_fast = getattr(config, "beta_fast", 32) | ||
| # beta_slow = getattr(config, "beta_slow", 1) | ||
|
|
||
| # TODO: Swap out like llama's rope? | ||
| # self.freqs_cis = precompute_freqs_cis( | ||
| # dim=qk_rope_head_dim, | ||
| # max_seq_len=config.max_position_embeddings, | ||
| # rope_theta=config.rope_theta, | ||
| # original_seq_len=original_seq_len, | ||
| # rope_factor=rope_factor, | ||
| # beta_fast=beta_fast, | ||
| # beta_slow=beta_slow, | ||
| # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for precomputing RoPE frequencies is commented out, but self.freqs_cis is used in DeepseekV3Model.__call__ at line 571. This will raise an AttributeError at runtime.
Looking at the DeepseekV3MLA implementation, the freqs_cis parameter is not used. Instead, apply_rope is called, which computes the frequencies on the fly.
To fix this, you should remove the freqs_cis parameter from the entire call chain, as it appears to be unused. This involves:
- Removing
freqs_cis: jax.Arrayfrom the signature ofDeepseekV3MLA.__call__. - Removing
freqs_cis: jax.Arrayfrom the signature ofDeepseekV3DecoderLayer.__call__. - Removing the
freqs_cis=self.freqs_cisargument from thelayer()call withinDeepseekV3Model.__call__.
This will resolve the crash and align the code with the current apply_rope implementation. You can then address the TODO about swapping the RoPE implementation in a separate change.
skyrl-tx/tx/models/deepseekv3.py
Outdated
| ) | ||
|
|
||
| # Bias only for specific model sizes (7168 hidden_size in original) | ||
| self.use_bias = config.hidden_size == 7168 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding the magic number 7168 to determine self.use_bias makes the code brittle and less maintainable. If a new model variant is introduced that also requires this bias, this line would need to be updated. A better approach would be to introduce a dedicated boolean flag in the DeepseekV3Config, such as use_router_bias, to control this behavior explicitly.
skyrl-tx/tx/models/deepseekv3.py
Outdated
| # Bias only for specific model sizes (7168 hidden_size in original) | ||
| self.use_bias = config.hidden_size == 7168 | ||
| if self.use_bias: | ||
| from tx.layers.util import Param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import statement is located inside a conditional block within __init__. According to PEP 8, all imports should be at the top of the file. This improves code readability and avoids potential circular import issues or unexpected behavior. Please move from tx.layers.util import Param to the top of the file with the other imports.
skyrl-tx/tx/models/deepseekv3.py
Outdated
| class DeepseekV3SharedMLP(nnx.Module): | ||
| """Always active shared experts.""" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DeepseekV3SharedMLP class is nearly identical to DeepseekV3MLP, with the only significant difference being the intermediate_size. This creates code duplication, which can make maintenance harder.
To improve this, consider refactoring them into a single, more generic MLP class (e.g., SwiGLU) that accepts intermediate_size as a parameter in its __init__ method. You could then instantiate this class with config.intermediate_size for the standard MLP and with the calculated shared_inter_dim for the shared MLP part.
|
@pcmoritz The PR is open for reviews now In the first test case I've added a todo - there seems to be some kind of drift which requires absolute tolerance to be around ~6e-3 for tests to pass. I'll investigate a little more, nothing seemed to have caught my eye so far |
|
Fixed the source of the drift, there was a default config mismatch |
|
This is awesome! Have you already gotten some end-to-end training working with it? It would be great to add one to https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-tx/README.md. If you haven't I'm also more than happy to help with it :) |
|
Looks like the tests are failing, unable to replicate this on my machine somehow. Some qwen tests also seem to be failing - is this expected? Have not been able to train end-to-end yet, will give it a shot over the weekend! (with any further fixes required). Added a task for it in the PR description |
|
Failing tests root cause: Huggingface outputs are not consistent between MacOS and Ubuntu (Accelerate vs MKL) LinuxOS: Linux 6.8.0-90-genericMachine: x86_64Python: 3.12.12PyTorch: 2.10.0+cu128PyTorch BLAS: mklCUDA available: FalseTransformers: 4.57.6 DEEPSEEK V3 TEST HF hidden_states[-1] first 10 values (sample 0, pos 0): MacosOS: Darwin 25.2.0Machine: arm64PyTorch BLAS: accelerateTransformers: 4.57.6 DEEPSEEK V3 TEST Bumping thresholds |
|
@tanmaysachan is attempting to deploy a commit to the Tyler's projects Team on Vercel. A member of the Team first needs to authorize it. |
- Add LogitsProcessorMixin to DeepseekV3ForCausalLM - Add get_lm_head() method for logits computation - Fix broken compute_positions import - Fix init_lora_adapter to handle n_routed_experts attribute - Add test_deepseekv3_lora_training.py with MoE rank normalization tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
23cfe46 to
7208b55
Compare
|
Accidental deployment attempt due to rebase ^ |
|
End-to-end training successfull on an A100. /api/v1/healthz -> {"status":"ok"} Added GPU tests (need anyscale creds to run) GPU tests on A100:
|
Addresses #865