diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f68229368..666c5888d 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -156,3 +156,47 @@ jobs: - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + e2e-test-ckpt: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-ckpt')) + runs-on: self-hosted + container: + image: slimerl/slime:latest + options: > + --gpus all + --ipc=host + --shm-size=16g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -e http_proxy=$http_proxy + -e https_proxy=$https_proxy + -e HTTP_PROXY=$HTTP_PROXY + -e HTTPS_PROXY=$HTTPS_PROXY + -v /mnt/nvme0n1/slime_ci:/data/slime_ci + -v /mnt/nvme0n1/slime_ci/models:/root/models + -v /mnt/nvme0n1/slime_ci/datasets:/root/datasets + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + SLIME_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index de53fe4e4..7ceac55c4 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -24,6 +24,13 @@ {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, ], }, + 'e2e-test-ckpt': { + 'label': 'run-ci-ckpt', + 'tests': [ + {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + ], + }, } %> name: PR Test diff --git a/build_conda.sh b/build_conda.sh index 548fe1d96..a1d19543d 100644 --- a/build_conda.sh +++ b/build_conda.sh @@ -21,13 +21,13 @@ micromamba install -n slime cuda cuda-nvtx cuda-nvtx-dev nccl -c nvidia/label/cu micromamba install -n slime -c conda-forge cudnn -y # prevent installing cuda 13.0 for sglang -pip install cuda-python==12.9.1 -pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129 +pip install cuda-python==13.1.0 +pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu129 # install sglang git clone https://github.com/sgl-project/sglang.git cd sglang -git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +git checkout 5e2cda6158e670e64b926a9985d65826c537ac82 # Install the python packages pip install -e "python[all]" @@ -39,7 +39,7 @@ pip install cmake ninja MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 --no-build-isolation pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" +pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0" pip install flash-linear-attention==0.4.0 NVCC_APPEND_FLAGS="--threads 4" \ pip -v install --disable-pip-version-check --no-cache-dir \ @@ -50,7 +50,7 @@ git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \ pip install -e . -pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation @@ -60,6 +60,9 @@ git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ cd Megatron-LM/ && git checkout core_v0.14.0 && \ pip install -e . +# https://github.com/pytorch/pytorch/issues/168167 +pip install nvidia-cudnn-cu12==9.16.0.29 + # install slime and apply patches # if slime does not exist locally, clone it @@ -76,6 +79,6 @@ fi # apply patch cd $BASE_DIR/sglang -git apply $SLIME_DIR/docker/patch/v0.5.5.post1/sglang.patch +git apply $SLIME_DIR/docker/patch/v0.5.6/sglang.patch cd $BASE_DIR/Megatron-LM -git apply $SLIME_DIR/docker/patch/v0.5.5.post1/megatron.patch +git apply $SLIME_DIR/docker/patch/v0.5.6/megatron.patch \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 4cc49b718..5c99ecfae 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -71,25 +71,11 @@ RUN if [ "$ENABLE_CUDA_13" = "1" ]; then \ python3 -m pip install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu130-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps; \ fi -# AMEM -# we need to create a fake libcuda.so.1 to make the linker happy when building AMEM -ENV CUDA_DIR=/usr/local/cuda -ENV CUDA_STUBS=${CUDA_DIR}/lib64/stubs -RUN ln -s ${CUDA_STUBS}/libcuda.so ${CUDA_STUBS}/libcuda.so.1 && \ - echo "${CUDA_STUBS}" > /etc/ld.so.conf.d/z-cuda-stubs.conf && \ - ldconfig -RUN git clone https://github.com/inclusionAI/asystem-amem.git && \ - cd asystem-amem && git checkout 6483bb17c9a98b51c3a94b7048467d5b50fbad4b && \ - git submodule init && git submodule update && \ - MPI_HOME=/usr/lib/x86_64-linux-gnu/openmpi/ ./build.sh && \ - mv /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2 /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2.bak && \ - cp -r third_party/nccl/build/lib/* /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ - # https://github.com/pytorch/pytorch/issues/168167 RUN pip install nvidia-cudnn-cu12==9.16.0.29 RUN rm /root/.tmux.conf -RUN rm -rf /root/.cache/pip /root/asystem-amem /root/flash-attention +RUN rm -rf /root/.cache/pip /root/flash-attention # ====================================== Patches ============================================ diff --git a/docker/README.md b/docker/README.md index 92f559e72..156169c72 100644 --- a/docker/README.md +++ b/docker/README.md @@ -5,10 +5,10 @@ We will publish 2 kinds of docker images: 2. latest version, which aligns to `lmsysorg/sglang:latest`. current stable version is: -- sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) history versions: -- sglang v0.5.0rc0-cu126 (8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f), megatron 48406695c4efcf1026a7ed70bb390793918dd97b +- sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) The command to build: diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch index 81fb3f814..9d0d6011c 100644 --- a/docker/patch/latest/megatron.patch +++ b/docker/patch/latest/megatron.patch @@ -219,14 +219,14 @@ index 6aec66e6d..6ca48b55f 100644 mtp_loss = loss_mask * mtp_loss if self.training: diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index a36b67364..8739270f2 100644 +index a36b67364..ed8883e32 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) -+ if param_group["step"] is None: ++ if "step" in param_group and param_group["step"] is None: + del param_group["step"] # Grad scaler state. diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index de12cdd43..8aaad7e1f 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -58,55 +58,6 @@ index 952374ed5..239ac2571 100644 class SchedulerDisaggregationPrefillMixin: """ -diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py -index 86c53f26b..52acf95b9 100644 ---- a/python/sglang/srt/distributed/device_communicators/pynccl.py -+++ b/python/sglang/srt/distributed/device_communicators/pynccl.py -@@ -380,3 +380,9 @@ class PyNcclCommunicator: - - self.disabled = old_disable - self.stream = old_stream -+ -+ def nccl_pause(self): -+ self.nccl.ncclPause(self.comm) -+ -+ def nccl_resume(self): -+ self.nccl.ncclResume(self.comm) -diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py -index 6b12f2922..7028a4e46 100644 ---- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py -+++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py -@@ -304,6 +304,17 @@ class NCCLLibrary: - Function("ncclGroupEnd", ncclResult_t, []), - ] - -+ if os.environ.get("AMEM_ENABLE", "0") == "1": -+ exported_functions.extend( -+ [ -+ # ncclResult_t ncclPause(ncclComm_t comm); -+ Function("ncclPause", ncclResult_t, [ncclComm_t]), -+ # ncclResult_t ncclResume(ncclComm_t comm); -+ Function("ncclResume", ncclResult_t, [ncclComm_t]), -+ Function("ncclSetGroupID", ncclResult_t, [ctypes.c_int]), -+ ] -+ ) -+ - exported_functions_symm_mem = [ - # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags); - Function( -@@ -551,6 +562,12 @@ class NCCLLibrary: - def ncclGroupEnd(self) -> None: - self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) - -+ def ncclPause(self, comm: ncclComm_t) -> None: -+ self.NCCL_CHECK(self._funcs["ncclPause"](comm)) -+ -+ def ncclResume(self, comm: ncclComm_t) -> None: -+ self.NCCL_CHECK(self._funcs["ncclResume"](comm)) -+ - - __all__ = [ - "NCCLLibrary", diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index cf90f6fe0..11d26df81 100644 --- a/python/sglang/srt/distributed/parallel_state.py @@ -192,17 +143,26 @@ index 9f556a885..992843285 100644 bsz, s, _ = x_shape head = self.num_attention_heads_per_partition diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py -index 932f52aeb..79c6b664f 100644 +index 932f52aeb..ee52f4c94 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -372,6 +372,7 @@ class LayerCommunicator: residual: torch.Tensor, forward_batch: ForwardBatch, quant_format: str = "", -+ post_residual_addition: Optional[torch.Tensor] = None, ++ **kwargs, ): if get_attn_tp_context().input_scattered: hidden_states, residual = self._tp_reduce_scatter( +@@ -421,7 +422,7 @@ class LayerCommunicator: + ) + + else: +- hidden_states = self.input_layernorm(hidden_states) ++ hidden_states = self.input_layernorm(hidden_states, **kwargs) + else: + + if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format): @@ -453,7 +454,9 @@ class LayerCommunicator: ) else: @@ -210,12 +170,12 @@ index 932f52aeb..79c6b664f 100644 - hidden_states, residual + hidden_states, + residual, -+ post_residual_addition, ++ **kwargs, ) hidden_states = self._communicate_simple_fn( diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py -index 3293a8a59..a075b71ce 100644 +index 3293a8a59..ea6b30d73 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -84,15 +84,12 @@ class RMSNorm(CustomOp): @@ -236,45 +196,73 @@ index 3293a8a59..a075b71ce 100644 self.variance_epsilon = eps self.hidden_size = hidden_size self.variance_size_override = ( -@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): +@@ -105,21 +102,29 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, -+ post_residual_addition: Optional[torch.Tensor] = None, ++ **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: - return self.forward_native(x, residual) -+ return self.forward_native(x, residual, post_residual_addition) ++ return self.forward_native(x, residual, **kwargs) if is_batch_invariant_mode_enabled(): if ( residual is not None or get_global_server_args().rl_on_policy_target == "fsdp" ): - return self.forward_native(x, residual) -+ return self.forward_native(x, residual, post_residual_addition) ++ return self.forward_native(x, residual, **kwargs) return rms_norm_batch_invariant( x, self.weight.data, self.variance_epsilon, ) if residual is not None: -+ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). -+ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm ++ # TODO: Ideally we want to have (hidden_states+residual)+post_residual_addition. ++ # but right now we can only have hidden_states+(residual+post_residual_addition). ++ # (hidden_states+residual)+post_residual_addition != hidden_states+(residual+post_residual_addition), ++ # we probably need to add another parameter to fused_add_rmsnorm ++ post_residual_addition = kwargs.get("post_residual_addition") + if post_residual_addition is not None: + residual = residual + post_residual_addition fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) -@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): +@@ -129,6 +134,7 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + out, _, residual_out = torch_npu.npu_add_rms_norm( +@@ -141,6 +147,7 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + residual_out = torch.empty_like(x) +@@ -160,6 +167,7 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, -+ post_residual_addition: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + # NOTE: Remove this if aiter kernel supports discontinuous input +@@ -179,17 +187,36 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): x = x.contiguous() - orig_dtype = self.override_orig_dtype or x.dtype + orig_dtype = x.dtype ++ post_residual_addition = kwargs.get("post_residual_addition") + + if residual is not None and not self.fp32_residual: + x = ( @@ -308,6 +296,148 @@ index 3293a8a59..a075b71ce 100644 hidden_size = x.shape[-1] if hidden_size != self.hidden_size: +@@ -226,6 +253,7 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if _is_cpu_amx_available: + if residual is not None: +@@ -237,15 +265,16 @@ class RMSNorm(CustomOp): + x, self.weight.data, self.variance_epsilon + ) + else: +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, **kwargs) + + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, **kwargs) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual +@@ -307,6 +336,7 @@ class LayerNorm(CustomOp): + def forward_cuda( + self, + x: torch.Tensor, ++ **kwargs, + ) -> torch.Tensor: + if ( + _flashinfer_layernorm_available +@@ -315,11 +345,12 @@ class LayerNorm(CustomOp): + ): + return layernorm(x, self.weight, self.bias, self.variance_epsilon) + else: +- return self.forward_native(x) ++ return self.forward_native(x, **kwargs) + + def forward_native( + self, + x: torch.Tensor, ++ **kwargs, + ) -> torch.Tensor: + weight = self.weight if self.elementwise_affine else None + bias = self.bias if self.use_bias else None +@@ -336,12 +367,14 @@ class LayerNorm(CustomOp): + def forward_hip( + self, + x: torch.Tensor, ++ **kwargs, + ) -> torch.Tensor: +- return self.forward_native(x) ++ return self.forward_native(x, **kwargs) + + def forward_npu( + self, + x: torch.Tensor, ++ **kwargs, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(self.dtype) +@@ -360,8 +393,9 @@ class LayerNorm(CustomOp): + def forward_cpu( + self, + x: torch.Tensor, ++ **kwargs, + ) -> torch.Tensor: +- return self.forward_native(x) ++ return self.forward_native(x, **kwargs) + + + class GemmaRMSNorm(CustomOp): +@@ -382,6 +416,7 @@ class GemmaRMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + gemma_fused_add_rmsnorm( +@@ -395,6 +430,7 @@ class GemmaRMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + if residual is not None: +@@ -412,13 +448,15 @@ class GemmaRMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: +- return self._forward_impl(x, residual) ++ return self._forward_impl(x, residual, **kwargs) + + def forward_npu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x = x + residual +@@ -431,8 +469,9 @@ class GemmaRMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: +- return self._forward_impl(x, residual) ++ return self._forward_impl(x, residual, **kwargs) + + + class Gemma3RMSNorm(CustomOp): +@@ -445,17 +484,17 @@ class Gemma3RMSNorm(CustomOp): + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + +- def forward_native(self, x): ++ def forward_native(self, x, **kwargs): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + +- def forward_cuda(self, x): +- return self.forward_native(x) ++ def forward_cuda(self, x, **kwargs): ++ return self.forward_native(x, **kwargs) + +- def forward_npu(self, x): ++ def forward_npu(self, x, **kwargs): + output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps) + return output + diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 522865765..733bad5f2 100644 --- a/python/sglang/srt/layers/logits_processor.py @@ -350,10 +480,10 @@ index e7d5a67cc..639e47163 100644 out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py new file mode 100644 -index 000000000..e16817f1f +index 000000000..11adcaa77 --- /dev/null +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py -@@ -0,0 +1,279 @@ +@@ -0,0 +1,305 @@ +import logging +from abc import ABC +from contextlib import contextmanager @@ -364,12 +494,17 @@ index 000000000..e16817f1f + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.layers.dp_attention import ( ++ attn_tp_all_gather_into_tensor, + get_attention_dp_rank, ++ get_attention_tp_size, + get_dp_local_info, + is_dp_attention_enabled, +) +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.server_args import get_global_server_args ++from sglang.srt.layers.moe import ( ++ get_moe_a2a_backend, ++) + +logger = logging.getLogger(__name__) + @@ -446,9 +581,6 @@ index 000000000..e16817f1f + assert hasattr(self, "buffer") + return get_tensor_size_bytes(self.buffer) + -+ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): -+ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) -+ + def _finalize_allocation_log(self): + """Common logging and memory usage computation for captured experts buffers.""" + buffer_size_GB = self.get_buffer_size_bytes() / _GB @@ -539,7 +671,24 @@ index 000000000..e16817f1f + device=device, + ) + ++ if get_moe_a2a_backend().is_deepep(): ++ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 ++ self.gather_buffer = torch.empty( ++ ( ++ self.device_cache.buffer.shape[0] * attn_tp_size, ++ self.device_cache.buffer.shape[2], ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ + def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ if get_moe_a2a_backend().is_deepep(): ++ local_topk_ids = topk_ids ++ topk_ids = self.gather_buffer[ ++ : local_topk_ids.size(0) * get_attention_tp_size() ++ ] ++ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + def sync_fwd_experts_buffer_DtoH( @@ -549,7 +698,9 @@ index 000000000..e16817f1f + can_run_graph: bool, + cuda_graph_batch: int, + ): -+ if is_dp_attention_enabled(): ++ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer ++ # contains data from all DP ranks. We should not slice by DP rank in this case. ++ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): + local_start_pos, local_num_tokens = get_dp_local_info(self.forward_batch) + # handle with cuda graph padding + if can_run_graph: @@ -561,6 +712,11 @@ index 000000000..e16817f1f + local_start_pos = 0 + local_end_pos = device_loc.shape[0] + ++ if self.forward_batch.num_token_non_padded is not None: ++ assert local_end_pos - local_start_pos >= self.forward_batch.num_token_non_padded ++ local_end_pos = local_start_pos + self.forward_batch.num_token_non_padded ++ cpu_loc = cpu_loc[: self.forward_batch.num_token_non_padded] ++ + self.host_cache.buffer[cpu_loc] = self.device_cache.buffer[ + local_start_pos:local_end_pos, :, : self.num_experts_per_tok + ].cpu() @@ -838,27 +994,23 @@ index 7f6f6a010..c4a673145 100644 if not get_global_server_args().sampling_backend == "ascend" or ( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py -index 87922077e..8cb6bad8d 100644 +index 87922077e..6507d8bf5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py -@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): +@@ -247,6 +247,12 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): s.sent_offset = len(output_str) output_strs.append(incremental_output) + output_routed_experts = [] + if recv_obj.output_routed_experts is not None: + output_routed_experts = [ -+ ( -+ output_routed_experts.tolist() -+ if output_routed_experts is not None -+ else [] -+ ) ++ output_routed_experts + for output_routed_experts in recv_obj.output_routed_experts + ] return BatchStrOutput( rids=recv_obj.rids, http_worker_ipcs=recv_obj.http_worker_ipcs, -@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): +@@ -272,6 +278,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, @@ -927,7 +1079,7 @@ index e34736cc4..5e5997a1a 100644 # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index c4c5a9ebb..1450c5fd8 100644 +index c4c5a9ebb..3650ba881 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -450,6 +450,7 @@ class Req: @@ -977,7 +1129,15 @@ index c4c5a9ebb..1450c5fd8 100644 is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_config=dllm_config, -@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1282,6 +1294,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + ) + else: + self.out_cache_loc = torch.cat(decoder_out_cache_loc) ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + if not encoder_out_cache_loc: + self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to( +@@ -1457,6 +1470,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.req_pool_indices = req_pool_indices_tensor self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc @@ -985,7 +1145,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) if input_embeds -@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1508,10 +1522,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) @@ -1000,7 +1160,7 @@ index c4c5a9ebb..1450c5fd8 100644 # For overlap scheduler, the output_ids has one step delay delta = 0 if self.enable_overlap else -1 -@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1677,6 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) @@ -1008,7 +1168,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 -@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1736,6 +1755,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Allocate memory self.out_cache_loc = alloc_for_decode(self, token_per_req=1) @@ -1016,7 +1176,7 @@ index c4c5a9ebb..1450c5fd8 100644 # Update req-level memory management fields for req in self.reqs: -@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1807,6 +1827,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None @@ -1024,7 +1184,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) -@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1852,6 +1873,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.out_cache_loc = None @@ -1032,7 +1192,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: self.output_ids = torch.cat([self.output_ids, other.output_ids]) -@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1903,6 +1925,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens=self.seq_lens, orig_seq_lens=self.orig_seq_lens, out_cache_loc=self.out_cache_loc, @@ -1040,7 +1200,7 @@ index c4c5a9ebb..1450c5fd8 100644 seq_lens_cpu=seq_lens_cpu, seq_lens_sum=self.seq_lens_sum, return_logprob=self.return_logprob, -@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1983,7 +2006,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " @@ -1050,7 +1210,7 @@ index c4c5a9ebb..1450c5fd8 100644 ) -@@ -2038,6 +2061,9 @@ class ModelWorkerBatch: +@@ -2038,6 +2062,9 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo @@ -1140,7 +1300,7 @@ index c48f5f893..a9796c25f 100644 placeholder_tokens_val=None, retraction_counts=retraction_counts, diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index f8ebfc1f4..a05449fac 100644 +index f8ebfc1f4..48b9a1a3b 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,6 +1,7 @@ @@ -1175,49 +1335,7 @@ index f8ebfc1f4..a05449fac 100644 if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( self.tp_worker.model_runner.model -@@ -137,6 +148,20 @@ class SchedulerUpdateWeightsMixin: - if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: - self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH) - -+ if os.environ.get("AMEM_ENABLE", "0") == "1": -+ tp_group = get_tp_group() -+ if tp_group is not None and tp_group.pynccl_comm is not None: -+ tp_group.pynccl_comm.nccl_pause() -+ attn_tp_group = get_attention_tp_group() -+ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: -+ attn_tp_group.pynccl_comm.nccl_pause() -+ moe_ep_group = get_moe_ep_group() -+ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: -+ moe_ep_group.pynccl_comm.nccl_pause() -+ moe_tp_group = get_moe_tp_group() -+ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: -+ moe_tp_group.pynccl_comm.nccl_pause() -+ - torch.get_device_module().synchronize() - - return ReleaseMemoryOccupationReqOutput() -@@ -155,6 +180,20 @@ class SchedulerUpdateWeightsMixin: - if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: - self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH) - -+ if os.environ.get("AMEM_ENABLE", "0") == "1": -+ tp_group = get_tp_group() -+ if tp_group is not None and tp_group.pynccl_comm is not None: -+ tp_group.pynccl_comm.nccl_resume() -+ attn_tp_group = get_attention_tp_group() -+ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: -+ attn_tp_group.pynccl_comm.nccl_resume() -+ moe_ep_group = get_moe_ep_group() -+ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: -+ moe_ep_group.pynccl_comm.nccl_resume() -+ moe_tp_group = get_moe_tp_group() -+ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: -+ moe_tp_group.pynccl_comm.nccl_resume() -+ - if GPU_MEMORY_TYPE_WEIGHTS in tags: - self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) - torch.distributed.barrier(self.tp_cpu_group) -@@ -167,6 +206,13 @@ class SchedulerUpdateWeightsMixin: +@@ -167,6 +178,13 @@ class SchedulerUpdateWeightsMixin: if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) @@ -1231,11 +1349,47 @@ index f8ebfc1f4..a05449fac 100644 return ResumeMemoryOccupationReqOutput() def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +index edbc52526..2cdc42755 100644 +--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py ++++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin: + result = (await self.update_weights_from_distributed_communicator(obj))[ + 0 + ] ++ if result.success and obj.weight_version is not None: ++ self._update_weight_version_if_provided(obj.weight_version) ++ result.message += ( ++ f" Weight version updated to {obj.weight_version}." ++ ) + return result.success, result.message + + # This means that weight sync +@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin: + async with self.is_pause_cond: + if self.is_pause: + result = (await self.update_weights_from_tensor_communicator(obj))[0] ++ if result.success and obj.weight_version is not None: ++ self._update_weight_version_if_provided(obj.weight_version) ++ result.message += ( ++ f" Weight version updated to {obj.weight_version}." ++ ) + return result.success, result.message + + # This means that weight sync diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index b90cf0616..98d71d896 100644 +index b90cf0616..8a5cbdbed 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -20,6 +20,7 @@ import logging + import math + import os + import pickle ++import pybase64 + import signal + import sys + import threading +@@ -888,6 +889,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): session_params=session_params, custom_logit_processor=obj.custom_logit_processor, return_hidden_states=obj.return_hidden_states, @@ -1243,17 +1397,22 @@ index b90cf0616..98d71d896 100644 data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, -@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1621,6 +1623,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): if getattr(recv_obj, "output_hidden_states", None): meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + if getattr(recv_obj, "output_routed_experts", None): -+ meta_info["routed_experts"] = recv_obj.output_routed_experts[i] ++ if recv_obj.output_routed_experts[i] is not None: ++ meta_info["routed_experts"] = pybase64.b64encode( ++ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C") ++ ).decode("ascii") ++ else: ++ meta_info["routed_experts"] = None + if isinstance(recv_obj, BatchStrOutput): state.text += recv_obj.output_strs[i] if self.server_args.stream_output and state.obj.stream: -@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1747,12 +1757,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): return if len(recv_obj.input_token_logprobs_val) > 0: @@ -1274,7 +1433,7 @@ index b90cf0616..98d71d896 100644 recv_obj.output_token_logprobs_val[recv_obj_index] ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py -index 3a85e6a7e..2859dafa1 100644 +index 3a85e6a7e..d2560e79b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( @@ -1338,8 +1497,16 @@ index 3a85e6a7e..2859dafa1 100644 if self.encoder_lens is not None: self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) self.positions = self._pad_tensor_to_size(self.positions, num_tokens) +@@ -906,6 +921,7 @@ class ForwardBatch: + self.spec_info.hidden_states = self.hidden_states_backup + if hasattr(self, "output_cache_loc_backup"): + self.out_cache_loc = self.output_cache_loc_backup ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + elif self.forward_mode.is_decode() or self.forward_mode.is_idle(): + logits_output.next_token_logits = logits_output.next_token_logits[:bs] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 4d58278b7..8f50dc430 100644 +index 4d58278b7..5965c481e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import ( @@ -1354,18 +1521,19 @@ index 4d58278b7..8f50dc430 100644 from sglang.srt.layers.pooler import EmbeddingPoolerOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model -@@ -502,6 +507,10 @@ class ModelRunner: +@@ -502,6 +507,11 @@ class ModelRunner: server_args.max_running_requests, server_args.max_total_tokens, ) + + # Init routed experts capturer -+ self.init_routed_experts_capturer() ++ if not self.is_draft_worker: ++ self.init_routed_experts_capturer() + if self.device == "cuda": self.init_cublas() self.init_attention_backend() -@@ -545,6 +554,40 @@ class ModelRunner: +@@ -545,6 +555,40 @@ class ModelRunner: # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -1406,7 +1574,7 @@ index 4d58278b7..8f50dc430 100644 def model_specific_adjustment(self): server_args = self.server_args -@@ -792,7 +835,11 @@ class ModelRunner: +@@ -792,7 +836,11 @@ class ModelRunner: ) with self.memory_saver_adapter.region( GPU_MEMORY_TYPE_WEIGHTS, @@ -1419,7 +1587,7 @@ index 4d58278b7..8f50dc430 100644 ): self.model = get_model( model_config=self.model_config, -@@ -2645,9 +2692,12 @@ class ModelRunner: +@@ -2645,9 +2693,12 @@ class ModelRunner: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: self.forward_pass_id += 1 @@ -1435,17 +1603,18 @@ index 4d58278b7..8f50dc430 100644 ): output = self._forward_raw( forward_batch, -@@ -2656,6 +2706,13 @@ class ModelRunner: +@@ -2656,6 +2707,14 @@ class ModelRunner: reinit_attn_backend, split_forward_count, ) + # Copy cached routing experts' buffers back to CPU cache -+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( -+ device_loc=forward_batch.out_cache_loc, -+ cpu_loc=forward_batch.out_cache_loc_cpu, -+ can_run_graph=output[1], -+ cuda_graph_batch=getattr(self.graph_runner, "bs", None), -+ ) ++ if not self.is_draft_worker: ++ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( ++ device_loc=forward_batch.out_cache_loc, ++ cpu_loc=forward_batch.out_cache_loc_cpu, ++ can_run_graph=output[1], ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() @@ -1474,27 +1643,17 @@ index ab1b6576b..dffd8f09a 100644 use_grouped_topk=False, correction_bias=self.gate.e_score_correction_bias, diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index a9689b8f2..bc8538da8 100644 +index a9689b8f2..0a6c467b1 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py -@@ -379,6 +379,17 @@ class Glm4MoeSparseMoeBlock(nn.Module): +@@ -393,6 +393,7 @@ class Glm4MoeSparseMoeBlock(nn.Module): - self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix)) - -+ self.topk = TopK( -+ top_k=self.top_k, + self.topk = TopK( + top_k=self.top_k + self.num_fused_shared_experts, + layer_id=self.layer_id, -+ renormalize=config.norm_topk_prob, -+ use_grouped_topk=True, -+ num_expert_group=config.n_group, -+ topk_group=config.topk_group, -+ correction_bias=self.gate.e_score_correction_bias, -+ routed_scaling_factor=self.routed_scaling_factor, -+ ) -+ - self.experts = get_moe_impl_class(quant_config)( - num_experts=config.n_routed_experts + self.num_fused_shared_experts, - num_fused_shared_experts=self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 9474700c4..398d622ff 100644 --- a/python/sglang/srt/models/gpt_oss.py @@ -1613,7 +1772,7 @@ index ea33e81ef..561934dce 100644 self.norm = PPMissingLayer(return_tuple=True) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 30b92acbd..a0d14895f 100644 +index 30b92acbd..0d28e0f2b 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): @@ -1642,7 +1801,7 @@ index 30b92acbd..a0d14895f 100644 hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], -+ post_residual_addition: Optional[torch.Tensor] = None, ++ **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention hidden_states, residual = self.layer_communicator.prepare_attn( @@ -1650,7 +1809,7 @@ index 30b92acbd..a0d14895f 100644 + hidden_states, + residual, + forward_batch, -+ post_residual_addition=post_residual_addition, ++ **kwargs, ) if hidden_states.shape[0] != 0: hidden_states = self.self_attn( @@ -2009,10 +2168,34 @@ index 8e7753dab..323788f39 100644 "--scheduler-recv-interval", type=int, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index b3d72df05..ddfe0b178 100644 +index b3d72df05..09a1634e0 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py -@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): +@@ -135,6 +135,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): + len(batch.input_ids), + ) + self.last_loc = last_loc ++ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True) + + bs = batch.batch_size() + assign_req_to_token_pool_func( +@@ -492,6 +493,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): + batch.out_cache_loc = tgt_cache_loc + batch.seq_lens.add_(accept_length + 1) + batch.seq_lens_cpu.add_(accept_length_cpu + 1) ++ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True) + + draft_input = EagleDraftInput( + hidden_states=batch.spec_info.hidden_states[accept_index], +@@ -575,6 +577,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) ++ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True) + + return EagleVerifyOutput( + draft_input=draft_input, +@@ -746,6 +749,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.topk_index = self.topk_index[: len(new_indices)] self.hidden_states = self.hidden_states[: len(new_indices)] self.verified_id = self.verified_id[: len(new_indices)] @@ -2023,7 +2206,7 @@ index b3d72df05..ddfe0b178 100644 else: # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` self.topk_p = self.topk_p[new_indices] -@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): +@@ -777,6 +784,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) diff --git a/docker/patch/v0.5.6/megatron.patch b/docker/patch/v0.5.6/megatron.patch new file mode 100644 index 000000000..9d0d6011c --- /dev/null +++ b/docker/patch/v0.5.6/megatron.patch @@ -0,0 +1,869 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index ccf5242a2..9b6d3e31f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li + _restore_dict_types(x_val, templ_val) + + ++@dataclass ++class MCoreMetadata(Metadata): ++ """Metadata with mcore specific data.""" ++ ++ # holds data related to flattened_range ++ # TODO: remove when flattened_range is properly removed ++ mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor ++ ++ + @dataclass(frozen=True) + class MCoreSavePlan(SavePlan): + """SavePlan with MCore specific data.""" +@@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner): + def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: + """Merges MCore data for all plans.""" + global_plan, metadata = super().create_global_plan(all_plans) +- metadata.mcore_data = dict( ++ mcore_data = dict( + ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type] + ) ++ metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata)) + return global_plan, metadata + + def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: +@@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py +index fe26e8b43..4451f2776 100644 +--- a/megatron/core/distributed/__init__.py ++++ b/megatron/core/distributed/__init__.py +@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads + from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel + from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig ++ ++# Backward compatibility patch for FSDP module reorganization ++import sys ++import importlib.util ++ ++spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') ++if spec: ++ custom_fsdp = importlib.util.module_from_spec(spec) ++ spec.loader.exec_module(custom_fsdp) ++ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp ++ if hasattr(custom_fsdp, 'MegatronFSDP'): ++ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index 7727efe1e..966fe652a 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index 860ee64a9..80944b702 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec( + qk_l2_norm: Optional[bool] = False, + use_te_op_fuser: Optional[bool] = False, + use_kitchen: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec( + ), + ), + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map={ + "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", + "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index 6aec66e6d..6ca48b55f 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -355,6 +355,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post +@@ -410,6 +411,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -431,6 +433,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -446,7 +449,7 @@ class GPTModel(LanguageModule): + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + +- if mtp_in_postprocess: ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -465,25 +468,37 @@ class GPTModel(LanguageModule): + if not self.post_process: + return hidden_states + +- if self.mtp_process: +- mtp_labels = labels.clone() ++ if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output +- mtp_logits, _ = self.output_layer( +- hidden_states_list[mtp_layer_number + 1], +- weight=output_weight, +- runtime_gather_output=runtime_gather_output, ++ output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()} ++ output_layer_buffers = dict(self.output_layer.named_buffers()) ++ mtp_logits, _ = torch.func.functional_call( ++ self.output_layer, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden_states_list[mtp_layer_number + 1],), ++ { ++ "weight": output_weight.detach() if output_weight else None, ++ "runtime_gather_output": runtime_gather_output, ++ }, + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. +- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) +- loss_mask, num_tokens = roll_tensor( +- loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ new_loss_mask, num_tokens = roll_tensor( ++ loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params + ) ++ loss_mask = new_loss_mask * loss_mask + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index a36b67364..ed8883e32 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a40c85a88..86688c331 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -9,6 +9,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index 63ee9d1f5..b90b744c1 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py +index c749bac43..dde8d50e7 100644 +--- a/megatron/core/transformer/attention.py ++++ b/megatron/core/transformer/attention.py +@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + nvtx_range_push(suffix="qkv") +- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) ++ if self.config.use_gated_attention: ++ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) ++ else: ++ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix="qkv") + + # =================================================== +@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): + # Output. [sq, b, h] + # ================= + ++ if self.config.use_gated_attention: ++ nvtx_range_push(suffix="sigmoid_gate") ++ core_attn_out = core_attn_out * torch.sigmoid(gate) ++ nvtx_range_pop(suffix="sigmoid_gate") ++ + nvtx_range_push(suffix="linear_proj") + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix="linear_proj") +@@ -879,19 +887,34 @@ class SelfAttention(Attention): + model_comm_pgs=model_comm_pgs, + ) + +- self.linear_qkv = build_module( +- submodules.linear_qkv, +- self.config.hidden_size, +- self.query_projection_size + 2 * self.kv_projection_size, +- config=self.config, +- init_method=self.config.init_method, +- gather_output=False, +- bias=self.config.add_bias_linear or self.config.add_qkv_bias, +- skip_bias_add=False, +- is_expert=False, +- tp_comm_buffer_name='qkv', +- tp_group=self.model_comm_pgs.tp, +- ) ++ if self.config.use_gated_attention: ++ self.linear_qgkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ 2 * (self.query_projection_size + self.kv_projection_size), ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) ++ else: ++ self.linear_qkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ self.query_projection_size + 2 * self.kv_projection_size, ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( +@@ -1036,6 +1059,65 @@ class SelfAttention(Attention): + + return query, key, value + ++ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 ++ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): ++ """ ++ Derives `query`, `key` and `value` tensors from `hidden_states`. ++ """ ++ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] ++ mixed_qgkv, _ = self.linear_qgkv(hidden_states) ++ ++ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] ++ new_tensor_shape = mixed_qgkv.size()[:-1] + ( ++ self.num_query_groups_per_partition, ++ ( ++ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) ++ * self.hidden_size_per_attention_head ++ ), ++ ) ++ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) ++ ++ split_arg_list = [ ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ self.hidden_size_per_attention_head, ++ self.hidden_size_per_attention_head, ++ ] ++ ++ if SplitAlongDim is not None: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) ++ else: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) ++ ++ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] ++ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) ++ gate = gate.reshape(query.size(0), query.size(1), -1) ++ ++ if self.q_layernorm is not None: ++ query = self.q_layernorm(query) ++ ++ if self.k_layernorm is not None: ++ key = self.k_layernorm(key) ++ ++ if self.config.test_mode: ++ self.run_realtime_tests() ++ ++ return query, gate, key, value ++ + def backward_dw(self) -> NoReturn: + """Execute weight update operations""" + self._backward_qkv_proj() +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 235b6f6af..fbcffe278 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -566,6 +566,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from slime.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 6b20b8622..459e65921 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -156,6 +156,9 @@ class TopKRouter(Router): + self.local_tokens_per_expert = None + self.expert_bias = None + ++ from slime.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index b7884e18e..f0104f861 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -105,17 +106,21 @@ def tie_output_layer_state_dict( + ) + + +-def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): +- """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support. + +- This function extends the original roll_tensor to support Context Parallelism, which allows +- MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, +- and tensor rolling requires communication between adjacent CP ranks to properly handle the +- boundary conditions. ++def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): ++ """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support. ++ ++ This function extends the original roll_tensor to support Context Parallelism and Packed Sequences. ++ When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires ++ communication between adjacent CP ranks to properly handle the boundary conditions. ++ When packed sequences are used, rolling is performed within each individual sequence boundary ++ to prevent mixing tokens between different packed sequences. + + For CP=1 (default behavior): Uses standard torch.roll with zero padding + For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges + boundary elements between adjacent CP ranks to maintain sequence continuity. ++ For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens. ++ + + Args: + tensor (Tensor): The input tensor to roll. +@@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): + dims (int): The dimension to roll (typically -1 for sequence dimension). + cp_group (ProcessGroup): The context parallelism process group. If None or size=1, + falls back to standard rolling behavior. ++ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. ++ If provided, rolling respects sequence boundaries. + Returns: + tuple: (rolled_tensor, sum_of_rolled_tensor) + """ ++ ++ if packed_seq_params is not None: ++ return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) ++ + # Standard rolling behavior when CP is not enabled (cp_group is None or size=1) + if cp_group is None or cp_group.size() == 1: + rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) +@@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): + + return rolled_tensor, rolled_tensor.sum() + ++def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): ++ """Roll tensor with packed sequence support. ++ ++ This function handles rolling for packed sequences by respecting sequence boundaries ++ defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual ++ sequence to prevent mixing tokens between different packed sequences. When Context ++ Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata ++ so we slice out the portion of every packed sequence that lives on the current rank and ++ reuse the standard CP boundary exchange to populate the rolling window. ++ ++ Args: ++ tensor (Tensor): The input tensor to roll. ++ shifts (int): The shift of the tensor (typically -1 for MTP). ++ dims (int): The dimension to roll (typically -1 for sequence dimension). ++ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. ++ cp_group (ProcessGroup): The context parallelism process group. ++ ++ Returns: ++ tuple: (rolled_tensor, sum_of_rolled_tensor) ++ """ ++ ++ # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once. ++ assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension." ++ assert shifts == -1, "Packed sequence roll only supports a single-token left shift." ++ cu_seqlens = packed_seq_params.cu_seqlens_q ++ assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q." ++ ++ rolled_tensor = tensor.clone() ++ ++ cp_size = cp_group.size() if cp_group is not None else 1 ++ if cp_size == 1: ++ # CP disabled: simply roll inside each packed sequence boundary. ++ for i in range(len(cu_seqlens) - 1): ++ start_idx = cu_seqlens[i] ++ end_idx = cu_seqlens[i + 1] ++ seq_slice = tensor[..., start_idx:end_idx] ++ rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) ++ rolled_seq[..., shifts:] = 0 ++ rolled_tensor[..., start_idx:end_idx] = rolled_seq ++ return rolled_tensor, rolled_tensor.sum() ++ ++ # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). ++ local_rank = torch.distributed.get_rank(group=cp_group) ++ global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) ++ next_rank = global_ranks[(local_rank + 1) % cp_size] ++ prev_rank = global_ranks[(local_rank - 1) % cp_size] ++ ++ # iterate over each sequence individually ++ for i in range(len(cu_seqlens) - 1): ++ start_idx = cu_seqlens[i] ++ end_idx = cu_seqlens[i + 1] ++ ++ # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx ++ local_start_idx = start_idx // cp_size ++ local_end_idx = end_idx // cp_size ++ tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() ++ ++ # The following code is very similar as the code in roll_tensor function ++ local_chunks = tensor_slice.chunk(2, dim=dims) ++ rolled_chunks = [ ++ torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks ++ ] ++ ++ tensor_send_list = [] ++ tensor_recv_list = [] ++ for chunk in rolled_chunks: ++ boundary = chunk.select(dims, shifts).contiguous().clone() ++ tensor_send_list.append(boundary) ++ tensor_recv_list.append(torch.empty_like(boundary)) ++ ++ ops = [] ++ if local_rank != 0: ++ ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) ++ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) ++ else: ++ tensor_recv_list[1].zero_() ++ ++ if local_rank != cp_size - 1: ++ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) ++ ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) ++ else: ++ tensor_recv_list[0].copy_(tensor_send_list[1]) ++ ++ for op in ops: ++ op.wait() ++ ++ index = [slice(None)] * rolled_chunks[0].dim() ++ index[dims] = shifts ++ for chunk, recv in zip(rolled_chunks, tensor_recv_list): ++ chunk[tuple(index)] = recv ++ ++ seq_result = torch.cat(rolled_chunks, dim=dims) ++ ++ # update the rolled tensor ++ rolled_tensor[..., local_start_idx:local_end_idx] = seq_result ++ ++ return rolled_tensor, rolled_tensor.sum() + + class MTPLossLoggingHelper: + """Helper class for logging MTP losses.""" +@@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule): + def _get_embeddings( + self, + input_ids: torch.Tensor, +- position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, ++ position_ids: Optional[torch.Tensor] = None, ++ packed_seq_params: Optional[PackedSeqParams] = None, + ): + """ + Preprocesses input data for the Multi-Token Prediction (MTP) layers. +@@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule): + sequence length, b is the batch size, and h is the hidden size. + """ + # Calc logits for the current Multi-Token Prediction (MTP) layers. +- input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group) +- position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group) ++ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ ++ # Prepare/roll position ids only when applicable. ++ if position_ids is None: ++ # Fallback position ids for learned absolute embedding. ++ seq_len = input_ids.size(-1) ++ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) ++ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) ++ ++ position_ids, _ = roll_tensor( ++ position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): +- """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" ++ """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`.""" + if self.config.fp8: + from megatron.core.extensions.transformer_engine import te_checkpoint + + return te_checkpoint( +- forward_func, ++ run, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +@@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule): + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + assert context is None, f"multi token prediction + cross attention is not yet supported." +- assert ( +- packed_seq_params is None +- ), f"multi token prediction + sequence packing is not yet supported." + + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, ++ packed_seq_params=packed_seq_params, + ) + + if self.config.recompute_granularity == 'full' and self.training: +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index d55bebe7e..1eecbbd38 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): + qk_layernorm: bool = False + """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ use_gated_attention: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 84f22bdea..f0f3f8e86 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index e3459c5ee..7346bf35b 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -937,8 +937,6 @@ def validate_args(args, defaults={}): + # MoE Spec check + if args.num_experts == 0: + args.num_experts = None +- if args.num_experts is not None: +- assert args.spec is None, "Model Spec must be None when using MoEs" + if args.num_experts is not None and args.moe_ffn_hidden_size is None: + args.moe_ffn_hidden_size = args.ffn_hidden_size + print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") +@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): + if args.is_hybrid_model: + kw_args['is_hybrid_model'] = args.is_hybrid_model + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ kw_args['use_gated_attention'] = args.use_gated_attention ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 5cf222ccc..d1554ca4c 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): + f"The transformers library must be installed to use huggingface_tokenizer_provider" + ) + ++ if "trust_remote_code" not in kwargs: ++ kwargs["trust_remote_code"] = True + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs diff --git a/docker/patch/v0.5.6/sglang.patch b/docker/patch/v0.5.6/sglang.patch new file mode 100644 index 000000000..de12cdd43 --- /dev/null +++ b/docker/patch/v0.5.6/sglang.patch @@ -0,0 +1,2053 @@ +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index ef52bda7f..537d892dc 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -296,6 +296,13 @@ class DecodePreallocQueue: + ) + return kv_manager + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + def add(self, req: Req, is_retracted: bool = False) -> None: + """Add a request to the pending queue.""" + if self._check_if_req_exceed_kv_capacity(req): +diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py +index d4414d084..c5fb10155 100644 +--- a/python/sglang/srt/disaggregation/mooncake/conn.py ++++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -1074,6 +1074,19 @@ class MooncakeKVManager(CommonKVManager): + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" + ) + ++ def close(self): ++ # Batch deregister KV data buffers ++ if self.kv_args.kv_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.kv_data_ptrs) ++ ++ # Batch deregister auxiliary data buffers ++ if self.kv_args.aux_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.aux_data_ptrs) ++ ++ # Batch deregister state/extra pool data buffers ++ if self.kv_args.state_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.state_data_ptrs) ++ + + class MooncakeKVSender(CommonKVSender): + +diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py +index 952374ed5..239ac2571 100644 +--- a/python/sglang/srt/disaggregation/prefill.py ++++ b/python/sglang/srt/disaggregation/prefill.py +@@ -305,6 +305,13 @@ class PrefillBootstrapQueue: + else: + return bootstrapped_reqs, failed_reqs + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + + class SchedulerDisaggregationPrefillMixin: + """ +diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py +index 86c53f26b..52acf95b9 100644 +--- a/python/sglang/srt/distributed/device_communicators/pynccl.py ++++ b/python/sglang/srt/distributed/device_communicators/pynccl.py +@@ -380,3 +380,9 @@ class PyNcclCommunicator: + + self.disabled = old_disable + self.stream = old_stream ++ ++ def nccl_pause(self): ++ self.nccl.ncclPause(self.comm) ++ ++ def nccl_resume(self): ++ self.nccl.ncclResume(self.comm) +diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +index 6b12f2922..7028a4e46 100644 +--- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py ++++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +@@ -304,6 +304,17 @@ class NCCLLibrary: + Function("ncclGroupEnd", ncclResult_t, []), + ] + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ exported_functions.extend( ++ [ ++ # ncclResult_t ncclPause(ncclComm_t comm); ++ Function("ncclPause", ncclResult_t, [ncclComm_t]), ++ # ncclResult_t ncclResume(ncclComm_t comm); ++ Function("ncclResume", ncclResult_t, [ncclComm_t]), ++ Function("ncclSetGroupID", ncclResult_t, [ctypes.c_int]), ++ ] ++ ) ++ + exported_functions_symm_mem = [ + # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags); + Function( +@@ -551,6 +562,12 @@ class NCCLLibrary: + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + ++ def ncclPause(self, comm: ncclComm_t) -> None: ++ self.NCCL_CHECK(self._funcs["ncclPause"](comm)) ++ ++ def ncclResume(self, comm: ncclComm_t) -> None: ++ self.NCCL_CHECK(self._funcs["ncclResume"](comm)) ++ + + __all__ = [ + "NCCLLibrary", +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index cf90f6fe0..11d26df81 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1780,7 +1780,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py +index 67a082ea6..390365864 100644 +--- a/python/sglang/srt/entrypoints/engine.py ++++ b/python/sglang/srt/entrypoints/engine.py +@@ -183,6 +183,7 @@ class Engine(EngineBase): + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: bool = False, ++ return_routed_experts: bool = False, + stream: bool = False, + bootstrap_host: Optional[Union[List[str], str]] = None, + bootstrap_port: Optional[Union[List[int], int]] = None, +@@ -218,6 +219,7 @@ class Engine(EngineBase): + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + return_hidden_states=return_hidden_states, ++ return_routed_experts=return_routed_experts, + stream=stream, + bootstrap_host=bootstrap_host, + bootstrap_port=bootstrap_port, +diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py +index 9f556a885..992843285 100644 +--- a/python/sglang/srt/layers/attention/vision.py ++++ b/python/sglang/srt/layers/attention/vision.py +@@ -518,11 +518,25 @@ class VisionAttention(nn.Module): + self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size + + if self.qk_normalization: ++ norm_kwargs = ( ++ dict( ++ weight_dtype=torch.float32, ++ cast_x_before_out_mul=True, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) + self.q_norm = RMSNorm( +- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ++ self.dummy_dim, ++ eps=layer_norm_eps, ++ var_hidden_size=embed_dim, ++ **norm_kwargs, + ) + self.k_norm = RMSNorm( +- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ++ self.dummy_dim, ++ eps=layer_norm_eps, ++ var_hidden_size=embed_dim, ++ **norm_kwargs, + ) + + # Select attention backend via a unified method +@@ -648,6 +662,15 @@ class VisionAttention(nn.Module): + if x.dim() == 2: + x = x.unsqueeze(0) + assert x.dim() == 3, x.shape ++ if ( ++ get_global_server_args().rl_on_policy_target is not None ++ and position_embeddings is not None ++ ): ++ assert isinstance(position_embeddings, tuple), ( ++ "expected position_embeddings to be a tuple of two tensors,\n" ++ f"but got {type(position_embeddings)}, change if needed" ++ ) ++ position_embeddings = tuple(p.to(x.dtype) for p in position_embeddings) + x_shape = x.shape + bsz, s, _ = x_shape + head = self.num_attention_heads_per_partition +diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py +index 932f52aeb..79c6b664f 100644 +--- a/python/sglang/srt/layers/communicator.py ++++ b/python/sglang/srt/layers/communicator.py +@@ -372,6 +372,7 @@ class LayerCommunicator: + residual: torch.Tensor, + forward_batch: ForwardBatch, + quant_format: str = "", ++ post_residual_addition: Optional[torch.Tensor] = None, + ): + if get_attn_tp_context().input_scattered: + hidden_states, residual = self._tp_reduce_scatter( +@@ -453,7 +454,9 @@ class LayerCommunicator: + ) + else: + hidden_states, residual = self.input_layernorm( +- hidden_states, residual ++ hidden_states, ++ residual, ++ post_residual_addition, + ) + + hidden_states = self._communicate_simple_fn( +diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py +index 3293a8a59..a075b71ce 100644 +--- a/python/sglang/srt/layers/layernorm.py ++++ b/python/sglang/srt/layers/layernorm.py +@@ -84,15 +84,12 @@ class RMSNorm(CustomOp): + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + cast_x_before_out_mul: bool = False, +- fp32_residual: bool = False, +- weight_dtype: Optional = None, +- override_orig_dtype: Optional = None, ++ fp32_residual: bool = True, + ) -> None: + super().__init__() + self.cast_x_before_out_mul = cast_x_before_out_mul + self.fp32_residual = fp32_residual +- self.override_orig_dtype = override_orig_dtype +- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) ++ self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( +@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, post_residual_addition) + if is_batch_invariant_mode_enabled(): + if ( + residual is not None + or get_global_server_args().rl_on_policy_target == "fsdp" + ): +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, post_residual_addition) + return rms_norm_batch_invariant( + x, + self.weight.data, + self.variance_epsilon, + ) + if residual is not None: ++ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). ++ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) +@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() +- orig_dtype = self.override_orig_dtype or x.dtype ++ orig_dtype = x.dtype ++ ++ if residual is not None and not self.fp32_residual: ++ x = ( ++ x ++ + residual ++ + ( ++ post_residual_addition ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.clone() + x = x.to(torch.float32) +- if residual is not None: +- x = x + residual.to(torch.float32) +- if self.fp32_residual: +- residual = x.clone() +- else: +- residual = x.to(orig_dtype) ++ if residual is not None and self.fp32_residual: ++ x = ( ++ x ++ + residual.to(torch.float32) ++ + ( ++ post_residual_addition.to(torch.float32) ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index 522865765..733bad5f2 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -841,11 +841,6 @@ class LogitsProcessor(nn.Module): + None, # bias + True, # is_vnni + ) +- elif get_global_server_args().rl_on_policy_target is not None: +- # Due to tie-weight, we may not be able to change lm_head's weight dtype +- logits = torch.matmul( +- hidden_states.bfloat16(), lm_head.weight.T.bfloat16() +- ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +index e7d5a67cc..639e47163 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@@ -14,6 +14,7 @@ import torch.nn.functional as F + import triton.language as tl + + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + direct_register_custom_op, +@@ -626,7 +627,10 @@ def fused_experts_impl( + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. +- if tokens_in_chunk <= 32: ++ if ( ++ not get_global_server_args().enable_deterministic_inference ++ and tokens_in_chunk <= 32 ++ ): + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py +new file mode 100644 +index 000000000..e16817f1f +--- /dev/null ++++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py +@@ -0,0 +1,279 @@ ++import logging ++from abc import ABC ++from contextlib import contextmanager ++from typing import Optional ++ ++import numpy as np ++import torch ++ ++from sglang.srt.configs.model_config import ModelConfig ++from sglang.srt.layers.dp_attention import ( ++ get_attention_dp_rank, ++ get_dp_local_info, ++ is_dp_attention_enabled, ++) ++from sglang.srt.mem_cache.memory_pool import ReqToTokenPool ++from sglang.srt.server_args import get_global_server_args ++ ++logger = logging.getLogger(__name__) ++ ++_GB = 1024 * 1024 * 1024 ++_MB = 1024 * 1024 ++ ++ ++def get_tensor_size_bytes(t: torch.Tensor): ++ return np.prod(t.shape) * t.dtype.itemsize ++ ++ ++class _RoutedExpertsDeviceCache: ++ def __init__( ++ self, ++ max_running_requests: int, ++ num_hidden_layers: int, ++ num_experts_per_tok: int, ++ num_fused_shared_experts: int, ++ device: str, ++ ) -> None: ++ self.buffer = torch.zeros( ++ ( ++ max( ++ get_global_server_args().chunked_prefill_size ++ * get_global_server_args().dp_size, ++ max_running_requests, ++ ), ++ num_hidden_layers, ++ num_experts_per_tok + num_fused_shared_experts, ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ self._finalize_allocation_log() ++ ++ def get_buffer_size_bytes(self): ++ assert hasattr(self, "buffer") ++ return get_tensor_size_bytes(self.buffer) ++ ++ def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): ++ assert layer_id is not None, "capturing routing experts but get layer_id None" ++ batch, _ = topk_ids.shape ++ self.buffer[:batch, layer_id, :] = topk_ids ++ ++ def _finalize_allocation_log(self): ++ """Common logging and memory usage computation for captured experts buffers.""" ++ buffer_size_MB = self.get_buffer_size_bytes() / _MB ++ logger.info( ++ f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" ++ ) ++ ++ ++class _RoutedExpertsHostCache: ++ def __init__( ++ self, ++ num_tokens: int, ++ num_hidden_layers: int, ++ num_experts_per_tok: int, ++ ) -> None: ++ self.num_tokens = num_tokens ++ self.buffer = torch.zeros( ++ ( ++ num_tokens, ++ num_hidden_layers, ++ num_experts_per_tok, ++ ), ++ dtype=torch.int32, ++ device="cpu", ++ pin_memory=True, ++ ) ++ self._finalize_allocation_log() ++ ++ def get_buffer_size_bytes(self): ++ assert hasattr(self, "buffer") ++ return get_tensor_size_bytes(self.buffer) ++ ++ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): ++ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) ++ ++ def _finalize_allocation_log(self): ++ """Common logging and memory usage computation for captured experts buffers.""" ++ buffer_size_GB = self.get_buffer_size_bytes() / _GB ++ logger.info( ++ f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" ++ ) ++ ++ ++class RoutedExpertsCapturer(ABC): ++ @staticmethod ++ def create( ++ enable: bool, ++ model_config: ModelConfig, ++ num_fused_shared_experts: int, ++ num_tokens: int, ++ max_running_requests: int, ++ device: str, ++ ): ++ if enable: ++ return _RoutedExpertsCapturerReal( ++ model_config, ++ num_tokens=num_tokens, ++ max_running_requests=max_running_requests, ++ num_fused_shared_experts=num_fused_shared_experts, ++ device=device, ++ ) ++ else: ++ return _RoutedExpertsCapturerNoop() ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ raise NotImplementedError ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ raise NotImplementedError ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ raise NotImplementedError ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ yield ++ ++ def get_host_cache(self): ++ raise NotImplementedError ++ ++ def get_device_cache(self): ++ raise NotImplementedError ++ ++ ++class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): ++ """Capturer for routed experts with host buffer""" ++ ++ def __init__( ++ self, ++ model_config: ModelConfig, ++ num_tokens: int, ++ max_running_requests: int, ++ num_fused_shared_experts: int, ++ device: str, ++ ): ++ self.forward_batch = None ++ self.num_fused_shared_experts = num_fused_shared_experts ++ self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers ++ self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok ++ ++ self.host_cache = _RoutedExpertsHostCache( ++ num_tokens=num_tokens, ++ num_hidden_layers=self.num_hidden_layers, ++ num_experts_per_tok=self.num_experts_per_tok, ++ ) ++ ++ self.device_cache = _RoutedExpertsDeviceCache( ++ max_running_requests=max_running_requests, ++ num_hidden_layers=self.num_hidden_layers, ++ num_experts_per_tok=self.num_experts_per_tok, ++ num_fused_shared_experts=self.num_fused_shared_experts, ++ device=device, ++ ) ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ if is_dp_attention_enabled(): ++ local_start_pos, local_num_tokens = get_dp_local_info(self.forward_batch) ++ # handle with cuda graph padding ++ if can_run_graph: ++ local_start_pos = get_attention_dp_rank() * cuda_graph_batch ++ local_end_pos = local_start_pos + local_num_tokens ++ else: ++ local_end_pos = local_start_pos + local_num_tokens ++ else: ++ local_start_pos = 0 ++ local_end_pos = device_loc.shape[0] ++ ++ self.host_cache.buffer[cpu_loc] = self.device_cache.buffer[ ++ local_start_pos:local_end_pos, :, : self.num_experts_per_tok ++ ].cpu() ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ cache_pool_idx = ( ++ req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() ++ ) ++ return self.get_host_cache().buffer[cache_pool_idx] ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ self.forward_batch = forward_batch ++ yield ++ ++ def get_host_cache(self): ++ return self.host_cache ++ ++ def get_device_cache(self): ++ return self.device_cache ++ ++ ++class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): ++ def __init__(self): ++ pass ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ pass ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ pass ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ pass ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ yield ++ ++ def get_host_cache(self): ++ pass ++ ++ def get_device_cache(self): ++ pass ++ ++ ++_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() ++ ++ ++def get_global_experts_capturer(): ++ return _global_expert_capturer ++ ++ ++def set_global_experts_capturer(capturer: RoutedExpertsCapturer): ++ global _global_expert_capturer ++ _global_expert_capturer = capturer +\ No newline at end of file +diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py +index a802647e8..0fd550c0c 100644 +--- a/python/sglang/srt/layers/moe/topk.py ++++ b/python/sglang/srt/layers/moe/topk.py +@@ -48,6 +48,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( + ) + from sglang.srt.layers.dp_attention import is_allocation_symmetric + from sglang.srt.layers.moe import get_moe_runner_backend ++from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +@@ -212,6 +213,7 @@ class TopK(CustomOp): + self, + top_k: int, + *, ++ layer_id: Optional[int] = None, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, +@@ -233,6 +235,7 @@ class TopK(CustomOp): + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + ++ self.layer_id = layer_id + self.topk_config = TopKConfig( + top_k=top_k, + use_grouped_topk=use_grouped_topk, +@@ -260,6 +263,7 @@ class TopK(CustomOp): + self.topk_config.torch_native = True + return select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -309,6 +313,7 @@ class TopK(CustomOp): + ): + topk_output = select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -326,6 +331,7 @@ class TopK(CustomOp): + ) -> TopKOutput: + return select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -856,6 +862,7 @@ def select_experts( + router_logits: torch.Tensor, + topk_config: TopKConfig, + *, ++ layer_id: Optional[int] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> StandardTopKOutput: +@@ -983,7 +990,10 @@ def select_experts( + ) + + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) +- ++ get_global_experts_capturer().capture( ++ layer_id=layer_id, ++ topk_ids=topk_ids, ++ ) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) + + +diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py +index 70466bb20..cd85fc2f2 100644 +--- a/python/sglang/srt/layers/moe/utils.py ++++ b/python/sglang/srt/layers/moe/utils.py +@@ -284,7 +284,7 @@ def speculative_moe_a2a_backend_context(): + global MOE_A2A_BACKEND + original_backend = MOE_A2A_BACKEND + try: +- MOE_A2A_BACKEND = MoeA2ABackend.NONE ++ MOE_A2A_BACKEND = get_speculative_moe_a2a_backend() + yield + finally: + MOE_A2A_BACKEND = original_backend +diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py +index 0cdb7e1ae..df8860409 100644 +--- a/python/sglang/srt/layers/rotary_embedding.py ++++ b/python/sglang/srt/layers/rotary_embedding.py +@@ -15,7 +15,6 @@ from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +- get_compiler_backend, + is_cpu, + is_cuda, + is_hip, +@@ -132,9 +131,7 @@ class RotaryEmbedding(CustomOp): + + if get_global_server_args().rl_on_policy_target is not None: + self._forward_method = self.forward_native +- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( +- self._apply_rotary_emb_wrapped +- ) ++ + self.position_cos, self.position_sin = None, None + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: +@@ -1423,6 +1420,9 @@ class MRotaryEmbedding(RotaryEmbedding): + f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" + ) + ++ if get_global_server_args().rl_on_policy_target is not None: ++ self._forward_method = self.forward_native ++ + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible +@@ -1432,8 +1432,7 @@ class MRotaryEmbedding(RotaryEmbedding): + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + +- @torch.compile(dynamic=True, backend=get_compiler_backend()) +- def _forward_native( ++ def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1490,7 +1489,7 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + +- def forward( ++ def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1507,14 +1506,12 @@ class MRotaryEmbedding(RotaryEmbedding): + """ + assert positions.ndim == 1 or positions.ndim == 2 + +- if positions.ndim == 2 and self.mrope_section and _is_cuda: +- return self._forward_triton(positions, query, key) +- elif _is_npu: +- return self._forward_npu(positions, query, key) +- else: +- return self._forward_native(positions, query, key) ++ # Use Triton kernel for multimodal (2D positions) with mrope ++ if positions.ndim == 2 and self.mrope_section: ++ return self.forward_triton(positions, query, key) ++ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) + +- def _forward_triton( ++ def forward_triton( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1563,15 +1560,19 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + +- def _forward_npu( ++ def forward_npu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, ++ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: ++ assert ( ++ fused_set_kv_buffer_arg is None ++ ), "fused_set_kv_buffer_arg is not supported for npu implementation" + # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 + if query.shape[1] > 4096: +- return self._forward_native(positions, query, key) ++ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) + rotary_mode = "half" + if self.is_neox_style: + rotary_mode = "half" +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 7f6f6a010..c4a673145 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -105,16 +105,11 @@ class Sampler(nn.Module): + if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: + probs_without_temp_scaling = torch.softmax(logits, dim=-1) + +- if get_global_server_args().rl_on_policy_target is not None: +- logits_div_temperature = ( +- logits.bfloat16().div(sampling_info.temperatures).bfloat16() +- ) +- logprobs_via_logsoftmax_kernel = torch.log_softmax( +- logits_div_temperature, dim=-1 +- ) +- + # Post process logits + logits.div_(sampling_info.temperatures) ++ if get_global_server_args().rl_on_policy_target is not None: ++ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) ++ + # For ascend backend, softmax is not needed before sampling + if not get_global_server_args().sampling_backend == "ascend" or ( + return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB +diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py +index 87922077e..8cb6bad8d 100644 +--- a/python/sglang/srt/managers/detokenizer_manager.py ++++ b/python/sglang/srt/managers/detokenizer_manager.py +@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + s.sent_offset = len(output_str) + output_strs.append(incremental_output) + ++ output_routed_experts = [] ++ if recv_obj.output_routed_experts is not None: ++ output_routed_experts = [ ++ ( ++ output_routed_experts.tolist() ++ if output_routed_experts is not None ++ else [] ++ ) ++ for output_routed_experts in recv_obj.output_routed_experts ++ ] + return BatchStrOutput( + rids=recv_obj.rids, + http_worker_ipcs=recv_obj.http_worker_ipcs, +@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, + output_token_entropy_val=recv_obj.output_token_entropy_val, + output_hidden_states=recv_obj.output_hidden_states, ++ output_routed_experts=output_routed_experts, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, + retraction_counts=recv_obj.retraction_counts, +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index e34736cc4..5e5997a1a 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -23,6 +23,8 @@ from dataclasses import dataclass, field + from enum import Enum + from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + ++import torch ++ + from sglang.srt.lora.lora_registry import LoRARef + from sglang.srt.managers.schedule_batch import BaseFinishReason + from sglang.srt.multimodal.mm_utils import has_valid_data +@@ -175,6 +177,8 @@ class GenerateReqInput(BaseReq): + log_metrics: bool = True + # Whether to return hidden states + return_hidden_states: Union[List[bool], bool] = False ++ # Whether to return captured routed experts ++ return_routed_experts: bool = False + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None +@@ -592,6 +596,7 @@ class GenerateReqInput(BaseReq): + if isinstance(self.return_hidden_states, list) + else self.return_hidden_states + ), ++ return_routed_experts=self.return_routed_experts, + modalities=self.modalities[i] if self.modalities else None, + session_params=self.session_params, + lora_path=self.lora_path[i] if self.lora_path is not None else None, +@@ -655,6 +660,9 @@ class TokenizedGenerateReqInput(BaseReq): + # Whether to return hidden states + return_hidden_states: bool = False + ++ # Whether to return captured routed experts ++ return_routed_experts: bool = False ++ + # The input embeds + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None + +@@ -910,6 +918,9 @@ class BatchTokenIDOutput( + # Hidden states + output_hidden_states: List[List[float]] + ++ # The routed experts for each output token ++ output_routed_experts: List[torch.Tensor] ++ + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. +@@ -989,6 +1000,9 @@ class BatchStrOutput( + # Hidden states + output_hidden_states: List[List[float]] + ++ # The routed experts for each output token ++ output_routed_experts: List[List[int]] ++ + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index c4c5a9ebb..1450c5fd8 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -450,6 +450,7 @@ class Req: + session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, + return_hidden_states: bool = False, ++ return_routed_experts: bool = False, + eos_token_ids: Optional[Set[int]] = None, + bootstrap_host: Optional[str] = None, + bootstrap_port: Optional[int] = None, +@@ -629,6 +630,12 @@ class Req: + self.output_topk_p = None + self.output_topk_index = None + ++ # capture routed experts ++ self.return_routed_experts = return_routed_experts ++ self.routed_experts: Optional[torch.Tensor] = ( ++ None # cpu tensor: shape (seqlen, topk) ++ ) ++ + # Embedding (return values) + self.embedding = None + +@@ -992,6 +999,7 @@ class Req: + self.retraction_count += 1 + + self.prefix_indices = torch.empty((0,), dtype=torch.int64) ++ self.routed_experts = [] + self.last_node = None + self.swa_uuid_for_lock = None + self.extend_input_len = 0 +@@ -1159,6 +1167,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + # Whether to return hidden states + return_hidden_states: bool = False + ++ # Whether to return captured experts ++ return_routed_experts: bool = False ++ + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + +@@ -1206,6 +1217,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + device=req_to_token_pool.device, + spec_algorithm=spec_algorithm, + return_hidden_states=any(req.return_hidden_states for req in reqs), ++ return_routed_experts=any(req.return_routed_experts for req in reqs), + is_prefill_only=all(req.is_prefill_only for req in reqs), + chunked_req=chunked_req, + dllm_config=dllm_config, +@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.req_pool_indices = req_pool_indices_tensor + self.orig_seq_lens = orig_seq_lens_tensor + self.out_cache_loc = out_cache_loc ++ self.out_cache_loc_cpu = out_cache_loc.cpu() + self.input_embeds = ( + torch.tensor(input_embeds).to(self.device, non_blocking=True) + if input_embeds +@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) ++ out_cache_loc_cpu = torch.cat( ++ [self.out_cache_loc_cpu, running_batch.out_cache_loc_cpu] ++ ) + + self.merge_batch(running_batch) + self.input_ids = input_ids + self.out_cache_loc = out_cache_loc ++ self.out_cache_loc_cpu = out_cache_loc_cpu + + # For overlap scheduler, the output_ids has one step delay + delta = 0 if self.enable_overlap else -1 +@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) + self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) ++ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu") + self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens_sum = 0 + self.extend_num_tokens = 0 +@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + # Allocate memory + self.out_cache_loc = alloc_for_decode(self, token_per_req=1) ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + # Update req-level memory management fields + for req in self.reqs: +@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] + self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] + self.out_cache_loc = None ++ self.out_cache_loc_cpu = None + self.seq_lens_sum = self.seq_lens.sum().item() + self.output_ids = self.output_ids[keep_indices_device] + self.return_logprob = any(req.return_logprob for req in self.reqs) +@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) + self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) + self.out_cache_loc = None ++ self.out_cache_loc_cpu = None + self.seq_lens_sum += other.seq_lens_sum + if self.output_ids is not None: + self.output_ids = torch.cat([self.output_ids, other.output_ids]) +@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + seq_lens=self.seq_lens, + orig_seq_lens=self.orig_seq_lens, + out_cache_loc=self.out_cache_loc, ++ out_cache_loc_cpu=self.out_cache_loc_cpu, + seq_lens_cpu=seq_lens_cpu, + seq_lens_sum=self.seq_lens_sum, + return_logprob=self.return_logprob, +@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " +- f"#req={(len(self.reqs))})" ++ f"#req={(len(self.reqs))}), " ++ f"#out_cache_loc={self.out_cache_loc})" + ) + + +@@ -2038,6 +2061,9 @@ class ModelWorkerBatch: + # Sampling info + sampling_info: SamplingBatchInfo + ++ # cpu copy of out_cache_loc ++ out_cache_loc_cpu: Optional[torch.Tensor] = None ++ + # The original sequence lengths, Qwen-1M related + orig_seq_lens: Optional[torch.Tensor] = None + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index b801fd8f8..9e27cc825 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -1305,6 +1305,7 @@ class Scheduler( + input_embeds=recv_req.input_embeds, + custom_logit_processor=recv_req.custom_logit_processor, + return_hidden_states=recv_req.return_hidden_states, ++ return_routed_experts=recv_req.return_routed_experts, + eos_token_ids=self.model_config.hf_eos_token_id, + bootstrap_host=recv_req.bootstrap_host, + bootstrap_port=recv_req.bootstrap_port, +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index c48f5f893..a9796c25f 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -9,6 +9,7 @@ import torch + from sglang.srt.disaggregation.utils import DisaggregationMode + from sglang.srt.environ import envs + from sglang.srt.layers.logits_processor import LogitsProcessorOutput ++from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, +@@ -112,6 +113,14 @@ class SchedulerOutputProcessorMixin: + req.check_finished() + + if req.finished(): ++ req.routed_experts = ( ++ get_global_experts_capturer().get_routed_experts( ++ req_pool_idx=req.req_pool_idx, ++ seqlen=req.seqlen, ++ req_to_token_pool=self.req_to_token_pool, ++ ) ++ ) ++ + release_kv_cache(req, self.tree_cache) + req.time_stats.completion_time = time.perf_counter() + elif not batch.decoding_reqs or req not in batch.decoding_reqs: +@@ -362,6 +371,12 @@ class SchedulerOutputProcessorMixin: + req.check_finished(new_accepted_len) + + if req.finished(): ++ req.routed_experts = get_global_experts_capturer().get_routed_experts( ++ req_pool_idx=req.req_pool_idx, ++ seqlen=req.seqlen, ++ req_to_token_pool=self.req_to_token_pool, ++ ) ++ + if self.server_args.disaggregation_decode_enable_offload_kvcache: + # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes + if not self.decode_offload_manager.offload_kv_cache(req): +@@ -756,6 +771,7 @@ class SchedulerOutputProcessorMixin: + spec_accepted_tokens = [] + retraction_counts = [] + output_hidden_states = None ++ output_routed_experts = None + + queue_times = [] + forward_entry_times = [] +@@ -946,6 +962,10 @@ class SchedulerOutputProcessorMixin: + if output_hidden_states is None: + output_hidden_states = [] + output_hidden_states.append(req.hidden_states) ++ if req.return_routed_experts: ++ if output_routed_experts is None: ++ output_routed_experts = [] ++ output_routed_experts.append(req.routed_experts) + + if ( + req.finished() +@@ -994,6 +1014,7 @@ class SchedulerOutputProcessorMixin: + output_token_ids_logprobs_idx=output_token_ids_logprobs_idx, + output_token_entropy_val=None, + output_hidden_states=output_hidden_states, ++ output_routed_experts=output_routed_experts, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, + retraction_counts=retraction_counts, +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index f8ebfc1f4..a05449fac 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -1,6 +1,7 @@ + from __future__ import annotations + + import logging ++import os + import traceback + from typing import TYPE_CHECKING, Tuple + +@@ -12,6 +13,9 @@ from sglang.srt.constants import ( + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + ) ++from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group ++from sglang.srt.layers.dp_attention import get_attention_tp_group + from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, +@@ -127,6 +131,13 @@ class SchedulerUpdateWeightsMixin: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.release_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.release_memory_occupation() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model +@@ -137,6 +148,20 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH) + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ tp_group = get_tp_group() ++ if tp_group is not None and tp_group.pynccl_comm is not None: ++ tp_group.pynccl_comm.nccl_pause() ++ attn_tp_group = get_attention_tp_group() ++ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: ++ attn_tp_group.pynccl_comm.nccl_pause() ++ moe_ep_group = get_moe_ep_group() ++ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: ++ moe_ep_group.pynccl_comm.nccl_pause() ++ moe_tp_group = get_moe_tp_group() ++ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: ++ moe_tp_group.pynccl_comm.nccl_pause() ++ + torch.get_device_module().synchronize() + + return ReleaseMemoryOccupationReqOutput() +@@ -155,6 +180,20 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH) + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ tp_group = get_tp_group() ++ if tp_group is not None and tp_group.pynccl_comm is not None: ++ tp_group.pynccl_comm.nccl_resume() ++ attn_tp_group = get_attention_tp_group() ++ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: ++ attn_tp_group.pynccl_comm.nccl_resume() ++ moe_ep_group = get_moe_ep_group() ++ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: ++ moe_ep_group.pynccl_comm.nccl_resume() ++ moe_tp_group = get_moe_tp_group() ++ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: ++ moe_tp_group.pynccl_comm.nccl_resume() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) + torch.distributed.barrier(self.tp_cpu_group) +@@ -167,6 +206,13 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.resume_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() ++ + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index b90cf0616..98d71d896 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): + session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, + return_hidden_states=obj.return_hidden_states, ++ return_routed_experts=obj.return_routed_experts, + data_parallel_rank=obj.data_parallel_rank, + priority=obj.priority, + extra_key=obj.extra_key, +@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): + if getattr(recv_obj, "output_hidden_states", None): + meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + ++ if getattr(recv_obj, "output_routed_experts", None): ++ meta_info["routed_experts"] = recv_obj.output_routed_experts[i] ++ + if isinstance(recv_obj, BatchStrOutput): + state.text += recv_obj.output_strs[i] + if self.server_args.stream_output and state.obj.stream: +@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): + return + + if len(recv_obj.input_token_logprobs_val) > 0: +- state.input_token_logprobs_val.extend( +- recv_obj.input_token_logprobs_val[recv_obj_index] +- ) +- state.input_token_logprobs_idx.extend( +- recv_obj.input_token_logprobs_idx[recv_obj_index] +- ) ++ if recv_obj.input_token_logprobs_val[recv_obj_index]: ++ state.input_token_logprobs_val.extend( ++ recv_obj.input_token_logprobs_val[recv_obj_index] ++ ) ++ state.input_token_logprobs_idx.extend( ++ recv_obj.input_token_logprobs_idx[recv_obj_index] ++ ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) +diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py +index 3a85e6a7e..2859dafa1 100644 +--- a/python/sglang/srt/model_executor/forward_batch_info.py ++++ b/python/sglang/srt/model_executor/forward_batch_info.py +@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( + set_dp_buffer_len, + set_is_extend_in_batch, + ) ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import get_compiler_backend, is_npu, support_triton + from sglang.srt.utils.common import ceil_align + +@@ -214,6 +215,9 @@ class ForwardBatch: + # The sum of all sequence lengths + seq_lens_sum: int + ++ # cpu copy of out_cache_loc ++ out_cache_loc_cpu: Optional[torch.Tensor] = None ++ + # The original sequence length without being chunked. Qwen-1M related. + orig_seq_lens: Optional[torch.Tensor] = None + +@@ -368,6 +372,7 @@ class ForwardBatch: + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + out_cache_loc=batch.out_cache_loc, ++ out_cache_loc_cpu=batch.out_cache_loc_cpu, + mm_inputs=batch.multimodal_inputs, + encoder_cached=batch.encoder_cached, + encoder_lens=batch.encoder_lens, +@@ -623,7 +628,10 @@ class ForwardBatch: + mm_input = batch.multimodal_inputs[batch_idx] + if self.forward_mode.is_decode(): + # 3 * N +- if mm_input is None: ++ if ( ++ mm_input is None ++ or get_global_server_args().rl_on_policy_target is not None ++ ): + mrope_positions_list[batch_idx] = torch.full( + (3, 1), + self.seq_lens[batch_idx] - 1, +@@ -640,7 +648,10 @@ class ForwardBatch: + batch.extend_seq_lens[batch_idx], + batch.extend_prefix_lens[batch_idx], + ) +- if mm_input is None: ++ if ( ++ mm_input is None ++ or get_global_server_args().rl_on_policy_target is not None ++ ): + # text only + mrope_positions = torch.tensor( + [ +@@ -823,6 +834,10 @@ class ForwardBatch: + ) + + self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens) ++ if self.out_cache_loc_cpu is not None: ++ self.out_cache_loc_cpu = self._pad_tensor_to_size( ++ self.out_cache_loc_cpu, num_tokens ++ ) + if self.encoder_lens is not None: + self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) + self.positions = self._pad_tensor_to_size(self.positions, num_tokens) +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 4d58278b7..8f50dc430 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import ( + set_is_extend_in_batch, + ) + from sglang.srt.layers.logits_processor import LogitsProcessorOutput ++from sglang.srt.layers.moe.routed_experts_capturer import ( ++ RoutedExpertsCapturer, ++ get_global_experts_capturer, ++ set_global_experts_capturer, ++) + from sglang.srt.layers.pooler import EmbeddingPoolerOutput + from sglang.srt.layers.sampler import Sampler + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model +@@ -502,6 +507,10 @@ class ModelRunner: + server_args.max_running_requests, + server_args.max_total_tokens, + ) ++ ++ # Init routed experts capturer ++ self.init_routed_experts_capturer() ++ + if self.device == "cuda": + self.init_cublas() + self.init_attention_backend() +@@ -545,6 +554,40 @@ class ModelRunner: + # Initialize piecewise CUDA graph + self.init_piecewise_cuda_graphs() + ++ def init_routed_experts_capturer(self): ++ # TODO: the redundant logic with TpModelWorker ++ max_running_requests = min( ++ ( ++ self.max_total_num_tokens // 2 ++ if self.server_args.max_running_requests is None ++ else self.server_args.max_running_requests ++ // ( ++ self.server_args.dp_size ++ if self.server_args.enable_dp_attention ++ else 1 ++ ) ++ ), ++ self.req_to_token_pool.size, ++ ) ++ ++ if not self.server_args.disable_shared_experts_fusion and hasattr( ++ self.model, "num_fused_shared_experts" ++ ): ++ num_fused_shared_experts = self.model.num_fused_shared_experts ++ else: ++ num_fused_shared_experts = 0 ++ ++ set_global_experts_capturer( ++ RoutedExpertsCapturer.create( ++ enable=get_global_server_args().enable_return_routed_experts, ++ model_config=self.model_config, ++ num_fused_shared_experts=num_fused_shared_experts, ++ num_tokens=self.max_total_num_tokens + self.page_size, ++ max_running_requests=max_running_requests, ++ device=self.device, ++ ) ++ ) ++ + def model_specific_adjustment(self): + server_args = self.server_args + +@@ -792,7 +835,11 @@ class ModelRunner: + ) + with self.memory_saver_adapter.region( + GPU_MEMORY_TYPE_WEIGHTS, +- enable_cpu_backup=enable_cpu_backup, ++ enable_cpu_backup=( ++ self.server_args.enable_weights_cpu_backup ++ if not self.is_draft_worker ++ else True ++ ), + ): + self.model = get_model( + model_config=self.model_config, +@@ -2645,9 +2692,12 @@ class ModelRunner: + ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: + self.forward_pass_id += 1 + +- with get_global_expert_distribution_recorder().with_forward_pass( +- self.forward_pass_id, +- forward_batch, ++ with ( ++ get_global_expert_distribution_recorder().with_forward_pass( ++ self.forward_pass_id, ++ forward_batch, ++ ), ++ get_global_experts_capturer().with_forward(forward_batch), + ): + output = self._forward_raw( + forward_batch, +@@ -2656,6 +2706,13 @@ class ModelRunner: + reinit_attn_backend, + split_forward_count, + ) ++ # Copy cached routing experts' buffers back to CPU cache ++ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( ++ device_loc=forward_batch.out_cache_loc, ++ cpu_loc=forward_batch.out_cache_loc_cpu, ++ can_run_graph=output[1], ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) + + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() +diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py +index dc30b4f0a..f29dc4b71 100644 +--- a/python/sglang/srt/models/deepseek_v2.py ++++ b/python/sglang/srt/models/deepseek_v2.py +@@ -667,6 +667,7 @@ class DeepseekV2MoE(nn.Module): + + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, ++ layer_id=self.layer_id, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, +diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py +index ab1b6576b..dffd8f09a 100644 +--- a/python/sglang/srt/models/ernie4.py ++++ b/python/sglang/srt/models/ernie4.py +@@ -87,6 +87,7 @@ class Ernie4Moe(nn.Module): + + self.topk = TopK( + top_k=config.moe_k, ++ layer_id=layer_id, + renormalize=True, + use_grouped_topk=False, + correction_bias=self.gate.e_score_correction_bias, +diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py +index a9689b8f2..bc8538da8 100644 +--- a/python/sglang/srt/models/glm4_moe.py ++++ b/python/sglang/srt/models/glm4_moe.py +@@ -379,6 +379,17 @@ class Glm4MoeSparseMoeBlock(nn.Module): + + self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix)) + ++ self.topk = TopK( ++ top_k=self.top_k, ++ layer_id=self.layer_id, ++ renormalize=config.norm_topk_prob, ++ use_grouped_topk=True, ++ num_expert_group=config.n_group, ++ topk_group=config.topk_group, ++ correction_bias=self.gate.e_score_correction_bias, ++ routed_scaling_factor=self.routed_scaling_factor, ++ ) ++ + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.n_routed_experts + self.num_fused_shared_experts, + num_fused_shared_experts=self.num_fused_shared_experts, +diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py +index 9474700c4..398d622ff 100644 +--- a/python/sglang/srt/models/gpt_oss.py ++++ b/python/sglang/srt/models/gpt_oss.py +@@ -113,6 +113,7 @@ class GptOssSparseMoeBlock(nn.Module): + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, ++ layer_id=layer_id, + ) + + self.top_k = config.num_experts_per_tok +diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py +index fd513060a..a089475b7 100644 +--- a/python/sglang/srt/models/grok.py ++++ b/python/sglang/srt/models/grok.py +@@ -142,6 +142,7 @@ class Grok1MoE(nn.Module): + self.topk = TopK( + top_k=top_k, + renormalize=False, ++ layer_id=layer_id, + custom_routing_function=custom_routing_function, + ) + +diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py +index 7c6fd9e48..b20d28544 100644 +--- a/python/sglang/srt/models/hunyuan.py ++++ b/python/sglang/srt/models/hunyuan.py +@@ -150,6 +150,7 @@ class HunYuanSparseMoeBlock(nn.Module): + + self.topk = TopK( + top_k=top_k, ++ layer_id=layer_id, + renormalize=True if top_k > 1 else False, + ) + +diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py +index 3530609ba..01c89e893 100644 +--- a/python/sglang/srt/models/longcat_flash.py ++++ b/python/sglang/srt/models/longcat_flash.py +@@ -245,6 +245,7 @@ class LongcatFlashMoE(nn.Module): + renormalize=False, + use_grouped_topk=False, + correction_bias=self.router.e_score_correction_bias.data, ++ layer_id=layer_id, + ) + self.topk.forward = self.topk.forward_native + +diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py +index a7dbadec6..c83a41338 100644 +--- a/python/sglang/srt/models/qwen2.py ++++ b/python/sglang/srt/models/qwen2.py +@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): + self.act_fn = SiluAndMul() + + def forward(self, x): +- if get_global_server_args().rl_on_policy_target is not None: +- x = x.bfloat16() +- + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) +@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), +- params_dtype=( +- torch.float32 +- if get_global_server_args().rl_on_policy_target is not None +- else None +- ), + ) + else: + self.embed_tokens = PPMissingLayer() +@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): + if self.pp_group.is_last_rank: + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py +index ea33e81ef..561934dce 100644 +--- a/python/sglang/srt/models/qwen2_moe.py ++++ b/python/sglang/srt/models/qwen2_moe.py +@@ -161,6 +161,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, ++ layer_id=layer_id, + ) + + self.experts = get_moe_impl_class(quant_config)( +@@ -581,7 +582,17 @@ class Qwen2MoeModel(nn.Module): + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: +- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.norm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + else: + self.norm = PPMissingLayer(return_tuple=True) + +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 30b92acbd..a0d14895f 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -256,10 +256,8 @@ class Qwen3DecoderLayer(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -289,10 +287,14 @@ class Qwen3DecoderLayer(nn.Module): + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + hidden_states, residual = self.layer_communicator.prepare_attn( +- hidden_states, residual, forward_batch ++ hidden_states, ++ residual, ++ forward_batch, ++ post_residual_addition=post_residual_addition, + ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index 9737ac719..09c756918 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -22,6 +22,7 @@ import math + from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + + import torch ++import torch.nn.functional as F + from torch import nn + from transformers import PretrainedConfig + +@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( + ) + from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +-from sglang.srt.layers.moe.topk import TopK ++from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK + from sglang.srt.layers.moe.utils import RoutingMethodType + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.layers.radix_attention import RadixAttention +@@ -227,7 +228,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + use_grouped_topk=False, ++ layer_id=layer_id, + ) ++ self.top_k = config.num_experts_per_tok + + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.num_experts +@@ -293,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) +- topk_output = self.topk(hidden_states, router_logits) ++ ++ if get_global_server_args().rl_on_policy_target is not None: ++ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) ++ routing_weights, selected_experts = torch.topk( ++ routing_weights, self.top_k, dim=-1 ++ ) ++ routing_weights /= routing_weights.sum(dim=-1, keepdim=True) ++ routing_weights = routing_weights.to(hidden_states.dtype) ++ topk_output = StandardTopKOutput( ++ topk_weights=routing_weights, ++ topk_ids=selected_experts, ++ router_logits=router_logits, ++ ) ++ else: ++ topk_output = self.topk(hidden_states, router_logits) ++ + final_hidden_states = self.experts(hidden_states, topk_output) + if ( + self.tp_size > 1 +@@ -474,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): + ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True +- ) ++ ) and (get_global_server_args().rl_on_policy_target is None) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope ++ and (get_global_server_args().rl_on_policy_target is None) + ) + self._used_fused_qk_norm_rope_last_call = False + +@@ -493,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): + prefix=add_prefix("attn", prefix), + ) + +- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) +- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) ++ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.alt_stream = alt_stream + + def _apply_qk_norm( +@@ -751,9 +778,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) +- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.input_layernorm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + self.post_attention_layernorm = RMSNorm( +- config.hidden_size, eps=config.rms_norm_eps ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) + + self.layer_communicator = LayerCommunicator( +diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py +index ed52f7ff4..8ce9fab9d 100644 +--- a/python/sglang/srt/models/qwen3_vl.py ++++ b/python/sglang/srt/models/qwen3_vl.py +@@ -18,7 +18,6 @@ import re + from functools import lru_cache, partial + from typing import Callable, Iterable, List, Optional, Tuple, Union + +-import numpy as np + import torch + import torch.nn as nn + from einops import rearrange +@@ -349,83 +348,65 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): ++ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + num_grid_per_side = int(self.num_position_embeddings**0.5) ++ device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + +- # TODO: use torch instand of np +- for t, h, w in grid_thw: +- h_idxs = np.linspace(0, num_grid_per_side - 1, h) +- w_idxs = np.linspace(0, num_grid_per_side - 1, w) ++ for t, h, w in zip(grid_ts, grid_hs, grid_ws): ++ h_idxs = torch.linspace(0, num_grid_per_side - 1, h) ++ w_idxs = torch.linspace(0, num_grid_per_side - 1, w) + +- h_idxs_floor = h_idxs.astype(int) +- w_idxs_floor = w_idxs.astype(int) +- h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) +- w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) ++ h_idxs_floor = h_idxs.int() ++ w_idxs_floor = w_idxs.int() ++ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + +- idx_list[0].extend( +- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[1].extend( +- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[2].extend( +- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[3].extend( +- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None]) +- .flatten() +- .tolist() +- * t +- ) ++ base_h = h_idxs_floor * num_grid_per_side ++ base_h_ceil = h_idxs_ceil * num_grid_per_side + +- weight_list[0].extend( +- ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t +- ) +- weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t) +- weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t) +- weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t) ++ indices = [ ++ (base_h[None].T + w_idxs_floor[None]).flatten(), ++ (base_h[None].T + w_idxs_ceil[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ++ ] + +- device = self.pos_embed.weight.device +- dtype = self.pos_embed.weight.dtype ++ weights = [ ++ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ++ ((1 - dh)[None].T * dw[None]).flatten(), ++ (dh[None].T * (1 - dw)[None]).flatten(), ++ (dh[None].T * dw[None]).flatten(), ++ ] + +- p0 = ( +- self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None] +- ) +- p1 = ( +- self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None] +- ) +- p2 = ( +- self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None] ++ for i in range(4): ++ idx_list[i].extend(indices[i].tolist()) ++ weight_list[i].extend(weights[i].tolist()) ++ ++ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) ++ weight_tensor = torch.tensor( ++ weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) +- p3 = ( +- self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None] ++ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] ++ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ++ ++ patch_pos_embeds = patch_pos_embeds.split( ++ [h * w for h, w in zip(grid_hs, grid_ws)] + ) + +- patch_pos_embeds = p0 + p1 + p2 + p3 +- patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size +- for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): ++ merge_size = self.spatial_merge_size ++ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): ++ pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( +- pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1) ++ pos_embed.view( ++ t, h // merge_size, merge_size, w // merge_size, merge_size, -1 ++ ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) +@@ -555,21 +536,27 @@ class Qwen3LLMModel(Qwen3Model): + hidden_states + residual if residual is not None else hidden_states + ) + ++ deepstack_embeds = None ++ if input_deepstack_embeds is not None: ++ prev_layer_idx = layer_idx - 1 ++ if prev_layer_idx in self.deepstack_embed_to_decoder_layer: ++ sep = self.hidden_size * prev_layer_idx ++ deepstack_embeds = input_deepstack_embeds[ ++ :, sep : sep + self.hidden_size ++ ] ++ ++ # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. ++ # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 ++ # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack ++ # The order matters because addition with different tensors is not associative in practice. + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, ++ post_residual_addition=deepstack_embeds, + ) + +- # process deepstack +- if ( +- input_deepstack_embeds is not None +- and layer_idx in self.deepstack_embed_to_decoder_layer +- ): +- sep = self.hidden_size * layer_idx +- hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size] +- + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { +diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py +index 4474f62d5..0e537c398 100644 +--- a/python/sglang/srt/models/step3_vl.py ++++ b/python/sglang/srt/models/step3_vl.py +@@ -129,6 +129,7 @@ class Step3TextMoEMLP(nn.Module): + top_k=config.moe_top_k, + renormalize=config.norm_expert_weight, + use_grouped_topk=False, ++ layer_id=layer_id, + ) + + self.experts = get_moe_impl_class(quant_config)( +diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py +index 370aec2b6..47666d8f3 100644 +--- a/python/sglang/srt/multimodal/processors/base_processor.py ++++ b/python/sglang/srt/multimodal/processors/base_processor.py +@@ -13,6 +13,7 @@ from PIL import Image + from transformers import BaseImageProcessorFast + + from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + get_bool_env_var, + is_npu, +@@ -260,7 +261,9 @@ class BaseMultimodalProcessor(ABC): + and isinstance(processor.image_processor, BaseImageProcessorFast) + and not self.server_args.disable_fast_image_processor + ): +- if not _is_npu: ++ if get_global_server_args().rl_on_policy_target is not None: ++ kwargs["device"] = "cpu" ++ elif not _is_npu: + kwargs["device"] = "cuda" + elif processor.__class__.__name__ not in { + "Qwen2_5_VLProcessor", +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 8e7753dab..323788f39 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -535,6 +535,7 @@ class ServerArgs: + disable_fast_image_processor: bool = False + keep_mm_feature_on_device: bool = False + enable_return_hidden_states: bool = False ++ enable_return_routed_experts: bool = False + scheduler_recv_interval: int = 1 + numa_node: Optional[List[int]] = None + enable_deterministic_inference: bool = False +@@ -1966,6 +1967,9 @@ class ServerArgs: + "Enable deterministic inference because of rl_on_policy_target." + ) + self.enable_deterministic_inference = True ++ ++ # For VLM ++ os.environ["SGLANG_VLM_CACHE_SIZE_MB"] = "0" + # TODO remove this environment variable as a whole + os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1" + +@@ -3705,6 +3709,11 @@ class ServerArgs: + action="store_true", + help="Enable returning hidden states with responses.", + ) ++ parser.add_argument( ++ "--enable-return-routed-experts", ++ action="store_true", ++ help="Enable returning routed experts of each layer with responses.", ++ ) + parser.add_argument( + "--scheduler-recv-interval", + type=int, +diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py +index b3d72df05..ddfe0b178 100644 +--- a/python/sglang/srt/speculative/eagle_info.py ++++ b/python/sglang/srt/speculative/eagle_info.py +@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] ++ if self.accept_length is not None: ++ self.accept_length = self.accept_length[: len(new_indices)] ++ if self.accept_length_cpu is not None: ++ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] +@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) ++ if self.accept_length is not None and spec_info.accept_length is not None: ++ self.accept_length = torch.cat( ++ [self.accept_length, spec_info.accept_length] ++ ) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif self.accept_length is not None: ++ zeros = torch.zeros( ++ [spec_info.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([self.accept_length, zeros]) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif spec_info.accept_length is not None: ++ zeros = torch.zeros( ++ [self.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([zeros, spec_info.accept_length]) ++ self.accept_length_cpu = self.accept_length.tolist() + + + @dataclass diff --git a/docker/version.txt b/docker/version.txt index 40fee9d81..fcd7ad62f 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20251209d \ No newline at end of file +nightly-dev-20251216b \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 606273c07..4402f36fe 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ # slime Documentation -We recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase. +We recommend new contributors start from writing documentation, which helps you quickly understand slime codebase. Most documentation files are located under the `docs/` folder. ## Docs Workflow diff --git a/docs/en/blogs/introducing_slime.md b/docs/en/blogs/introducing_slime.md index 8e76cd97d..f35c7d65b 100644 --- a/docs/en/blogs/introducing_slime.md +++ b/docs/en/blogs/introducing_slime.md @@ -36,7 +36,7 @@ It wasn’t always like this: no one forks PyTorch just for a new dataloader. We slime views the data sampling in RL differently. We manage all SGLang servers within slime with [sgl-router](https://github.com/sgl-project/sglang/tree/main/sgl-router) and provide an interface for the data generation component, **allowing users to inject custom logic and freely interact with SGLang servers**. Unleash their creativity. - + With the sgl-router, users only need to send HTTP requests to a single endpoint. By exposing this endpoint, complex agent environments can directly interact with slime through an OpenAI-compatible API — no need to modify the environment, and training-deployment consistency is preserved. diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index 3bb29d01b..129b5c1d2 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -11,12 +11,12 @@ Regarding parallelism, for sglang we will enable EP64, activate dp attention, an ## Environment Setup -For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](./qwen3-4B.md). +For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](qwen3-4B.md). To prepare the DeepSeek R1 checkpoint, first you will need to download DeepSeek-R1 to a directory accessible by all machines (hereinafter referred to as `$BASE_DIR`): ```bash -huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 +hf download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 ``` The Hugging Face checkpoint for DeepSeek-R1 is in a block-quantized fp8 format. To convert it into a torch_dist format that Megatron can load, you first need to convert it to a bf16 Hugging Face checkpoint: @@ -85,7 +85,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/deepseek-v3.sh" ``` -This reads the model's config from [scripts/models/deepseek-v3.sh](../../../scripts/models/deepseek-v3.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +This reads the model's config from [scripts/models/deepseek-v3.sh](https://github.com/THUDM/slime/blob/main/scripts/models/deepseek-v3.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/). #### CKPT\_ARGS diff --git a/docs/en/examples/glm4-9B.md b/docs/en/examples/glm4-9B.md index 5728dacd2..854b62e59 100644 --- a/docs/en/examples/glm4-9B.md +++ b/docs/en/examples/glm4-9B.md @@ -15,14 +15,14 @@ Download the model and data: ```bash # hf checkpoint -huggingface-cli download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 +hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` @@ -49,7 +49,7 @@ bash scripts/run-glm4-9B.sh ### Parameter Introduction -Here, we will briefly introduce the various components of the [run-glm4-9B.sh](../../../scripts/run-glm4-9B.sh) script: +Here, we will briefly introduce the various components of the [run-glm4-9B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-glm4-9B.sh) script: #### MODEL\_ARGS @@ -58,7 +58,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/glm4-9B.sh" ``` -Reads the model's config from [scripts/models/glm4-9B.sh](../../../scripts/models/glm4-9B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +Reads the model's config from [scripts/models/glm4-9B.sh](https://github.com/THUDM/slime/blob/main/scripts/models/glm4-9B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/). ⚠️ Ensure that settings such as `--rotary-base` in the model configuration file match the settings of the model you are currently training. This is because different models, even with the same architecture, might use different values. If needed, you can override these parameters in your script after loading the model weights. For instance: diff --git a/docs/en/examples/glm4.5-355B-A32B.md b/docs/en/examples/glm4.5-355B-A32B.md index 03149063e..1be67a98c 100644 --- a/docs/en/examples/glm4.5-355B-A32B.md +++ b/docs/en/examples/glm4.5-355B-A32B.md @@ -5,12 +5,12 @@ This is an example of doing GLM-4.5 RL training using 64xH100 GPUs. ## Environment Setup -For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](./qwen3-4B.md). +For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](qwen3-4B.md). First, you will need to download GLM-4.5 to a directory accessible by all machines (hereinafter referred to as `$BASE_DIR`): ```bash -huggingface-cli download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B +hf download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B ``` Next, we need to convert the huggingface checkpoint into the torch_dist format with 2 nodes, each with 8 GPUs: @@ -66,7 +66,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/glm4.5-355B-A32B.sh" ``` -This reads the model's config from [scripts/models/glm4.5-355B-A32B.sh](../../../scripts/models/glm4.5-355B-A32B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +This reads the model's config from [scripts/models/glm4.5-355B-A32B.sh](https://github.com/THUDM/slime/blob/main/scripts/models/glm4.5-355B-A32B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/). #### PERF\_ARGS diff --git a/docs/en/examples/qwen3-30B-A3B.md b/docs/en/examples/qwen3-30B-A3B.md index 895a400f8..2c8395be0 100644 --- a/docs/en/examples/qwen3-30B-A3B.md +++ b/docs/en/examples/qwen3-30B-A3B.md @@ -3,7 +3,7 @@ ## Environment Preparation -The environment setup, model download, data, and checkpoint conversion are the same as for the Qwen3-4B model. You can refer to [Example: Qwen3-4B Model](./qwen3-4B.md), replacing mentions of Qwen3-4B with Qwen3-30B-A3B. +The environment setup, model download, data, and checkpoint conversion are the same as for the Qwen3-4B model. You can refer to [Example: Qwen3-4B Model](qwen3-4B.md), replacing mentions of Qwen3-4B with Qwen3-30B-A3B. To convert huggingface checkpoint to torch_dist, please try: @@ -29,7 +29,7 @@ bash scripts/run-qwen3-30B-A3B.sh ### Parameter Introduction -Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) script. +Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-30B-A3B.sh) script. 1. To support running Qwen3-30B-A3B in an 8xH800 environment, we need to enable Megatron's CPU Adam to save GPU memory. The corresponding configuration is: @@ -79,7 +79,7 @@ Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B. slime also supports BF16 training with FP8 inference. For the Qwen3-30B-A3B model, you just need to download the following model: ```bash -huggingface-cli download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 +hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 ``` And replace `--hf-checkpoint` with: diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 4678325fe..9039c9e5c 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -15,14 +15,14 @@ Download the model and data: ```bash # hf checkpoint -huggingface-cli download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` @@ -49,7 +49,7 @@ bash scripts/run-qwen3-4B.sh ### Parameter Introduction -Here, we will briefly introduce the various components of the [run-qwen3-4B.sh](../../../scripts/run-qwen3-4B.sh) script: +Here, we will briefly introduce the various components of the [run-qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B.sh) script: #### MODEL\_ARGS @@ -58,7 +58,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/qwen3-4B.sh" ``` -This reads the model's configuration from [scripts/models/qwen3-4B.sh](../../../scripts/models/qwen3-4B.sh). These are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +This reads the model's configuration from [scripts/models/qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/models/qwen3-4B.sh). These are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/). ⚠️ Ensure that settings such as `--rotary-base` in the model configuration file match the settings of the model you are currently training. This is because different models, even with the same architecture, might use different values. If needed, you can override these parameters in your script after loading the model weights. For instance: diff --git a/docs/en/examples/qwen3-4b-base-openhermes.md b/docs/en/examples/qwen3-4b-base-openhermes.md index 8a9519071..1ae784902 100644 --- a/docs/en/examples/qwen3-4b-base-openhermes.md +++ b/docs/en/examples/qwen3-4b-base-openhermes.md @@ -3,7 +3,7 @@ ## Environment Preparation -First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](./models/qwen3-4B.md). +First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](qwen3-4B.md). After that, we will process the SFT data. Here, we use the classic [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) as an example. First, we process the data into a format suitable for `slime` to load. You can use the following script to add a column that conforms to the OpenAI message format and save it to `/root/openhermes2_5.parquet`. @@ -50,7 +50,7 @@ bash script/run-qwen3-4B-base-sft.sh ### Parameter Introduction -You can compare [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B.sh) with [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows: +You can compare [run-qwen3-4B-base-sft.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B-base-sft.sh) with [run-qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows: 1. Removed `SGLANG_ARGS` and `GRPO_ARGS`. This is because it is not necessary to start SGLang or configure GRPO-related settings during the SFT process. diff --git a/docs/en/get_started/qa.md b/docs/en/get_started/qa.md index 81e987928..e4bb7cc9e 100644 --- a/docs/en/get_started/qa.md +++ b/docs/en/get_started/qa.md @@ -49,7 +49,7 @@ 9. **My gradient norm is very high and the training crashes. What should I do?** - First, ensure that your data and model are compatible. For example, if your data already uses a chat template, check if this template matches the one used by the original model. If the data is correct, please refer to our [Debug Guide](./debug.md) for a more in-depth analysis. + First, ensure that your data and model are compatible. For example, if your data already uses a chat template, check if this template matches the one used by the original model. If the data is correct, please refer to our [Debug Guide](../developer_guide/debug.md) for a more in-depth analysis. 10. **My sglang generation takes an extremely long time, GPU power is maxed out, and there's no output for a long while. Why?** diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index 7aa9abeda..cd33235c4 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -108,6 +108,14 @@ PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \ Note that as Megatron will do padding to embedding for better performance, it may happen that the converted embedding is not correct. In that case, please manually set `--vocab-size` during convertion. +For FSDP checkpoints (without `common.pt`), use the dedicated conversion script. Point `--input-dir` to the checkpoint directory (e.g. `iter_xxx` or `iter_xxx/model`) and provide the original Hugging Face directory: + +```bash +python tools/convert_fsdp_to_hf.py \ + --input-dir /path/to/fsdp_ckpt/iter_xxx \ + --output-dir /root/fsdp-converted \ + --origin-hf-dir /root/GLM-Z1-9B-0414 +``` ## Training Script and Parameter Overview @@ -571,5 +579,5 @@ ray job submit --address="http://127.0.0.1:8265" \ slime has been deeply optimized for distributed training of large-scale Mixture of Experts (MoE) models. We provide some end-to-end training cases for reference: -- [Example: 64xH100 Training GLM-4.5](models/glm4.5-355B-A32B.md) -- [Example: 128xH100 Training DeepSeek-R1](models/deepseek-r1.md) +- [Example: 64xH100 Training GLM-4.5](../examples/glm4.5-355B-A32B.md) +- [Example: 128xH100 Training DeepSeek-R1](../examples/deepseek-r1.md) diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 26310cb9f..276d6ceff 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -6,7 +6,7 @@ When using slime, parameters are primarily passed for the following purposes: 1. To allocate a portion of the GPUs in the cluster for training and another portion for inference. -2. To load Megatron for the training portion. +2. To load Megatron or FSDP for the training portion. 3. To load SGLang for the inference portion. 4. To configure the hyperparameters required for RL training. @@ -67,7 +67,7 @@ MODEL_ARGS=( ) ``` -We provide configurations for common models in [scripts/models](../../scripts/models), which you can reuse directly. If you are also using Megatron for pre-training/SFT, you can directly reuse the model configurations from your pre-training/SFT setup. +We provide configurations for common models in [scripts/models](../../../scripts/models), which you can reuse directly. If you are also using Megatron for pre-training/SFT, you can directly reuse the model configurations from your pre-training/SFT setup. Note: @@ -99,7 +99,7 @@ Megatron supports several of its custom checkpoint formats. Here are two of the The `torch` format is Megatron's older storage format. Its structure consists of directories like `mp_rank_xxx`, where each directory corresponds to the checkpoint stored by each rank under a specific parallel partitioning. Because of this, when loading a `torch` format checkpoint, you must ensure that the checkpoint's parallelism strategy matches that of the training task. -We recommend using the `torch_dist` format because it supports automatic parallel sharding, meaning that training tasks with different parallelism settings can share the same checkpoint, which is much more convenient. `torch_dist` is also the default format in the open-source Megatron. A `torch_dist` format checkpoint typically contains a set of `.distcp` files. When using `torch_dist`, you can convert from Hugging Face to `torch_dist` and vice versa using the checkpoint conversion method described in the [README](../../README.md). +We recommend using the `torch_dist` format because it supports automatic parallel sharding, meaning that training tasks with different parallelism settings can share the same checkpoint, which is much more convenient. `torch_dist` is also the default format in the open-source Megatron. A `torch_dist` format checkpoint typically contains a set of `.distcp` files. When using `torch_dist`, you can convert from Hugging Face to `torch_dist` and vice versa using the checkpoint conversion method described in the [README](../../../README.md). In terms of storage structure, a Megatron checkpoint typically looks like this, assuming the storage path is `/ckpt/`: @@ -183,7 +183,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When slime supports customizing data generation (rollout) to various degrees. - - By default, it uses the `generate_rollout` function from [slime/rollout/sglang\_example.py](../../slime/rollout/sglang_rollout.py) for data generation. This file implements an asynchronous (asyncio) data generation flow based on SGLang and supports features like dynamic sampling and partial rollout. + - By default, it uses the `generate_rollout` function from [slime/rollout/sglang_rollout.py](https://github.com/THUDM/slime/blob/main/slime/rollout/sglang_rollout.py) for data generation. This file implements an asynchronous (asyncio) data generation flow based on SGLang and supports features like dynamic sampling and partial rollout. - You can completely replace the `generate_rollout` in sglang\_example.py by using the `--rollout-function-path` parameter. You just need to ensure that the function signature passed via `--rollout-function-path` is as follows: @@ -213,7 +213,7 @@ slime supports customizing data generation (rollout) to various degrees. - `evaluation`: A boolean indicating if the rollout is for evaluation. You can configure a separate evaluation function using `--eval-function-path`. - - The returned `Sample` type is defined in [slime/utils/types.py](../../slime/utils/types.py). When implementing, you need to ensure the following fields are correctly set: + - The returned `Sample` type is defined in [slime/utils/types.py](https://github.com/THUDM/slime/blob/main/slime/utils/types.py). When implementing, you need to ensure the following fields are correctly set: - `tokens`: The tokens for the prompt + response. - `response_length`: The total length of the response. For multi-turn tasks, this is the length of the tokens remaining after the first-turn prompt. @@ -254,7 +254,7 @@ slime supports customizing data generation (rollout) to various degrees. return sample ``` - For a more complete version, please refer to [slime/rollout/sglang\_example.py](../../slime/rollout/sglang_rollout.py). + For a more complete version, please refer to [slime/rollout/sglang_rollout.py](https://github.com/THUDM/slime/blob/main/slime/rollout/sglang_rollout.py). - Sometimes, you may also need to support a custom reward model. This can be configured by setting `--custom-rm-path`. @@ -275,7 +275,7 @@ Some parameters related to slime's resource scheduling are configured by slime i - `--tp-size` in slime is set using `--rollout-num-gpus-per-engine`. - `--model-path` in slime is set using `--hf-checkpoint`. -The way SGLang parameters are integrated into slime can be found in [slime/backends/sglang\_utils/arguments.py](../../slime/backends/sglang_utils/arguments.py). +The way SGLang parameters are integrated into slime can be found in [slime/backends/sglang_utils/arguments.py](https://github.com/THUDM/slime/blob/main/slime/backends/sglang_utils/arguments.py). ### How to Use the Router @@ -291,7 +291,7 @@ slime supports different and lightly modified versions of Megatron by reusing co ### Parameter Configuration -slime directly imports all parameters of the Megatron in the current environment by using `from megatron.training.arguments import parse_args`. If the version of Megatron you are using has parameters defined outside of `parse_args`, you can configure them by passing them in, similar to how it's done in [train.py](../../train.py), for example: +slime directly imports all parameters of the Megatron in the current environment by using `from megatron.training.arguments import parse_args`. If the version of Megatron you are using has parameters defined outside of `parse_args`, you can configure them by passing them in, similar to how it's done in [train.py](https://github.com/THUDM/slime/blob/main/train.py), for example: ```python if __name__ == "__main__": @@ -309,4 +309,64 @@ In some customized Megatron implementations, special operations need to be perfo - `--custom-megatron-init-path`: Adds some initialization calls. - `--custom-megatron-before-log-prob-hook-path`: Is called before calculating the log probability. - - `--custom-megatron-before-train-step-hook-path`: Is called before each training step. You could use this to mix in special training losses, for example. \ No newline at end of file + - `--custom-megatron-before-train-step-hook-path`: Is called before each training step. You could use this to mix in special training losses, for example. + +## How to Use FSDP + +slime also support FSDP2 as the training backend, docs [here](https://lmsys.org/blog/2025-12-03-miles-fsdp/). + +> FSDP automatically reads all architecture information via `AutoModelForCausalLM.from_pretrained()`, without manual specification. Megatron requires manual configuration of parameters to read model architecture information. FSDP can read entirely from `config.json`, directly avoiding the weight format conversion step. + +To run FSDP as the training backend, pass `--train-backend fsdp` to enable. + +### Parameters + +Parameters that FSDP used are shown as below in comparison to Megatron, more supports are coming on the way. + +| Configuration Category | Megatron Parameter | FSDP Parameter | Description | +| --- | --- | --- | --- | +| **Model Loading** | `--load` (Megatron checkpoint) + architecture args (`--num-layers`, `--hidden-size` etc.) | `--hf-checkpoint` (Required) | **FSDP**: Directly uses HuggingFace format, no weight conversion needed, architecture inferred via `AutoConfig` | +| **Tensor Parallel** | `--tensor-model-parallel-size` | Coming Soon | | +| **Pipeline Parallel** | `--pipeline-model-parallel-size` | Coming Soon | | +| **Expert Parallel** | `--expert-model-parallel-size` | Coming Soon | | +| **Context Parallel** | `--context-parallel-size` | `--context-parallel-size` | Both support CP | +| **Initial Learning Rate** | `--lr` | `--lr` | Same parameter | +| **Learning Rate Decay** | `--lr-decay-style` (linear/cosine etc.) | `--lr-decay-style` | Same parameter | +| **Warmup** | `--lr-warmup-iters` (steps) | `--lr-warmup-iters` | Same parameter | +| **Min Learning Rate** | `--min-lr` | `--min-lr` | Same parameter | +| **Optimizer Type** | `--optimizer` (adam/sgd etc.) | `--optimizer` (default adam) | Basically same | +| **Distributed Optimizer** | `--use-distributed-optimizer` | Built-in to FSDP | FSDP uses distributed optimizer by default | +| **Gradient Checkpoint** | `--recompute-granularity`, `--recompute-method` | `--gradient-checkpointing` | **FSDP**: Simplified to boolean switch | +| **CPU Offload** | Implemented via distributed optimizer | `--fsdp-cpu-offload` | **FSDP**: Offload parameters/gradients/optimizer states to CPU | +| **CPU Backend** | Implemented via distributed optimizer | `--fsdp-cpu-backend` | **FSDP**: Specify CPU backend and use hybrid backend when CPU offload is enabled | +| **Attention Backend** | Decided by Megatron Core | `--attn-implementation` (flash_attention_2/sdpa/eager) | **FSDP**: Directly passed to HuggingFace | +| **Mixed Precision** | `--fp16` or `--bf16` | `--fp16` (bf16 inferred automatically) | Basically same | +| **Training Backend** | Default or `--train-backend megatron` | `--train-backend fsdp` (Required) | Used to switch backend | +| **Config** | | `--config` | **FSDP**: Set additional parameters for FSDP backend | + +### Quick Start + +```bash +# If you need to use WANDB, you need to set the environment variable WANDB_API_KEY in advance +# Download model weights (Qwen3-4B) +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B + +# Download training dataset (dapo-math-17k) +hf download --repo-type dataset zhuzilin/dapo-math-17k \ + --local-dir /root/dapo-math-17k + +# Download evaluation dataset (aime-2024) +hf download --repo-type dataset zhuzilin/aime-2024 \ + --local-dir /root/aime-2024 + +# Clone code and install dependencies +git clone https://github.com/THUDM/slime.git +cd slime +pip install -e . + + +# FSDP does not require weight conversion, natively supports huggingface format +# Enable reference model, train Qwen3-4B in colocate mode +source /root/slime/scripts/run-qwen3-4B-fsdp.sh +``` + diff --git a/docs/en/platform_support/amd_tutorial.md b/docs/en/platform_support/amd_tutorial.md index 4a41fae2a..a8dbd8c6d 100644 --- a/docs/en/platform_support/amd_tutorial.md +++ b/docs/en/platform_support/amd_tutorial.md @@ -85,7 +85,7 @@ Note: We implemented a dedicated AMD conversion script that forces a CPU-only co ### Example: Qwen3-4B We provide examples to use [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B), please refer to: -- [Example: Qwen3-4B Model](../../../scripts/run-qwen3-4B-amd.sh): Just run `scripts/run-qwen3-4B-amd.sh` +- [Example: Qwen3-4B Model](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B-amd.sh): Just run `scripts/run-qwen3-4B-amd.sh` ⚠️ TODO: ROCM seems to not support `apex` yet. Thus, we need to disable gradient accumulation fusionby adding the `--no-gradient-accumulation-fusion` flag in the training script currently. We will continue investigating how to enable this. diff --git a/docs/zh/examples/deepseek-r1.md b/docs/zh/examples/deepseek-r1.md index 83ea5ca64..87314efc5 100644 --- a/docs/zh/examples/deepseek-r1.md +++ b/docs/zh/examples/deepseek-r1.md @@ -10,12 +10,12 @@ ## 环境准备 -搭建环境与下载数据的方法可以参考 [示例:Qwen3-4B](./qwen3-4B.md)。 +搭建环境与下载数据的方法可以参考 [示例:Qwen3-4B](qwen3-4B.md)。 准备 DeepSeek R1 的 ckpt 首先需要在多机均可访问到的地址(下记为 `$BASE_DIR`)上下载 DeepSeek-R1: ```bash -huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 +hf download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 ``` DeepSeek-R1 的 huggingface ckpt 为 block-quant 的 fp8 格式,为了转换一个 Megatron 可以加载的 torch dist 格式,需要先转化一个 bf16 的 huggingface ckpt: @@ -84,7 +84,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/deepseek-v3.sh" ``` -从 [scripts/models/deepseek-v3.sh](../../../scripts/models/deepseek-v3.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](../../../scripts/models/) 中提供了一些样例。 +从 [scripts/models/deepseek-v3.sh](https://github.com/THUDM/slime/blob/main/scripts/models/deepseek-v3.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/) 中提供了一些样例。 #### CKPT_ARGS diff --git a/docs/zh/examples/glm4-9B.md b/docs/zh/examples/glm4-9B.md index a005c3d84..99d834bd2 100644 --- a/docs/zh/examples/glm4-9B.md +++ b/docs/zh/examples/glm4-9B.md @@ -15,14 +15,14 @@ pip install -e . ```bash # hf checkpoint -huggingface-cli download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 +hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` @@ -49,7 +49,7 @@ bash script/run-glm4-9B.sh ### 参数简介 -这里我们简单介绍一下脚本 [run-glm4-9B.sh](../../../scripts/run-glm4-9B.sh) 中的各个组成部分: +这里我们简单介绍一下脚本 [run-glm4-9B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-glm4-9B.sh) 中的各个组成部分: #### MODEL_ARGS @@ -58,7 +58,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/glm4-9B.sh" ``` -从 [scripts/models/glm4-9B.sh](../../../scripts/models/glm4-9B.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](../../../scripts/models/) 中提供了一些样例。 +从 [scripts/models/glm4-9B.sh](https://github.com/THUDM/slime/blob/main/scripts/models/glm4-9B.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/) 中提供了一些样例。 ⚠️ 注意检查模型文件中的 `--rotary-base` 等配置是否对应你当前训练模型的配置,因为同一个模型结构的不同模型可能有不同的取值。在这种情况下,你可以在导入模型参数后在脚本里进行覆盖,例如: diff --git a/docs/zh/examples/glm4.5-355B-A32B.md b/docs/zh/examples/glm4.5-355B-A32B.md index edfd1529a..8a9db6d69 100644 --- a/docs/zh/examples/glm4.5-355B-A32B.md +++ b/docs/zh/examples/glm4.5-355B-A32B.md @@ -4,12 +4,12 @@ ## 环境准备 -搭建环境与下载数据的方法可以参考 [示例:Qwen3-4B](./qwen3-4B.md)。 +搭建环境与下载数据的方法可以参考 [示例:Qwen3-4B](qwen3-4B.md)。 首先需要在多机均可访问到的地址(下记为 `$BASE_DIR`)上下载 GLM-4.5: ```bash -huggingface-cli download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B +hf download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B ``` 通过如下方式通过 2 机 16 卡将 huggingface checkpoint 转换为 torch dist 格式: @@ -65,7 +65,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/glm4.5-355B-A32B.sh" ``` -从 [scripts/models/glm4.5-355B-A32B.sh](../../../scripts/models/glm4.5-355B-A32B.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](../../../scripts/models/) 中提供了一些样例。 +从 [scripts/models/glm4.5-355B-A32B.sh](https://github.com/THUDM/slime/blob/main/scripts/models/glm4.5-355B-A32B.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/) 中提供了一些样例。 #### PERF_ARGS diff --git a/docs/zh/examples/qwen3-30B-A3B.md b/docs/zh/examples/qwen3-30B-A3B.md index 3dcdd1ecb..1abdaa032 100644 --- a/docs/zh/examples/qwen3-30B-A3B.md +++ b/docs/zh/examples/qwen3-30B-A3B.md @@ -2,7 +2,7 @@ ## 环境准备 -搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 Qwen3-30B-A3B 即可。 +搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 Qwen3-30B-A3B 即可。 可以用如下方法把 huggingface checkpoint 转化为 torch_dist 格式: @@ -28,7 +28,7 @@ bash scripts/run-qwen3-30B-A3B.sh ### 参数简介 -这里我们简单介绍一下脚本 [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) 中与 MoE 相关的部分。 +这里我们简单介绍一下脚本 [run-qwen3-30B-A3B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-30B-A3B.sh) 中与 MoE 相关的部分。 1. 为了支持在 8xH800 环境中运行 Qwen3-30B-A3B,我们需要开启 megatron 的 CPU Adam 以节省显存,对应配置为: @@ -78,7 +78,7 @@ bash scripts/run-qwen3-30B-A3B.sh slime 还支持 bf16 训练,fp8 推理。对于 Qwen3-30B-A3B 模型,只需要下载如下模型: ```bash -huggingface-cli download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 +hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 ``` 并将 `--hf-checkpoint` 替换为: diff --git a/docs/zh/examples/qwen3-4B.md b/docs/zh/examples/qwen3-4B.md index a67ba925c..03ca850a1 100644 --- a/docs/zh/examples/qwen3-4B.md +++ b/docs/zh/examples/qwen3-4B.md @@ -15,14 +15,14 @@ pip install -e . ```bash # hf checkpoint -huggingface-cli download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` @@ -49,7 +49,7 @@ bash scripts/run-qwen3-4B.sh ### 参数简介 -这里我们简单介绍一下脚本 [run-qwen3-4B.sh](../../../scripts/run-qwen3-4B.sh) 中的各个组成部分: +这里我们简单介绍一下脚本 [run-qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B.sh) 中的各个组成部分: #### MODEL_ARGS @@ -58,7 +58,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/qwen3-4B.sh" ``` -从 [scripts/models/qwen3-4B.sh](../../../scripts/models/qwen3-4B.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](../../../scripts/models/) 中提供了一些样例。 +从 [scripts/models/qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/models/qwen3-4B.sh) 读取模型的 config。这些 config 都是 megatron 的参数。在使用 megatron 进行训练的时候,megatron 无法从 ckpt 中读取模型 config,需要我们自行配置。我们在 [scripts/models](https://github.com/THUDM/slime/tree/main/scripts/models/) 中提供了一些样例。 ⚠️ 注意检查模型文件中的 `--rotary-base` 等配置是否对应你当前训练模型的配置,因为同一个模型结构的不同模型可能有不同的取值。在这种情况下,你可以在导入模型参数后在脚本里进行覆盖,例如: diff --git a/docs/zh/examples/qwen3-4b-base-openhermes.md b/docs/zh/examples/qwen3-4b-base-openhermes.md index 1b759db56..6abb6ac64 100644 --- a/docs/zh/examples/qwen3-4b-base-openhermes.md +++ b/docs/zh/examples/qwen3-4b-base-openhermes.md @@ -2,7 +2,7 @@ ## 环境准备 -首先需要我们仿照 [示例:Qwen3-4B 模型](./models/qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。 +首先需要我们仿照 [示例:Qwen3-4B 模型](qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。 之后,我们处理 sft 数据。这里我们以经典的 [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) 为例,首先把数据处理成适合 slime 加载的格式,可以用如下的脚本进行处理,增加一个符合 openai message 格式的列,并保存在 `/root/openhermes2_5.parquet`。 @@ -49,7 +49,7 @@ bash script/run-qwen3-4B-base-sft.sh ### 参数简介 -可以将 [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整: +可以将 [run-qwen3-4B-base-sft.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整: 1. 移除了 `SGLANG_ARGS` 和 `GRPO_ARGS`。这是因为 sft 的过程中不需要启动 sglang 或者做 grpo 相关的配置; diff --git a/docs/zh/examples/qwen3-next-80B-A3B.md b/docs/zh/examples/qwen3-next-80B-A3B.md new file mode 100644 index 000000000..a831c4de0 --- /dev/null +++ b/docs/zh/examples/qwen3-next-80B-A3B.md @@ -0,0 +1,108 @@ +# 8xH100 训练 Qwen3-30B-A3B + +## 环境准备 + +搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 +Qwen3-next-80B-A3B-Instruct 即可。 + +可以用如下完整方法把 huggingface checkpoint 转化为 torch_dist 格式: + +```bash +export BASE_FOLDER=./models/ +# 下载模型权重 (Qwen3-Next-80B-A3B-Thinking) +hf download Qwen/Qwen3-Next-80B-A3B-Thinking --local-dir ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking +``` + +```shell +cd slime/ +pip install -e . + +# (for acceleration) +cd .. # and find a proper folder +git clone https://github.com/fla-org/flash-linear-attention +cd flash-linear-attention +git checkout 9714c595 +pip install -e . + +wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl +pip install ./causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl +``` + +## [Optional] Fix a bug in triton compilation on Blackwell (sm100) + +see discussion here https://github.com/triton-lang/triton/issues/8695 +and https://github.com/fla-org/flash-linear-attention/issues/638 + +We need to apply a patch to fix the bug. +Go to the flash-linear-attention folder you just installed, and apply the following patch: + +```diff +diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py +index c5119dcf..838f5e4e 100644 +--- a/fla/ops/gated_delta_rule/wy_fast.py ++++ b/fla/ops/gated_delta_rule/wy_fast.py +@@ -198,7 +198,14 @@ def prepare_wy_repr_bwd_kernel( + b_A += tl.dot(b_kb, tl.trans(b_k)) + b_dkb = tl.dot(b_dA, b_k) + b_db += tl.sum(b_dkb * b_k, 1) +- b_dk += tl.dot(tl.trans(b_dA), b_kb) ++ b_dk += tl.inline_asm_elementwise( ++ asm="mov.f32 $0, $1;", ++ constraints="=r,r", ++ args=[tl.dot(tl.trans(b_dA), b_kb)], ++ dtype=tl.float32, ++ is_pure=True, ++ pack=1, ++ ) + b_dk += b_dkb * b_b[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + +``` + +save it as `patch.diff` (Please remember to copy the last empty line to the file!) and do `git apply patch.diff` + +## 执行训练 (Megatron) + +**当前暂不支持Blackwell** + +转换模型权重: + +```bash +source scripts/models/qwen3-next-80B-A3B.sh +PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \ + tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-Next-80B-A3B-Thinking/ \ + --save /root/Qwen3-Next-80B-A3B-Thinking_torch_dist/ +``` + +单机8卡 + +```bash +cd /root/slime +export BASE_FOLDER=/root +export MASTER_ADDR=127.0.0.1 +bash scripts/run-qwen3-next-80B-A3B-8gpus.sh +``` +如果显存不够,考虑disable `--accumulate-allreduce-grads-in-fp32`,enable `--grad-reduce-in-bf16` + + +多机(4x8) + +```bash +cd /root/slime +export BASE_FOLDER=/root +export MASTER_ADDR=your_master_addr +bash scripts/run-qwen3-next-80B-A3B.sh +``` + +## 执行训练 (FSDP) + +```bash +export BASE_FOLDER=./models/ +export MASTER_ADDR=127.0.0.1 + +bash scripts/run-qwen3-next-80B-A3B-fsdp.sh +``` + diff --git a/docs/zh/get_started/qa.md b/docs/zh/get_started/qa.md index b3b5627c8..05c73890b 100644 --- a/docs/zh/get_started/qa.md +++ b/docs/zh/get_started/qa.md @@ -49,7 +49,7 @@ 1. **grad norm 好高,训练训崩了怎么办?** - 首先请确保数据和模型是匹配的,例如说,如果数据是实现已经做好 chat template 的了,这个 chat template 是否和原模型一致。如果数据正确的话,可以参考 [debug 指南](./debug.md) 进行更深入的分析。 + 首先请确保数据和模型是匹配的,例如说,如果数据是实现已经做好 chat template 的了,这个 chat template 是否和原模型一致。如果数据正确的话,可以参考 [debug 指南](../developer_guide/debug.md) 进行更深入的分析。 1. **我的 sglang 生成时间特别特别久,gpu 功率都打满了,跑了好久好没有输出是为什么?** diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md index 46002a6ed..2c3b21e1f 100644 --- a/docs/zh/get_started/quick_start.md +++ b/docs/zh/get_started/quick_start.md @@ -108,6 +108,15 @@ PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \ 由于 Megatron 会对 embedding 做 padding,可能会出现转换出来的权重的 embedding 形状不匹配的问题。这时需要在转换时设置 `--vocab-size`。 +对于使用 FSDP 后端训练并保存的检查点(目录中没有 `common.pt` 的情况),请使用专门的转换脚本。将 `--input-dir` 指向检查点目录(例如 `iter_xxx` 或 `iter_xxx/model`),并提供原始 Hugging Face 模型路径: + +```bash +python tools/convert_fsdp_to_hf.py \ + --input-dir /path/to/fsdp_ckpt/iter_xxx \ + --output-dir /root/fsdp-converted \ + --origin-hf-dir /root/GLM-Z1-9B-0414 +``` + ## 训练脚本与参数概览 完成上述准备工作后,即可运行训练脚本。 @@ -577,5 +586,5 @@ ray job submit --address="http://127.0.0.1:8265" \ slime 针对大规模混合专家(MoE)模型的分布式训练进行了深度优化。我们提供了一些端到端的训练案例以供参考: -- [示例:64xH100 训练 GLM-4.5](models/glm4.5-355B-A32B.md) -- [示例:128xH100 训练 DeepSeek-R1](models/deepseek-r1.md) +- [示例:64xH100 训练 GLM-4.5](../examples/glm4.5-355B-A32B.md) +- [示例:128xH100 训练 DeepSeek-R1](../examples/deepseek-r1.md) diff --git a/docs/zh/get_started/usage.md b/docs/zh/get_started/usage.md index 332be9d41..a27d1009a 100644 --- a/docs/zh/get_started/usage.md +++ b/docs/zh/get_started/usage.md @@ -5,7 +5,7 @@ 在使用 slime 时,传参主要是为了如下几件事: 1. 把集群中一部分 GPU 分配做训练,一部分分配做推理; -2. 训练的部分加载 megatron; +2. 训练的部分加载 megatron或者FSDP; 3. 推理部分加载 sglang; 4. 配置 RL 训练需要的超参。 @@ -70,7 +70,7 @@ MODEL_ARGS=( ) ``` -我们在 [scripts/models](../../scripts/models) 提供了常用模型的配置,可以直接复用。如果你也在使用 megatron 进行 pretrain/sft 的话,可以直接复用 pretrain/sft 中的模型配置。 +我们在 [scripts/models](../../../scripts/models) 提供了常用模型的配置,可以直接复用。如果你也在使用 megatron 进行 pretrain/sft 的话,可以直接复用 pretrain/sft 中的模型配置。 注意: @@ -103,7 +103,7 @@ megatron 支持多种其自定义的 ckpt 格式,这里介绍 2 种比较主 torch 格式是 megatron 的老存储格式,里面的结构大约是一些 `mp_rank_xxx` 的文件夹,每个文件夹对应了在对应的并行划分下,每个 rank 存储的 ckpt。也是因为如此,在加载 torch 格式的 ckpt 的时候,需要保证 ckpt 的并行策略和训练任务的并行策略是相同的。 -我们推荐使用 torch_dist 格式 ckpt,因为 torch_dist 格式可以支持自动并行切分,也就是不同并行的训练任务都可以共用同一个 ckpt,会方便很多。torch_dist 这也是开源 megatron 目前的默认格式。torch_dist 格式的 ckpt 中一般是一堆 `.distcp` 文件。在使用 torch_dist 时,可以使用 [README](../../README_zh.md) 中介绍的 ckpt 转化方法从 huggingface 转化为 torch_dist,反之亦然。 +我们推荐使用 torch_dist 格式 ckpt,因为 torch_dist 格式可以支持自动并行切分,也就是不同并行的训练任务都可以共用同一个 ckpt,会方便很多。torch_dist 这也是开源 megatron 目前的默认格式。torch_dist 格式的 ckpt 中一般是一堆 `.distcp` 文件。在使用 torch_dist 时,可以使用 [README](../../../README_zh.md) 中介绍的 ckpt 转化方法从 huggingface 转化为 torch_dist,反之亦然。 在存储结构上,megatron 的 ckpt 一般是这样的结构,这里假设存储的路径为 `/ckpt/`: @@ -187,7 +187,7 @@ sglang 的加载非常简单,只需要: slime 支持不同程度的自定义数据生成(rollout)。 -- 默认会使用 [slime/rollout/sglang_rollout.py](../../slime/rollout/sglang_rollout.py) 中的 `generate_rollout` 函数进行数据生成。这个文件中实现了基于 sglang 的异步(asyncio)数据生成流程,并支持了例如 dynamic sampling,partial rollout 等功能; +- 默认会使用 [slime/rollout/sglang_rollout.py](https://github.com/THUDM/slime/blob/main/slime/rollout/sglang_rollout.py) 中的 `generate_rollout` 函数进行数据生成。这个文件中实现了基于 sglang 的异步(asyncio)数据生成流程,并支持了例如 dynamic sampling,partial rollout 等功能; - 可以通过 `--rollout-function-path` 参数,完全替换 sglang_rollout.py 中的 `generate_rollout`,只需要保证 `--rollout-function-path` 传入的函数签名满足: @@ -213,7 +213,7 @@ slime 支持不同程度的自定义数据生成(rollout)。 - `rollout_id` 对应的是当前是第几次数据生成,用作保证续训时的数据顺序; - `data_buffer` 是 slime 中全局唯一的数据 buffer,可以用来获取初始 prompt,数据 id,将生成至一半的 sample 存储下来下次留作下次使用等; - `evaluation` 是否是当做 evaluation 使用。可以通过 `--eval-function-path` 单独配置 eval 的函数; - - 返回的 `Sample` 类型见 [slime/utils/types.py](../../slime/utils/types.py),在实现时,需要保证 + - 返回的 `Sample` 类型见 [slime/utils/types.py](https://github.com/THUDM/slime/blob/main/slime/utils/types.py),在实现时,需要保证 - `tokens`:prompt + response 的 token; - `response_length`:response 的总长。对于多轮任务,则是除去第一轮 prompt,剩余的 token 长度; - `reward`:这条数据的 reward; @@ -253,7 +253,7 @@ slime 支持不同程度的自定义数据生成(rollout)。 return sample ``` - 更完备的版本请查看 [slime/rollout/sglang_rollout.py](../../slime/rollout/sglang_rollout.py)。 + 更完备的版本请查看 [slime/rollout/sglang_rollout.py](https://github.com/THUDM/slime/blob/main/slime/rollout/sglang_rollout.py)。 - 有的时候,我们还需要支持自定义的 reward model,可以通过配置 `--custom-rm-path` 来进行配置。 @@ -274,7 +274,7 @@ slime 通过引入 sglang 的 `ServerArgs.add_cli_args`,从而引入了几乎 - `--tp-size` 在 slime 中会使用 `--rollout-num-gpus-per-engine` - `--model-path` 在 slime 中会使用 `--hf-checkpoint` -sglang 参数引入 slime 的方式可以参考 [slime/backends/sglang_utils/arguments.py](../../slime/backends/sglang_utils/arguments.py)。 +sglang 参数引入 slime 的方式可以参考 [slime/backends/sglang_utils/arguments.py](https://github.com/THUDM/slime/blob/main/slime/backends/sglang_utils/arguments.py)。 ### router 使用方法 @@ -290,7 +290,7 @@ slime 通过复用 `megatron.training` 目录下的常规函数,如 `parse_arg ### 参数配置 -slime 通过直接引入 `from megatron.training.arguments import parse_args` 引入了当前环境中 megatron 的所有参数。如果当前使用的 megatron 有在 `parse_args` 之外的参数,可以通过像 [train.py](../../train.py) 中传入参数来进行配置,例如: +slime 通过直接引入 `from megatron.training.arguments import parse_args` 引入了当前环境中 megatron 的所有参数。如果当前使用的 megatron 有在 `parse_args` 之外的参数,可以通过像 [train.py](https://github.com/THUDM/slime/blob/main/train.py) 中传入参数来进行配置,例如: ```python if __name__ == "__main__": @@ -308,4 +308,62 @@ if __name__ == "__main__": - `--custom-megatron-init-path`:会增加一些 init 的调用; - `--custom-megatron-before-log-prob-hook-path`:会在计算 log prob 之前调用; -- `--custom-megatron-before-train-step-hook-path`:会在每个训练步之前调用。可以考虑用这种方式混入特殊的训练 loss 之类的。 \ No newline at end of file +- `--custom-megatron-before-train-step-hook-path`:会在每个训练步之前调用。可以考虑用这种方式混入特殊的训练 loss 之类的。 + +## FSDP 使用方法 + +slime 同样也支持FSDP2作为训练后端,可以参考[文档](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/fsdp/readme.md)。 + +> FSDP 通过 `AutoModelForCausalLM.from_pretrained()` 自动读取所有架构信息,无需手动指定。Megatron 需要手动配置参数读取 model 架构信息,FSDP可以全部从 `config.json` 自动读取,可以直接避免权重格式转换步骤。 + +可以通过在命令行传递 `--train-backend fsdp` 来启动 FSDP 作为训练后端。 + +### 参数 + +FSDP和Megatron后端支持的参数的对比如下表所示,接下来FSDP会有更多的支持。 + +| 配置类别 | Megatron 参数 | FSDP 参数 | 说明 | +| --- | --- | --- | --- | +| **模型加载** | `--load` (Megatron checkpoint) + 架构参数 (`--num-layers`, `--hidden-size` 等) | `--hf-checkpoint` (必需) | **FSDP**: 直接使用 HuggingFace 格式,无需转换权重,通过 `AutoConfig` 自动推断架构 | +| **张量并行** | `--tensor-model-parallel-size` | Coming Soon | | +| **流水线并行** | `--pipeline-model-parallel-size` | Coming Soon | | +| **专家并行** | `--expert-model-parallel-size` | Coming Soon | | +| **上下文并行** | `--context-parallel-size` | `--context-parallel-size` | 两者都支持 CP | +| **初始学习率** | `--lr` | `--lr` | 参数相同 | +| **学习率衰减** | `--lr-decay-style` (linear/cosine 等) | `--lr-decay-style` | 参数相同 | +| **Warmup** | `--lr-warmup-iters` (步数) | `--lr-warmup-iters` | 参数相同 | +| **最小学习率** | `--min-lr` | `--min-lr` | 参数相同 | +| **优化器类型** | `--optimizer` (adam/sgd 等) | `--optimizer` (默认 adam) | 基本相同 | +| **分布式优化器** | `--use-distributed-optimizer` | 内置于 FSDP | FSDP 默认使用分布式优化器 | +| **梯度检查点** | `--recompute-granularity`, `--recompute-method` | `--gradient-checkpointing` | **FSDP**: 简化为布尔开关 | +| **CPU Offload** | 通过分布式优化器实现 | `--fsdp-cpu-offload` | **FSDP**: 将参数/梯度/优化器状态卸载到 CPU | +| **CPU 后端** | 通过分布式优化器实现 | `--fsdp-cpu-backend` | **FSDP**: 指定CPU的后端并且当CPU offload时使用混合后端 | +| **Attention 后端** | 由 Megatron Core 决定 | `--attn-implementation` (flash_attention_2/sdpa/eager) | **FSDP**: 直接透传给 HuggingFace | +| **混合精度** | `--fp16` 或 `--bf16` | `--fp16` (bf16 自动推断) | 基本相同 | +| **训练后端** | 默认或 `--train-backend megatron` | `--train-backend fsdp` (必需) | 用于切换后端 | +| **参数配置** | | `--config` | **FSDP**: 为FSDP设置额外的参数 | + +### FSDP 一键启动 + +```bash +# 如果需要使用 WANDB,需要提前设置好环境变量 WANDB_API_KEY +# 下载模型权重 (Qwen3-4B) +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B + +# 下载训练数据集 (dapo-math-17k) +hf download --repo-type dataset zhuzilin/dapo-math-17k \ + --local-dir /root/dapo-math-17k + +# 下载评估数据集 (aime-2024) +hf download --repo-type dataset zhuzilin/aime-2024 \ + --local-dir /root/aime-2024 + +# 克隆代码并安装依赖 +git clone https://github.com/THUDM/slime.git +cd slime +pip install -e . + + +# FSDP不用进行权重转换,native 支持 huggingface 格式 +# 开启 reference model,在 colocated 模式下训练 Qwen3-4B +source /root/slime/scripts/run-qwen3-4B-fsdp.sh \ No newline at end of file diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 0e4096a97..bc4e769f1 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -41,7 +41,7 @@ slime 是 GLM-4.5 与 GLM-4.6 背后的 RL 训练框架。除此之外,slime _examples_synced/reproducibility/README.md advanced/speculative-decoding.md - advanced/fault-tolerance.md + advanced/fault-torlance.md advanced/arch-support-beyond-megatron.md .. toctree:: diff --git a/examples/geo3k_vlm/README.md b/examples/geo3k_vlm/README.md index 7cc16a313..d6766ebe4 100644 --- a/examples/geo3k_vlm/README.md +++ b/examples/geo3k_vlm/README.md @@ -1,9 +1,9 @@ -# FSDP + VLM Single-Turn RL +# VLM Single-Turn RL (FSDP & Megatron) -Training VLMs with FSDP on single-turn reasoning task using GRPO on the [GEO3K dataset](https://huggingface.co/datasets/hiyouga/geometry3k). We used processed version [here](https://huggingface.co/datasets/chenhegu/geo3k_imgurl). +Training VLMs with FSDP or Megatron on single-turn reasoning task using GRPO on the [GEO3K dataset](https://huggingface.co/datasets/hiyouga/geometry3k). We used processed version [here](https://huggingface.co/datasets/chenhegu/geo3k_imgurl).
-
+