diff --git a/skyrl-tx/README.md b/skyrl-tx/README.md index eae979497..3092c9f67 100644 --- a/skyrl-tx/README.md +++ b/skyrl-tx/README.md @@ -19,7 +19,7 @@ SkyRL tx is an open-source library that implements a backend for the [Tinker API - **Multi-User LoRA Support** — Efficient GPU sharing across users with individual adapters - **SFT & RL Support** — Supervised fine-tuning and reinforcement learning with PPO and custom loss functions - **Multi-Node Training** — FSDP and tensor parallelism for distributed training -- **Multiple Model Architectures** — Support for Qwen3 (dense & MoE) and Llama 3 +- **Multiple Model Architectures** — Support for Qwen3 (dense & MoE), Llama 3, and DeepSeek V3 - **External Inference Engine** — Optional vLLM integration for optimized inference - **Production Ready** — PostgreSQL support, cloud storage checkpoints, and database migrations @@ -229,6 +229,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 uv run --extra gpu --extra tinker -m tx.tinker.api | Qwen3 Dense Models | ✅ | | Qwen3 MoE Models | ✅ | | Llama 3 Models | ✅ | +| DeepSeek V3 Models | ✅ | | Multi-User LoRA | ✅ | | LoRA (all layers) | ✅ | | Forward/Backward | ✅ | diff --git a/skyrl-tx/tests/models/test_deepseekv3.py b/skyrl-tx/tests/models/test_deepseekv3.py new file mode 100644 index 000000000..188917e12 --- /dev/null +++ b/skyrl-tx/tests/models/test_deepseekv3.py @@ -0,0 +1,199 @@ +import os +import tempfile + +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig +from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE as HFDeepseekV3MoE +from tx.layers.lora import LoRAMixin +from tx.models.configs import DeepseekV3Config +from tx.models.deepseekv3 import DeepseekV3ForCausalLM, DeepseekV3MoE +from tx.utils.models import load_safetensors + + +@pytest.mark.parametrize("tp", [1, 2]) +def test_deepseekv3(tp: int): + if not jax._src.xla_bridge.backends_are_initialized(): + jax.config.update("jax_num_cpu_devices", 2) + + if tp > 1 and os.getenv("CI"): + pytest.skip("TP > 1 currently runs out of memory in the CI") + + model_name = "yujiepan/deepseek-v3-tiny-random" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, attn_implementation="eager", use_safetensors=True, trust_remote_code=True + ) + + inputs = ["The capital of France is", "The most popular programming language is"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + with torch.no_grad(): + hf_outputs = hf_model( + batch.input_ids, + attention_mask=batch.attention_mask, + output_hidden_states=True, + return_dict=True, + use_cache=False, + ) + + # Save the HF model checkpoint so we can load our model from it + with tempfile.TemporaryDirectory() as tmp: + hf_model.save_pretrained(tmp, safe_serialization=True) + + base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) + config = DeepseekV3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True) + mesh = jax.make_mesh( + (1, tp), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_safetensors(tmp, config, model) + + outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) + + assert outputs.hidden_states is not None + assert np.allclose(hf_outputs.hidden_states[0], outputs.hidden_states[0], rtol=1e-6) + assert np.allclose(hf_outputs.hidden_states[1], outputs.hidden_states[1], rtol=1e-3, atol=1e-3) + assert np.allclose(hf_outputs.hidden_states[-1], outputs.hidden_states[-1], rtol=3e-2, atol=6e-2) + + +def load_moe_base_weights(jax_moe_layer: DeepseekV3MoE, hf_moe_layer: HFDeepseekV3MoE) -> None: + """Load base weights from HF MoE layer to JAX MoE layer.""" + jax_moe_layer.gate.weight[:] = hf_moe_layer.gate.weight.detach().numpy().T + jax_moe_layer.gate.e_score_correction_bias[:] = hf_moe_layer.gate.e_score_correction_bias.detach().numpy() + + for i, expert in enumerate(hf_moe_layer.experts): + jax_moe_layer.experts.gate_proj.weight[i, :, :] = expert.gate_proj.weight.detach().numpy().T + jax_moe_layer.experts.up_proj.weight[i, :, :] = expert.up_proj.weight.detach().numpy().T + jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T + + jax_moe_layer.shared_experts.gate_proj.kernel[:] = hf_moe_layer.shared_experts.gate_proj.weight.detach().numpy().T + jax_moe_layer.shared_experts.up_proj.kernel[:] = hf_moe_layer.shared_experts.up_proj.weight.detach().numpy().T + jax_moe_layer.shared_experts.down_proj.kernel[:] = hf_moe_layer.shared_experts.down_proj.weight.detach().numpy().T + + +def test_deepseekv3_moe_layer(): + model_name = "yujiepan/deepseek-v3-tiny-random" + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + base_config = PretrainedConfig.from_pretrained(model_name) + config = DeepseekV3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True) + + # Initial deepseek layers don't have MoE + hf_moe_layer = hf_model.model.layers[1].mlp + torch.manual_seed(42) + x = torch.randn(4, 2, config.hidden_size) + with torch.no_grad(): + hf_expert_output = hf_moe_layer.forward(x) + + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + moe_layer = DeepseekV3MoE(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_moe_base_weights(moe_layer, hf_moe_layer) + + jax_expert_output = moe_layer(x.numpy()) + + # Higher tolerance due to cross-platform BLAS differences + assert np.allclose(hf_expert_output.detach().numpy(), jax_expert_output, rtol=6e-3, atol=6e-3) + + +def load_lora_weights( + jax_module: LoRAMixin, + adapter_idx: int, + lora_A_weights: np.ndarray, + lora_B_weights: np.ndarray, + scaling: float, + rank: int, +) -> None: + """Load LoRA weights from numpy arrays to JAX module.""" + assert ( + jax_module.lora_A is not None + and jax_module.lora_B is not None + and jax_module.lora_scaling is not None + and jax_module.lora_ranks is not None + ) + jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(lora_A_weights)) + jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(lora_B_weights)) + jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[adapter_idx].set(scaling) + jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank) + + +def test_deepseekv3_moe_layer_lora(): + """Test MoE LoRA by merging adapter into base weights and comparing outputs.""" + model_name = "yujiepan/deepseek-v3-tiny-random" + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + base_config = PretrainedConfig.from_pretrained(model_name) + config = DeepseekV3Config(base_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True) + + hf_moe_layer = hf_model.model.layers[1].mlp + x = torch.randn(3, 4, config.hidden_size) + + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + moe_layer = DeepseekV3MoE(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_moe_base_weights(moe_layer, hf_moe_layer) + + # Set LoRA weights for all adapters + rng = np.random.default_rng(42) + scaling = 2.0 + rank = config.max_lora_rank + for adapter_idx in range(config.max_lora_adapters): + for proj in [moe_layer.experts.gate_proj, moe_layer.experts.up_proj, moe_layer.experts.down_proj]: + assert proj.lora_A is not None and proj.lora_B is not None + lora_A = rng.normal(0, 1.0, proj.lora_A[...].shape[1:]) + lora_B = rng.normal(0, 1.0, proj.lora_B[...].shape[1:]) + load_lora_weights(proj, adapter_idx, lora_A, lora_B, scaling, rank) + + # Test with different adapters per sample + adapter_indices = jnp.array([0, 2, 1]) + output_with_lora = moe_layer(x.numpy(), adapter_indices=adapter_indices) + + # Test each sample by comparing with merged weights for its adapter + for sample_idx in range(len(adapter_indices)): + adapter_idx = int(adapter_indices[sample_idx]) + + # Create merged model by adding LoRA weights to base weights + moe_layer_merged = DeepseekV3MoE(config, dtype=jnp.float32, rngs=nnx.Rngs(1 + adapter_idx)) + + # Copy router weights + moe_layer_merged.gate.weight[:] = moe_layer.gate.weight[:] + moe_layer_merged.gate.e_score_correction_bias[:] = moe_layer.gate.e_score_correction_bias[:] + + # Copy shared experts weights + moe_layer_merged.shared_experts.gate_proj.kernel[:] = moe_layer.shared_experts.gate_proj.kernel[:] + moe_layer_merged.shared_experts.up_proj.kernel[:] = moe_layer.shared_experts.up_proj.kernel[:] + moe_layer_merged.shared_experts.down_proj.kernel[:] = moe_layer.shared_experts.down_proj.kernel[:] + + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + proj = getattr(moe_layer.experts, proj_name) + proj_merged = getattr(moe_layer_merged.experts, proj_name) + + # For each expert, merge: base + scaling * (lora_A @ lora_B) + for expert_idx in range(config.n_routed_experts): + lora_A = proj.lora_A[adapter_idx, expert_idx, :, :] + lora_B = proj.lora_B[adapter_idx, expert_idx, :, :] + lora_delta = scaling * (lora_A @ lora_B) + + # Copy base weight AND add LoRA delta + base_weight = proj.weight[expert_idx, :, :] + merged_weight = base_weight + lora_delta + proj_merged.weight[...] = proj_merged.weight[...].at[expert_idx, :, :].set(merged_weight) + + # Run merged model on this sample + x_sample = x[sample_idx : sample_idx + 1].numpy() + output_merged = moe_layer_merged(x_sample) + + assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) diff --git a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py new file mode 100644 index 000000000..ab1038d2b --- /dev/null +++ b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py @@ -0,0 +1,227 @@ +from flax import nnx +import jax +import jax.numpy as jnp +import optax +from huggingface_hub import snapshot_download +from transformers import PretrainedConfig + +from tx.models.configs import DeepseekV3Config +from tx.models.deepseekv3 import DeepseekV3ForCausalLM +from tx.utils.models import get_dtype, load_safetensors +from tx.layers.lora import init_lora_adapter +from tx.tinker.types import LoraConfig + + +def _is_routed_expert_path(path) -> bool: + """Disambiguate shared_experts and experts""" + keys = [] + for p in path: + if hasattr(p, "key"): + keys.append(str(p.key)) + elif hasattr(p, "name"): + keys.append(str(p.name)) + + for i, key in enumerate(keys): + if key == "experts" and i > 0 and keys[i - 1] == "mlp": + return True + return False + + +def _get_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): + """Extract out-of-rank params, using effective rank for routed expert layers.""" + + def slice_param(path, p): + path_str = str(path) + + if _is_routed_expert_path(path): + effective_rank = max(1, rank // num_experts) + else: + effective_rank = rank + + if "lora_A" in path_str: + # lora_A shape: [adapters, ..., max_rank] - slice last dim + return p[adapter_idx, ..., effective_rank:].copy() + elif "lora_B" in path_str: + # lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim + return p[adapter_idx, ..., effective_rank:, :].copy() + return p + + return jax.tree.map_with_path(slice_param, params) + + +def test_lora_training_moe_rank_normalized(): + base_model = "yujiepan/deepseek-v3-tiny-random" + base_config = PretrainedConfig.from_pretrained(base_model, trust_remote_code=True) + config = DeepseekV3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True) + + checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"]) + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = DeepseekV3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + load_safetensors(checkpoint_path, config, model) + + # Set different ranks for each adapter (0: rank 16, 1: rank 8) + # For routed experts with 256 experts: effective rank = max(1, rank // 256) = 1 + # For other layers: effective rank = configured rank + init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) + init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) + + optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) + + batch = jnp.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]], dtype=jnp.int32) + target_ids = batch[:, 1:] + input_ids = batch[:, :-1] + adapter_indices = jnp.array([0, 1], dtype=jnp.int32) + attention_mask = jnp.ones_like(input_ids) + + def loss_fn(model, input_ids, target_ids, attention_mask): + outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) + return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() + + graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) + + def get_adapter_params(params, adapter_idx): + return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + + num_experts = config.n_routed_experts + + # Save initial states + initial_adapter_2_params = get_adapter_params(lora_params, 2) + initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + + initial_loss = None + + # Training loop + for step in range(10): + + def loss_for_lora(lora_params): + merged_model = nnx.merge(graphdef, lora_params, non_lora_params) + return loss_fn(merged_model, input_ids, target_ids, attention_mask) + + loss_and_grad_fn = nnx.value_and_grad(loss_for_lora) + loss, lora_grads = loss_and_grad_fn(lora_params) + + if initial_loss is None: + initial_loss = float(loss) + + optimizer.update(lora_params, lora_grads) + + print(f"Step {step}: loss = {float(loss):.4f}") + + final_loss = float(loss) + + def verify_params_unchanged(initial_params, final_params, error_msg_prefix): + for (path, initial), (_, final) in zip( + jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) + ): + assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" + + assert final_loss < initial_loss, f"Loss did not decrease: {initial_loss} -> {final_loss}" + + # Verify unused adapter was not modified + final_adapter_2_params = get_adapter_params(lora_params, 2) + verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") + + # Verify out-of-rank params were not modified + final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + verify_params_unchanged( + initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" + ) + final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + verify_params_unchanged( + initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" + ) + + +def test_lora_training_high_rank(): + base_model = "yujiepan/deepseek-v3-tiny-random" + base_config = PretrainedConfig.from_pretrained(base_model, trust_remote_code=True) + config = DeepseekV3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True) + + checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"]) + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = DeepseekV3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + load_safetensors(checkpoint_path, config, model) + + init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) + init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) + + optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) + + batch = jnp.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]], dtype=jnp.int32) + target_ids = batch[:, 1:] + input_ids = batch[:, :-1] + adapter_indices = jnp.array([0, 1], dtype=jnp.int32) + attention_mask = jnp.ones_like(input_ids) + + def loss_fn(model, input_ids, target_ids, attention_mask): + outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) + return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() + + graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) + + def get_adapter_params(params, adapter_idx): + return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + + num_experts = config.n_routed_experts + + # Save initial states for all unused adapters + initial_adapter_2_params = get_adapter_params(lora_params, 2) + initial_adapter_3_params = get_adapter_params(lora_params, 3) + initial_adapter_4_params = get_adapter_params(lora_params, 4) + + # Save out-of-rank params for adapters 0 and 1 + initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + + # Training loop + for step in range(10): + + def loss_for_lora(lora_params): + merged_model = nnx.merge(graphdef, lora_params, non_lora_params) + return loss_fn(merged_model, input_ids, target_ids, attention_mask) + + loss_and_grad_fn = nnx.value_and_grad(loss_for_lora) + loss, lora_grads = loss_and_grad_fn(lora_params) + + optimizer.update(lora_params, lora_grads) + + print(f"Step {step}: loss = {float(loss):.4f}") + + def verify_params_unchanged(initial_params, final_params, error_msg_prefix): + for (path, initial), (_, final) in zip( + jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) + ): + assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" + + # Verify unused adapters (2, 3, 4) were not modified + final_adapter_2_params = get_adapter_params(lora_params, 2) + verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") + + final_adapter_3_params = get_adapter_params(lora_params, 3) + verify_params_unchanged(initial_adapter_3_params, final_adapter_3_params, "Adapter 3 was modified") + + final_adapter_4_params = get_adapter_params(lora_params, 4) + verify_params_unchanged(initial_adapter_4_params, final_adapter_4_params, "Adapter 4 was modified") + + # Verify out-of-rank params were not modified + final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + verify_params_unchanged( + initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" + ) + final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + verify_params_unchanged( + initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" + ) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 911fff721..710ed620e 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -318,7 +318,7 @@ def init_adapter(path, value): # Following Thinking Machines' approach: divide rank by num_experts # to keep total LoRA parameters similar to non-MoE models if "experts" in normalized_path: - effective_rank = max(1, lora_config.rank // model.config.num_experts) + effective_rank = max(1, lora_config.rank // model.config.get_num_experts()) if not filter_lora(lora_config, normalized_path): effective_rank = 0 diff --git a/skyrl-tx/tx/layers/rotary_embedding.py b/skyrl-tx/tx/layers/rotary_embedding.py index 3ff67655f..8af80ada4 100644 --- a/skyrl-tx/tx/layers/rotary_embedding.py +++ b/skyrl-tx/tx/layers/rotary_embedding.py @@ -1,10 +1,15 @@ """Rotary Position Embeddings (RoPE) implementation.""" +import math +from typing import Any, Callable + import jax from jax import numpy as jnp -def apply_rope(inputs: jax.Array, position_ids: jax.Array, head_dim: int, theta: float) -> jax.Array: +def apply_rope( + inputs: jax.Array, position_ids: jax.Array, head_dim: int, theta: float, interleave: bool = False +) -> jax.Array: """Apply Rotary Position Embeddings (RoPE). Args: @@ -12,6 +17,8 @@ def apply_rope(inputs: jax.Array, position_ids: jax.Array, head_dim: int, theta: position_ids: Position indices of shape [B, T] head_dim: Dimension of each attention head theta: Base for the geometric progression (rope_theta) + interleave: If True, use interleaved slicing (x[..., ::2], x[..., 1::2]) + instead of splitting the last dimension in half. Returns: Tensor with RoPE applied, same shape as inputs @@ -20,5 +27,56 @@ def apply_rope(inputs: jax.Array, position_ids: jax.Array, head_dim: int, theta: timescale = jnp.pow(theta, fraction) x = (position_ids[..., None] / timescale[None, None, :])[..., None, :] sin, cos = jnp.sin(x), jnp.cos(x) - a, b = jnp.split(inputs, 2, axis=-1) - return jnp.concatenate([a * cos - b * sin, b * cos + a * sin], axis=-1).astype(inputs.dtype) + + if interleave: + a, b = inputs[..., ::2], inputs[..., 1::2] + else: + a, b = jnp.split(inputs, 2, axis=-1) + + return jnp.concatenate([a * cos - b * sin, a * sin + b * cos], axis=-1).astype(inputs.dtype) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def get_rope( + head_dim: int, + rope_theta: float, + rope_scaling: dict[str, Any] | None = None, +) -> tuple[Callable[[jax.Array, jax.Array], jax.Array], float]: + """Factory function to create a rotary embedding function. + + Args: + head_dim: Dimension of each attention head. + rope_theta: Base for the geometric progression. + rope_scaling: Optional dict with scaling configuration. The "type" or + "rope_type" field determines the RoPE variant to use. + + Returns: + A tuple of (rotary_emb, mscale) where rotary_emb takes (inputs, positions) + and returns RoPE-applied outputs, and mscale is the attention magnitude + scale factor for YaRN-style scaling. + """ + rope_scaling = rope_scaling or {} + rope_type = rope_scaling.get("rope_type", "default") + + match rope_type: + case "deepseek_yarn": + mscale = yarn_get_mscale(rope_scaling["factor"], rope_scaling["mscale_all_dim"]) + + def rotary_emb(inputs: jax.Array, positions: jax.Array) -> jax.Array: + return apply_rope(inputs, positions, head_dim, rope_theta, interleave=True) + + case "default": + mscale = 1.0 + + def rotary_emb(inputs: jax.Array, positions: jax.Array) -> jax.Array: + return apply_rope(inputs, positions, head_dim, rope_theta) + + case _: + raise ValueError(f"Unsupported rope_type: {rope_type}") + + return rotary_emb, mscale diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index 8a3ce3ae4..15e011388 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -45,7 +45,11 @@ def __init__( self.loss_chunk_size = loss_chunk_size self.gradient_checkpointing = gradient_checkpointing + def get_num_experts(self): + return getattr(self, "num_experts", None) or getattr(self, "n_routed_experts", None) + # Model-specific aliases for clarity and backwards compatibility Llama3Config = ModelConfig Qwen3Config = ModelConfig +DeepseekV3Config = ModelConfig diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py new file mode 100644 index 000000000..a2e48abdf --- /dev/null +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -0,0 +1,561 @@ +from flax import nnx +import jax +from jax import numpy as jnp +from jax.sharding import get_abstract_mesh + +from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear +from tx.layers.rotary_embedding import get_rope +from tx.layers.util import Param, prepare_routing +from tx.layers.layernorm import RMSNorm +from tx.models.configs import DeepseekV3Config +from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput +from tx.utils.generator import GeneratorMixin, KVCache +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead + + +class DeepseekV3Attention(nnx.Module): + """Multi-Head Latent Attention (MLA) Layer.""" + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.num_heads = config.num_attention_heads + + tp = get_abstract_mesh().shape.get("tp", 1) + shard_attention_heads = config.shard_attention_heads + if shard_attention_heads: + assert self.num_heads % tp == 0, f"num_heads={self.num_heads} must be divisible by tp={tp}" + tp_shard = "tp" if shard_attention_heads else None + + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + + if self.q_lora_rank is None: + self.q_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_heads * self.qk_head_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + rngs=rngs, + ) + self.q_a_proj = None + self.q_a_layernorm = None + self.q_b_proj = None + else: + self.q_proj = None + self.q_a_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.q_lora_rank, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=config.attention_bias, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)), + rngs=rngs, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.q_b_proj = LoRALinear( + in_features=self.q_lora_rank, + out_features=self.num_heads * self.qk_head_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + rngs=rngs, + ) + + self.kv_a_proj_with_mqa = LoRALinear( + in_features=config.hidden_size, + out_features=self.kv_lora_rank + self.qk_rope_head_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=config.attention_bias, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)), + rngs=rngs, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + + self.kv_b_proj = LoRALinear( + in_features=self.kv_lora_rank, + out_features=self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + rngs=rngs, + ) + + self.o_proj = LoRALinear( + in_features=self.num_heads * self.v_head_dim, + out_features=config.hidden_size, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=config.attention_bias, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")), + rngs=rngs, + ) + + self.rotary_emb, mscale = get_rope(self.qk_rope_head_dim, config.rope_theta, config.rope_scaling) + self.scaling = self.qk_head_dim ** (-0.5) * mscale * mscale + + def __call__( + self, + x: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None = None, + kv_cache: tuple[jax.Array, jax.Array] | None = None, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + B, T, _ = x.shape + + # Query projection + if self.q_lora_rank is None: + q_states = self.q_proj(x, adapter_indices=adapter_indices) + else: + y = self.q_a_proj(x, adapter_indices=adapter_indices) + q_states = self.q_b_proj(self.q_a_layernorm(y), adapter_indices=adapter_indices) + + q_states = q_states.reshape(B, T, self.num_heads, self.qk_head_dim) + q_pass, q_rot = jnp.split(q_states, [self.qk_nope_head_dim], axis=-1) + + compressed_kv = self.kv_a_proj_with_mqa(x, adapter_indices=adapter_indices) + k_pass, k_rot = jnp.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_rot = k_rot.reshape(B, T, 1, self.qk_rope_head_dim) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass), adapter_indices=adapter_indices) + k_pass = k_pass.reshape(B, T, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_pass, v = jnp.split(k_pass, [self.qk_nope_head_dim], axis=-1) + + q_rot = self.rotary_emb(q_rot, positions) + k_rot = self.rotary_emb(k_rot, positions) + + # Expand k_rot to all heads + k_rot = jnp.broadcast_to(k_rot, (B, T, self.num_heads, self.qk_rope_head_dim)) + + q = jnp.concatenate([q_pass, q_rot], axis=-1) + k = jnp.concatenate([k_pass, k_rot], axis=-1) + + # Handle KV cache + if kv_cache is not None: + k, v = KVCache.update_layer(kv_cache, k, v, positions) + + updated_cache = (k, v) + + # Jax attention expects v to have the same shape as k + v = jnp.pad(v, ((0, 0), (0, 0), (0, 0), (0, self.qk_head_dim - self.v_head_dim))) + + attn_output = jax.nn.dot_product_attention( + q, + k, + v, + scale=self.scaling, + mask=attention_mask[:, None, None, :].astype(bool), + is_causal=kv_cache is None, + ) + + attn_output = attn_output[:, :, :, : self.v_head_dim].reshape(B, T, self.num_heads * self.v_head_dim) + return self.o_proj(attn_output, adapter_indices=adapter_indices), updated_cache + + +class DeepseekV3MLP(nnx.Module): + + def __init__( + self, + config: DeepseekV3Config, + *, + dtype: jnp.dtype, + rngs: nnx.Rngs, + override_intermediate_size: int | None = None, + ) -> None: + self.config = config + intermediate_size = override_intermediate_size or config.intermediate_size + self.gate_proj = LoRALinear( + config.hidden_size, + intermediate_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + self.up_proj = LoRALinear( + config.hidden_size, + intermediate_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + self.down_proj = LoRALinear( + intermediate_size, + config.hidden_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + + def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array: + gate_out = self.gate_proj(x, adapter_indices) + up_out = self.up_proj(x, adapter_indices) + return self.down_proj(nnx.silu(gate_out) * up_out, adapter_indices) + + +class DeepseekV3TopkRouter(nnx.Module): + """DeepseekV3 MoE routing gate. Returns raw router logits.""" + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + + self.weight = Param( + config.hidden_size, + config.n_routed_experts, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, None)), + rngs=rngs, + ) + + self.e_score_correction_bias = nnx.Variable(jnp.zeros(config.n_routed_experts, dtype=dtype)) + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + hidden_states = hidden_states.reshape(-1, self.config.hidden_size) + router_logits = hidden_states.astype(jnp.float32) @ self.weight[...].astype(jnp.float32) + return router_logits + + +class DeepseekV3NaiveMoe(nnx.Module): + """Run NaiveMoe on selected expert groups.""" + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.num_experts = config.n_routed_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + + # NOTE: Huggingface implementation uses a fused gate_up_proj, but the weights are keyed + # by gate_proj and up_proj separately. + self.gate_proj = LoRAExpert( + self.num_experts, + self.hidden_dim, + self.intermediate_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), + rngs=rngs, + ) + self.up_proj = LoRAExpert( + self.num_experts, + self.hidden_dim, + self.intermediate_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), + rngs=rngs, + ) + self.down_proj = LoRAExpert( + self.num_experts, + self.intermediate_dim, + self.hidden_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")), + rngs=rngs, + ) + + def __call__( + self, + hidden_states: jax.Array, + top_k_index: jax.Array, + top_k_weights: jax.Array, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + num_experts_per_tok = top_k_index.shape[1] + + # Prepare for ragged_dot by sorting tokens based on their assigned expert + selected_experts_flat = top_k_index.ravel() + hidden_states_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0) + adapter_indices_expanded = ( + jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None + ) + + hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing( + hidden_states_expanded, + selected_experts_flat, + self.num_experts, + adapter_indices=adapter_indices_expanded, + ) + + gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) + up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) + down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted) + + # Unsort and combine the expert outputs + unsorted_out = down_out[unsort_indices] + reshaped_out = unsorted_out.reshape(-1, num_experts_per_tok, self.hidden_dim) + return jnp.sum(reshaped_out * top_k_weights[..., None], axis=1) + + +class DeepseekV3MoE(nnx.Module): + """MoE layer for routing to top-k expert groups.""" + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.n_group = config.n_group + + self.gate = DeepseekV3TopkRouter(config, dtype=dtype, rngs=rngs) + self.experts = DeepseekV3NaiveMoe(config, dtype=dtype, rngs=rngs) + + inter_dim = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config, dtype=dtype, rngs=rngs, override_intermediate_size=inter_dim) + + def _compute_routing(self, router_logits: jax.Array) -> tuple[jax.Array, jax.Array]: + num_tokens = router_logits.shape[0] + num_experts = router_logits.shape[1] + + scores = nnx.sigmoid(router_logits) + scores_with_bias = scores + self.gate.e_score_correction_bias[...] + + experts_per_group = num_experts // self.n_group + scores_grouped = scores_with_bias.reshape(num_tokens, self.n_group, experts_per_group) + + top2, _ = jax.lax.top_k(scores_grouped, 2) + group_scores = jnp.sum(top2, axis=-1) + + _, top_group_indices = jax.lax.top_k(group_scores, self.config.topk_group) + + mask = jnp.ones((num_tokens, self.n_group), dtype=bool) + batch_indices = jnp.arange(num_tokens)[:, None] + mask = mask.at[batch_indices, top_group_indices].set(False) + mask = jnp.broadcast_to(mask[:, :, None], scores_grouped.shape) + + scores_with_bias = jnp.where(mask, 0.0, scores_grouped) + scores_with_bias = scores_with_bias.reshape(num_tokens, num_experts) + + _, top_k_index = jax.lax.top_k(scores_with_bias, self.config.num_experts_per_tok) + + # Get weights from original scores + top_k_weights = jnp.take_along_axis(scores, top_k_index, axis=-1) + + if self.config.norm_topk_prob: + top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True) + + top_k_weights = top_k_weights * self.config.routed_scaling_factor + + return top_k_weights.astype(router_logits.dtype), top_k_index + + def __call__( + self, + hidden_states: jax.Array, + *, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + batch_size, seq_len, hidden_size = hidden_states.shape + hidden_states_flat = hidden_states.reshape(-1, hidden_size) + + if adapter_indices is not None: + adapter_indices_flat = jnp.repeat(adapter_indices, seq_len) + else: + adapter_indices_flat = None + + router_logits = self.gate(hidden_states_flat) + top_k_weights, top_k_index = self._compute_routing(router_logits) + + expert_output = self.experts(hidden_states_flat, top_k_index, top_k_weights, adapter_indices_flat) + shared_output = self.shared_experts( + hidden_states_flat.reshape(batch_size, seq_len, hidden_size), adapter_indices + ).reshape(-1, hidden_size) + expert_output = expert_output + shared_output + + return expert_output.reshape(batch_size, seq_len, hidden_size) + + +class DeepseekV3DecoderLayer(nnx.Module): + + def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) + + # Use dense MLP for initial layers, MoE for the rest + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) + else: + self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None = None, + kv_cache: tuple[jax.Array, jax.Array] | None = None, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, updated_cache = self.self_attn( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices) + hidden_states = residual + mlp_output + + return hidden_states, updated_cache + + +class DeepseekV3Model(nnx.Module): + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + + self.embed_tokens = LoRAEmbed( + num_embeddings=config.vocab_size, + features=config.hidden_size, + dtype=dtype, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + param_dtype=dtype, + embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + rngs=rngs, + ) + self.layers = nnx.List( + [ + DeepseekV3DecoderLayer(config, layer_idx=i, dtype=dtype, rngs=rngs) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + + def __call__( + self, + input_ids: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + output_hidden_states: bool | None = None, + adapter_indices: jax.Array | None = None, + kv_cache: KVCache | None = None, + ) -> ModelOutput: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) + all_hidden_states: list[jax.Array] = [] + updated_keys, updated_values = [], [] + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), + ) + updated_keys.append(k) + updated_values.append(v) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return ModelOutput( + last_hidden_state=hidden_states, + kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class DeepseekV3ForCausalLM(nnx.Module, ModelForCausalLM, GeneratorMixin, LogitsProcessorMixin): + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.model = DeepseekV3Model(config, dtype=dtype, rngs=rngs) + + if not self.config.tie_word_embeddings: + self.lm_head = LoRALinear( + config.hidden_size, + config.vocab_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + return self.lm_head + + @staticmethod + def is_lora_param(path: tuple, _value) -> bool: + """Return True if a parameter path corresponds to LoRA weights.""" + return any(name in path for name in ("lora_A", "lora_B")) + + def __call__( + self, + input_ids: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array | None = None, + output_hidden_states: bool | None = None, + adapter_indices: jax.Array | None = None, + kv_cache: KVCache | None = None, + ) -> CausalLMOutput: + if positions is None: + positions = jnp.arange(attention_mask.shape[1])[None, :] + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + positions=positions, + output_hidden_states=output_hidden_states, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + + return CausalLMOutput( + last_hidden_state=outputs.last_hidden_state, + kv_cache=outputs.kv_cache, + hidden_states=outputs.hidden_states, + ) diff --git a/skyrl-tx/tx/run/train.py b/skyrl-tx/tx/run/train.py index 807f959d7..398c51a59 100644 --- a/skyrl-tx/tx/run/train.py +++ b/skyrl-tx/tx/run/train.py @@ -77,7 +77,7 @@ def train( train_dataset = load_dataset(dataset, split=split) assert isinstance(train_dataset, Dataset) base_config = AutoConfig.from_pretrained(model_name) - config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True) + model_config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True) tokenizer = AutoTokenizer.from_pretrained(model_name) tracker = get_tracker(tracker_name, base_config, **tracker_args) loader = get_loader(loader_name) @@ -85,11 +85,11 @@ def train( model_class = get_model_class(base_config) mesh = jax.make_mesh((1, 1, tp_size), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) with jax.set_mesh(mesh): - model = model_class(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + model = model_class(model_config, dtype=get_dtype(model_config.dtype), rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, get_optimizer(optimizer_name, optimizer_args), wrt=nnx.Param) if load_checkpoint_path: - load_safetensors(load_checkpoint_path, base_config, model) + load_safetensors(load_checkpoint_path, model_config, model) num_steps = train_dataset.num_rows / batch_size for step, (batch, metrics) in enumerate(loader(tokenizer, train_dataset, batch_size)): @@ -102,9 +102,9 @@ def train( if step % save_steps == 0: logger.info(f"Saving checkpoint to {output_dir}") - save_safetensors(base_config, model, output_dir / "model.safetensors") + save_safetensors(model_config, model, output_dir / "model.safetensors") logger.info(f"Saving final checkpoint to {output_dir}") base_config.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - save_safetensors(base_config, model, output_dir / "model.safetensors") + save_safetensors(model_config, model, output_dir / "model.safetensors") diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index faf1a9634..6e840febf 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -16,6 +16,7 @@ from transformers import PretrainedConfig import peft +from tx.models.configs import ModelConfig from tx.utils.log import logger from tx.utils.storage import download_and_unpack, pack_and_upload from tx.tinker.types import LoraConfig @@ -62,12 +63,15 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: "Get the correct model class based on the config." import tx.models.llama3 import tx.models.qwen3 + import tx.models.deepseekv3 for architecture in config.architectures or []: if hasattr(tx.models.llama3, architecture): return getattr(tx.models.llama3, architecture) if hasattr(tx.models.qwen3, architecture): return getattr(tx.models.qwen3, architecture) + if hasattr(tx.models.deepseekv3, architecture): + return getattr(tx.models.deepseekv3, architecture) raise ValueError(f"None of the architectures {config.architectures} is currently supported.") @@ -89,7 +93,7 @@ def get_expert_key(path: tuple, expert_idx: int) -> str: def load_safetensors( checkpoint_dir: str | os.PathLike, - config: PretrainedConfig, + config: ModelConfig, model: nnx.Module, skip_lora: bool = True, prefix: str = "", @@ -110,7 +114,9 @@ def load_safetensors( if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue if "experts" in path: - tensors[key] = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.num_experts)], axis=0) + tensors[key] = np.stack( + [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 + ) else: tensors[key] = tensors[key] if "embed_tokens" in path else tensors[key].T if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: @@ -122,7 +128,7 @@ def load_safetensors( def save_safetensors( - config: PretrainedConfig, + config: ModelConfig, model: nnx.Module, filename: Path, prefix: str = "", @@ -137,7 +143,7 @@ def save_safetensors( continue key = get_param_key(path, prefix=prefix) if "experts" in path: - for i in range(config.num_experts): + for i in range(config.get_num_experts()): tensors[get_expert_key(path, i)] = param[i, :, :].T continue if "q_proj" in path or "k_proj" in path or "v_proj" in path: