Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 74 additions & 37 deletions docker/patch/latest/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ index e7d5a67cc..639e47163 100644
out_hidden_states[begin_chunk_idx:end_chunk_idx],
diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py
new file mode 100644
index 000000000..7369f9dc9
index 000000000..11adcaa77
--- /dev/null
+++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py
@@ -0,0 +1,308 @@
@@ -0,0 +1,305 @@
+import logging
+from abc import ABC
+from contextlib import contextmanager
Expand Down Expand Up @@ -402,9 +402,6 @@ index 000000000..7369f9dc9
+ assert hasattr(self, "buffer")
+ return get_tensor_size_bytes(self.buffer)
+
+ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor):
+ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True)
+
+ def _finalize_allocation_log(self):
+ """Common logging and memory usage computation for captured experts buffers."""
+ buffer_size_GB = self.get_buffer_size_bytes() / _GB
Expand Down Expand Up @@ -903,7 +900,7 @@ index e34736cc4..5e5997a1a 100644
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index c4c5a9ebb..1450c5fd8 100644
index c4c5a9ebb..3650ba881 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -450,6 +450,7 @@ class Req:
Expand Down Expand Up @@ -953,15 +950,23 @@ index c4c5a9ebb..1450c5fd8 100644
is_prefill_only=all(req.is_prefill_only for req in reqs),
chunked_req=chunked_req,
dllm_config=dllm_config,
@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1282,6 +1294,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
else:
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)

if not encoder_out_cache_loc:
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
@@ -1457,6 +1470,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_pool_indices = req_pool_indices_tensor
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
+ self.out_cache_loc_cpu = out_cache_loc.cpu()
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
if input_embeds
@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1508,10 +1522,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):

input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
Expand All @@ -976,47 +981,47 @@ index c4c5a9ebb..1450c5fd8 100644

# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1
@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1677,6 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
+ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu")
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1736,6 +1755,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):

# Allocate memory
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)

# Update req-level memory management fields
for req in self.reqs:
@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1807,6 +1827,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
+ self.out_cache_loc_cpu = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1852,6 +1873,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None
+ self.out_cache_loc_cpu = None
self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None:
self.output_ids = torch.cat([self.output_ids, other.output_ids])
@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1903,6 +1925,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens=self.seq_lens,
orig_seq_lens=self.orig_seq_lens,
out_cache_loc=self.out_cache_loc,
+ out_cache_loc_cpu=self.out_cache_loc_cpu,
seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=self.seq_lens_sum,
return_logprob=self.return_logprob,
@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@@ -1983,7 +2006,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
Expand All @@ -1026,7 +1031,7 @@ index c4c5a9ebb..1450c5fd8 100644
)


@@ -2038,6 +2061,9 @@ class ModelWorkerBatch:
@@ -2038,6 +2062,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo

Expand Down Expand Up @@ -1194,7 +1199,7 @@ index edbc52526..2cdc42755 100644

# This means that weight sync
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b90cf0616..9b0992655 100644
index b90cf0616..8a5cbdbed 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -20,6 +20,7 @@ import logging
Expand All @@ -1213,14 +1218,12 @@ index b90cf0616..9b0992655 100644
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
extra_key=obj.extra_key,
@@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -1621,6 +1623,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

+ if getattr(recv_obj, "output_routed_experts", None):
+ if recv_obj.output_routed_experts[i] is not None:
+ # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}")
+ # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt")
+ meta_info["routed_experts"] = pybase64.b64encode(
+ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C")
+ ).decode("ascii")
Expand All @@ -1230,7 +1233,7 @@ index b90cf0616..9b0992655 100644
if isinstance(recv_obj, BatchStrOutput):
state.text += recv_obj.output_strs[i]
if self.server_args.stream_output and state.obj.stream:
@@ -1747,12 +1759,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -1747,12 +1757,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return

if len(recv_obj.input_token_logprobs_val) > 0:
Expand All @@ -1251,7 +1254,7 @@ index b90cf0616..9b0992655 100644
recv_obj.output_token_logprobs_val[recv_obj_index]
)
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index 3a85e6a7e..2859dafa1 100644
index 3a85e6a7e..d2560e79b 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
Expand Down Expand Up @@ -1315,8 +1318,16 @@ index 3a85e6a7e..2859dafa1 100644
if self.encoder_lens is not None:
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
@@ -906,6 +921,7 @@ class ForwardBatch:
self.spec_info.hidden_states = self.hidden_states_backup
if hasattr(self, "output_cache_loc_backup"):
self.out_cache_loc = self.output_cache_loc_backup
+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)

elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 4d58278b7..8f50dc430 100644
index 4d58278b7..5965c481e 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import (
Expand All @@ -1331,18 +1342,19 @@ index 4d58278b7..8f50dc430 100644
from sglang.srt.layers.pooler import EmbeddingPoolerOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -502,6 +507,10 @@ class ModelRunner:
@@ -502,6 +507,11 @@ class ModelRunner:
server_args.max_running_requests,
server_args.max_total_tokens,
)
+
+ # Init routed experts capturer
+ self.init_routed_experts_capturer()
+ if not self.is_draft_worker:
+ self.init_routed_experts_capturer()
+
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
@@ -545,6 +554,40 @@ class ModelRunner:
@@ -545,6 +555,40 @@ class ModelRunner:
# Initialize piecewise CUDA graph
self.init_piecewise_cuda_graphs()

Expand Down Expand Up @@ -1383,7 +1395,7 @@ index 4d58278b7..8f50dc430 100644
def model_specific_adjustment(self):
server_args = self.server_args

@@ -792,7 +835,11 @@ class ModelRunner:
@@ -792,7 +836,11 @@ class ModelRunner:
)
with self.memory_saver_adapter.region(
GPU_MEMORY_TYPE_WEIGHTS,
Expand All @@ -1396,7 +1408,7 @@ index 4d58278b7..8f50dc430 100644
):
self.model = get_model(
model_config=self.model_config,
@@ -2645,9 +2692,12 @@ class ModelRunner:
@@ -2645,9 +2693,12 @@ class ModelRunner:
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
self.forward_pass_id += 1

Expand All @@ -1412,17 +1424,18 @@ index 4d58278b7..8f50dc430 100644
):
output = self._forward_raw(
forward_batch,
@@ -2656,6 +2706,13 @@ class ModelRunner:
@@ -2656,6 +2707,14 @@ class ModelRunner:
reinit_attn_backend,
split_forward_count,
)
+ # Copy cached routing experts' buffers back to CPU cache
+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH(
+ device_loc=forward_batch.out_cache_loc,
+ cpu_loc=forward_batch.out_cache_loc_cpu,
+ can_run_graph=output[1],
+ cuda_graph_batch=getattr(self.graph_runner, "bs", None),
+ )
+ if not self.is_draft_worker:
+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH(
+ device_loc=forward_batch.out_cache_loc,
+ cpu_loc=forward_batch.out_cache_loc_cpu,
+ can_run_graph=output[1],
+ cuda_graph_batch=getattr(self.graph_runner, "bs", None),
+ )

if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end()
Expand Down Expand Up @@ -1976,10 +1989,34 @@ index 8e7753dab..323788f39 100644
"--scheduler-recv-interval",
type=int,
diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
index b3d72df05..ddfe0b178 100644
index b3d72df05..09a1634e0 100644
--- a/python/sglang/srt/speculative/eagle_info.py
+++ b/python/sglang/srt/speculative/eagle_info.py
@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
@@ -135,6 +135,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
len(batch.input_ids),
)
self.last_loc = last_loc
+ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)

bs = batch.batch_size()
assign_req_to_token_pool_func(
@@ -492,6 +493,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
+ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)

draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[accept_index],
@@ -575,6 +577,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
+ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)

return EagleVerifyOutput(
draft_input=draft_input,
@@ -746,6 +749,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
Expand All @@ -1990,7 +2027,7 @@ index b3d72df05..ddfe0b178 100644
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
@@ -777,6 +784,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
Expand Down
2 changes: 1 addition & 1 deletion docker/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-dev-20251216a
nightly-dev-20251216b