diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index ebd605fba..14215177b 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -301,10 +301,10 @@ index e7d5a67cc..639e47163 100644 out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py new file mode 100644 -index 000000000..7369f9dc9 +index 000000000..11adcaa77 --- /dev/null +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py -@@ -0,0 +1,308 @@ +@@ -0,0 +1,305 @@ +import logging +from abc import ABC +from contextlib import contextmanager @@ -402,9 +402,6 @@ index 000000000..7369f9dc9 + assert hasattr(self, "buffer") + return get_tensor_size_bytes(self.buffer) + -+ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): -+ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) -+ + def _finalize_allocation_log(self): + """Common logging and memory usage computation for captured experts buffers.""" + buffer_size_GB = self.get_buffer_size_bytes() / _GB @@ -903,7 +900,7 @@ index e34736cc4..5e5997a1a 100644 # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index c4c5a9ebb..1450c5fd8 100644 +index c4c5a9ebb..3650ba881 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -450,6 +450,7 @@ class Req: @@ -953,7 +950,15 @@ index c4c5a9ebb..1450c5fd8 100644 is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_config=dllm_config, -@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1282,6 +1294,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + ) + else: + self.out_cache_loc = torch.cat(decoder_out_cache_loc) ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + if not encoder_out_cache_loc: + self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to( +@@ -1457,6 +1470,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.req_pool_indices = req_pool_indices_tensor self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc @@ -961,7 +966,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) if input_embeds -@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1508,10 +1522,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) @@ -976,7 +981,7 @@ index c4c5a9ebb..1450c5fd8 100644 # For overlap scheduler, the output_ids has one step delay delta = 0 if self.enable_overlap else -1 -@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1677,6 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) @@ -984,7 +989,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 -@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1736,6 +1755,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Allocate memory self.out_cache_loc = alloc_for_decode(self, token_per_req=1) @@ -992,7 +997,7 @@ index c4c5a9ebb..1450c5fd8 100644 # Update req-level memory management fields for req in self.reqs: -@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1807,6 +1827,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None @@ -1000,7 +1005,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) -@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1852,6 +1873,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.out_cache_loc = None @@ -1008,7 +1013,7 @@ index c4c5a9ebb..1450c5fd8 100644 self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: self.output_ids = torch.cat([self.output_ids, other.output_ids]) -@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1903,6 +1925,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens=self.seq_lens, orig_seq_lens=self.orig_seq_lens, out_cache_loc=self.out_cache_loc, @@ -1016,7 +1021,7 @@ index c4c5a9ebb..1450c5fd8 100644 seq_lens_cpu=seq_lens_cpu, seq_lens_sum=self.seq_lens_sum, return_logprob=self.return_logprob, -@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -1983,7 +2006,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " @@ -1026,7 +1031,7 @@ index c4c5a9ebb..1450c5fd8 100644 ) -@@ -2038,6 +2061,9 @@ class ModelWorkerBatch: +@@ -2038,6 +2062,9 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo @@ -1194,7 +1199,7 @@ index edbc52526..2cdc42755 100644 # This means that weight sync diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index b90cf0616..9b0992655 100644 +index b90cf0616..8a5cbdbed 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -20,6 +20,7 @@ import logging @@ -1213,14 +1218,12 @@ index b90cf0616..9b0992655 100644 data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, -@@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1621,6 +1623,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): if getattr(recv_obj, "output_hidden_states", None): meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + if getattr(recv_obj, "output_routed_experts", None): + if recv_obj.output_routed_experts[i] is not None: -+ # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}") -+ # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt") + meta_info["routed_experts"] = pybase64.b64encode( + recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C") + ).decode("ascii") @@ -1230,7 +1233,7 @@ index b90cf0616..9b0992655 100644 if isinstance(recv_obj, BatchStrOutput): state.text += recv_obj.output_strs[i] if self.server_args.stream_output and state.obj.stream: -@@ -1747,12 +1759,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1747,12 +1757,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): return if len(recv_obj.input_token_logprobs_val) > 0: @@ -1251,7 +1254,7 @@ index b90cf0616..9b0992655 100644 recv_obj.output_token_logprobs_val[recv_obj_index] ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py -index 3a85e6a7e..2859dafa1 100644 +index 3a85e6a7e..d2560e79b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( @@ -1315,8 +1318,16 @@ index 3a85e6a7e..2859dafa1 100644 if self.encoder_lens is not None: self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) self.positions = self._pad_tensor_to_size(self.positions, num_tokens) +@@ -906,6 +921,7 @@ class ForwardBatch: + self.spec_info.hidden_states = self.hidden_states_backup + if hasattr(self, "output_cache_loc_backup"): + self.out_cache_loc = self.output_cache_loc_backup ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + elif self.forward_mode.is_decode() or self.forward_mode.is_idle(): + logits_output.next_token_logits = logits_output.next_token_logits[:bs] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 4d58278b7..8f50dc430 100644 +index 4d58278b7..5965c481e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import ( @@ -1331,18 +1342,19 @@ index 4d58278b7..8f50dc430 100644 from sglang.srt.layers.pooler import EmbeddingPoolerOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model -@@ -502,6 +507,10 @@ class ModelRunner: +@@ -502,6 +507,11 @@ class ModelRunner: server_args.max_running_requests, server_args.max_total_tokens, ) + + # Init routed experts capturer -+ self.init_routed_experts_capturer() ++ if not self.is_draft_worker: ++ self.init_routed_experts_capturer() + if self.device == "cuda": self.init_cublas() self.init_attention_backend() -@@ -545,6 +554,40 @@ class ModelRunner: +@@ -545,6 +555,40 @@ class ModelRunner: # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -1383,7 +1395,7 @@ index 4d58278b7..8f50dc430 100644 def model_specific_adjustment(self): server_args = self.server_args -@@ -792,7 +835,11 @@ class ModelRunner: +@@ -792,7 +836,11 @@ class ModelRunner: ) with self.memory_saver_adapter.region( GPU_MEMORY_TYPE_WEIGHTS, @@ -1396,7 +1408,7 @@ index 4d58278b7..8f50dc430 100644 ): self.model = get_model( model_config=self.model_config, -@@ -2645,9 +2692,12 @@ class ModelRunner: +@@ -2645,9 +2693,12 @@ class ModelRunner: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: self.forward_pass_id += 1 @@ -1412,17 +1424,18 @@ index 4d58278b7..8f50dc430 100644 ): output = self._forward_raw( forward_batch, -@@ -2656,6 +2706,13 @@ class ModelRunner: +@@ -2656,6 +2707,14 @@ class ModelRunner: reinit_attn_backend, split_forward_count, ) + # Copy cached routing experts' buffers back to CPU cache -+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( -+ device_loc=forward_batch.out_cache_loc, -+ cpu_loc=forward_batch.out_cache_loc_cpu, -+ can_run_graph=output[1], -+ cuda_graph_batch=getattr(self.graph_runner, "bs", None), -+ ) ++ if not self.is_draft_worker: ++ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( ++ device_loc=forward_batch.out_cache_loc, ++ cpu_loc=forward_batch.out_cache_loc_cpu, ++ can_run_graph=output[1], ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() @@ -1976,10 +1989,34 @@ index 8e7753dab..323788f39 100644 "--scheduler-recv-interval", type=int, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index b3d72df05..ddfe0b178 100644 +index b3d72df05..09a1634e0 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py -@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): +@@ -135,6 +135,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): + len(batch.input_ids), + ) + self.last_loc = last_loc ++ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True) + + bs = batch.batch_size() + assign_req_to_token_pool_func( +@@ -492,6 +493,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): + batch.out_cache_loc = tgt_cache_loc + batch.seq_lens.add_(accept_length + 1) + batch.seq_lens_cpu.add_(accept_length_cpu + 1) ++ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True) + + draft_input = EagleDraftInput( + hidden_states=batch.spec_info.hidden_states[accept_index], +@@ -575,6 +577,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) ++ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True) + + return EagleVerifyOutput( + draft_input=draft_input, +@@ -746,6 +749,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.topk_index = self.topk_index[: len(new_indices)] self.hidden_states = self.hidden_states[: len(new_indices)] self.verified_id = self.verified_id[: len(new_indices)] @@ -1990,7 +2027,7 @@ index b3d72df05..ddfe0b178 100644 else: # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` self.topk_p = self.topk_p[new_indices] -@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): +@@ -777,6 +784,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) diff --git a/docker/version.txt b/docker/version.txt index 81449aa2b..fcd7ad62f 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20251216a \ No newline at end of file +nightly-dev-20251216b \ No newline at end of file