diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ecac3865662..a5fe4bddf59 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1049,6 +1049,11 @@ def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE if self.spec_algorithm.is_eagle(): return + elif self.spec_algorithm.is_lookahead(): + self.spec_info.prepare_for_verify(self) + # overwrite the forward_mode + self.forward_mode = ForwardMode.TARGET_VERIFY + return self.input_ids = self.output_ids self.output_ids = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9c00c8b2521..8d78936e85d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -146,7 +146,7 @@ def __init__( ) self.decode_mem_cache_buf_multiplier = ( self.server_args.speculative_num_draft_tokens - if not self.spec_algorithm.is_none() + if self.spec_algorithm.is_eagle() else 1 ) self.enable_hierarchical_cache = server_args.enable_hierarchical_cache @@ -257,6 +257,17 @@ def __init__( target_worker=self.tp_worker, dp_rank=dp_rank, ) + elif self.spec_algorithm.is_lookahead(): + from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker + + self.draft_worker = LOOKAHEADWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + dp_rank=dp_rank, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + ) else: self.draft_worker = None @@ -1064,6 +1075,8 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: if batch.batch_size() < initial_bs: self.batch_is_full = False + if self.spec_algorithm.is_lookahead(): + batch.spec_info = self.draft_worker.prepare_for_verify(batch) # Update batch tensors batch.prepare_for_decode() return batch diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index db103162f23..7eaba057f85 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -106,7 +106,7 @@ def set_torch_compile_config(): torch._dynamo.config.cache_size_limit = 1024 -def get_batch_sizes_to_capture(model_runner: ModelRunner): +def get_batch_sizes_to_capture(model_runner: ModelRunner, is_spec=False): server_args = model_runner.server_args capture_bs = server_args.cuda_graph_bs if capture_bs is None: @@ -126,12 +126,17 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ) ) ) - capture_bs = [ - bs - for bs in capture_bs - if bs <= model_runner.req_to_token_pool.size - and bs <= server_args.cuda_graph_max_bs - ] + if is_spec: + # For speculative inference, large batch sizes are not effective. + capture_bs = [1, 2, 3, 4] + else: + capture_bs = [ + bs + for bs in capture_bs + if bs <= model_runner.req_to_token_pool.size + and bs <= server_args.cuda_graph_max_bs + ] + if is_hip_: capture_bs += [i * 8 for i in range(21, 33)] compile_bs = ( @@ -158,7 +163,7 @@ def set_global_graph_memory_pool(val): class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" - def __init__(self, model_runner: ModelRunner): + def __init__(self, model_runner: "ModelRunner", is_spec=False): # Parse args self.model_runner = model_runner self.graphs = {} @@ -169,9 +174,11 @@ def __init__(self, model_runner: ModelRunner): self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.tp_size = model_runner.server_args.tp_size self.dp_size = model_runner.server_args.dp_size - + self.is_spec = is_spec # Batch sizes to capture - self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture( + model_runner, is_spec + ) self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): @@ -183,6 +190,12 @@ def __init__(self, model_runner: ModelRunner): self.model_runner.server_args.speculative_num_draft_tokens ) + if model_runner.spec_algorithm.is_lookahead() and self.is_spec: + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) + # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs @@ -213,6 +226,11 @@ def __init__(self, model_runner: ModelRunner): (self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) + if self.is_spec: + self.draft_token_num = ( + torch.ones((len(self.capture_bs),), dtype=torch.int32) + * self.num_tokens_per_bs + ) if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch @@ -281,7 +299,14 @@ def can_run(self, forward_batch: ForwardBatch): if self.is_encoder_decoder else True ) - return is_bs_supported and is_encoder_lens_supported + + is_token_num_supported = True + if self.is_spec: + is_token_num_supported = ( + forward_batch.batch_size * self.num_tokens_per_bs + == forward_batch.input_ids.numel() + ) + return is_token_num_supported and is_bs_supported and is_encoder_lens_supported def capture(self): with graph_capture() as graph_capture_context: @@ -481,5 +506,23 @@ def get_spec_info(self, num_tokens: int): spec_steps=self.model_runner.server_args.speculative_num_steps, capture_hidden_mode=CaptureHiddenMode.FULL, ) + if self.model_runner.spec_algorithm.is_lookahead() and self.is_spec: + from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput + + bs = int(num_tokens / self.num_tokens_per_bs) + spec_info = LookaheadVerifyInput( + None, + None, + None, + None, + None, + self.draft_token_num[:bs], + ) + spec_info.capture_hidden_mode = CaptureHiddenMode.NULL + spec_info.custom_mask = torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ) return spec_info diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d125868b09a..7f21959b571 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -595,7 +595,7 @@ def init_memory_pool( 4096, ) - if not self.spec_algorithm.is_none(): + if self.spec_algorithm.is_eagle(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size else: @@ -717,6 +717,7 @@ def init_double_sparsity_channel_config(self, selected_channel): def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None + self.cuda_graph_runner_spec = None if not self.is_generation: # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models @@ -728,6 +729,10 @@ def init_cuda_graphs(self): tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) + if self.spec_algorithm.is_lookahead(): + # in case look_ahead failed to match any draft token, fallback to normal cuda graph decode + self.cuda_graph_runner_spec = CudaGraphRunner(self, is_spec=True) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def apply_torch_tp(self): @@ -772,12 +777,16 @@ def forward_idle(self, forward_batch: ForwardBatch): ) def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: + cuda_graph_runner = self.cuda_graph_runner + if forward_batch.spec_algorithm.is_lookahead(): + cuda_graph_runner = self.cuda_graph_runner_spec + if ( forward_batch.forward_mode.is_cuda_graph() - and self.cuda_graph_runner - and self.cuda_graph_runner.can_run(forward_batch) + and cuda_graph_runner + and cuda_graph_runner.can_run(forward_batch) ): - return self.cuda_graph_runner.replay(forward_batch) + return cuda_graph_runner.replay(forward_batch) if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 93f797087f2..bfc9e7ecbe6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -126,6 +126,8 @@ class ServerArgs: speculative_num_steps: int = 5 speculative_num_draft_tokens: int = 64 speculative_eagle_topk: int = 8 + speculative_lookahead_path: str = None + speculative_one_branch: bool = False # Double Sparsity enable_double_sparsity: bool = False @@ -270,6 +272,18 @@ def __post_init__(self): "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding." ) + if self.speculative_algorithm == "LOOKAHEAD": + self.disable_overlap_schedule = True + self.chunked_prefill_size = -1 + self.disable_mla = True + self.enable_double_sparsity = False + assert ( + self.attention_backend == "flashinfer" + ), "Lookahead speculative decoding only support flashinfer for now." + logger.info( + "The mla, chunked_prefill, overlap scheduler and double_sparsity are disabled because of lookahead speculative decoding." + ) + # GGUF if ( self.load_format == "auto" or self.load_format == "gguf" @@ -698,7 +712,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE"], + choices=["EAGLE", "LOOKAHEAD"], help="Speculative algorithm.", ) parser.add_argument( @@ -706,6 +720,12 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.", ) + parser.add_argument( + "--speculative-lookahead-path", + type=str, + help="The path of the lookahead. If provided, the lookahead will be inited from this path. You can `lookahead_cache.save_mem('lookahrad.pkl')` to save the lookahead for later use.", + required=False, + ) parser.add_argument( "--speculative-num-steps", type=int, @@ -725,6 +745,11 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[1, 2, 4, 8], default=ServerArgs.speculative_eagle_topk, ) + parser.add_argument( + "--speculative-one-branch", + action="store_true", + help="Whether to use one branch in Lookahead Speculative Decoding.", + ) # Double Sparsity parser.add_argument( diff --git a/python/sglang/srt/speculative/lookahead_cache.py b/python/sglang/srt/speculative/lookahead_cache.py new file mode 100644 index 00000000000..0309e88393f --- /dev/null +++ b/python/sglang/srt/speculative/lookahead_cache.py @@ -0,0 +1,654 @@ +# -*- coding: utf-8 -*- +""" +Adapted from: +https://github.com/alipay/PainlessInferenceAcceleration/blob/main/pia/lookahead/common/lookahead_cache.py +""" +import json +import pickle +import time +from collections import defaultdict + +import numpy as np + + +class Node: + __slots__ = ["freqs", "children"] + + def __init__(self, children, freqs): + self.children = children + self.freqs = freqs + + def __repr__(self): + return f"{list(self.children.keys())}:{self.freqs}" + + +class Tree: + def __init__(self, token_id, max_node=65536, max_output_node=512): + self.token_id = token_id + self.max_node = max_node + self.max_output_node = max_output_node + self.n_node = 0 + self.n_output_node = 0 + self.nodes = {} + + def put(self, token_ids, mode="output", idx=0, freq=1.0): + assert mode in ("input", "output") + if mode == "output": + idx = -1 + self._put(token_ids, self.nodes, mode=mode, idx=idx, freq=freq) + + def _put(self, token_ids, nodes, mode="output", freq=1.0, idx=-1): + for t in token_ids: + if t not in nodes: + nodes[t] = Node({}, {idx: freq}) + self.n_node += 1 + if mode == "output": + self.n_output_node += 1 + else: + nodes[t].freqs[idx] = nodes[t].freqs.get(idx, 0.0) + freq + nodes = nodes[t].children + + def get( + self, + token_ids, + max_size=64, + max_length=8, + min_input_size=0, + min_output_size=0, + output_weight=1e-4, + mode="mix", + idx=0, + ): + assert mode in ("input", "output", "mix") + + match_token_id, nodes = self._match(token_ids, mode=mode, idx=idx) + if not nodes: + token_id = token_ids[-1] if token_ids else self.token_id + return [token_id], np.ones((1, 1), dtype=np.int64), [0, 0] + + freqs = [] + self._dfs_get_freqs(nodes, freqs, idx, output_weight) + + min_mix_freq = min_input_freq = min_output_freq = 1e9 + if mode == "input": + output_weight = 0.0 + size = len([x for x in freqs if x[1] > 0]) + min_input_freq = ( + sorted(freqs, key=lambda x: x[1], reverse=True)[min_input_size - 1][1] + if size > max_size + else 0.0 + ) + elif mode == "output": + output_weight = 1.0 + size = len([x for x in freqs if x[2] > 0]) + min_output_freq = ( + sorted(freqs, key=lambda x: x[2], reverse=True)[min_output_size - 1][2] + if size > max_size + else 0.0 + ) + else: + size = len([x for x in freqs if x[1] > 0 or x[2] > 0]) + if size > max_size: + indices = set() + if min_input_size > 0: + input_freqs = sorted(freqs, key=lambda x: x[1], reverse=True) + min_input_freq = input_freqs[min_input_size - 1][1] + indices.update([x[0] for x in input_freqs[:min_input_size]]) + + if min_output_size > 0: + output_freqs = sorted(freqs, key=lambda x: x[2], reverse=True) + min_output_freq = output_freqs[min_output_size - 1][2] + indices.update([x[0] for x in output_freqs[:min_output_size]]) + + if len(indices) < max_size: + mix_freqs = sorted(freqs, key=lambda x: x[3], reverse=True) + rest_size = max_size - len(indices) + indices.update([x[0] for x in mix_freqs[:rest_size]]) + cur_size = len(indices) + for i in range(rest_size, min(rest_size + max_size, size)): + if mix_freqs[i][0] in indices: + continue + cur_size += 1 + if cur_size >= max_size: + min_mix_freq = mix_freqs[i][3] + break + else: + min_mix_freq = 0.0 + + mask = np.zeros((max_size, max_size), dtype=np.int64) + mask[:, 0] = 1 + ids = [match_token_id or self.token_id] + sizes = [0, 0] + self._ravel( + nodes, + ids, + mask, + -1, + max_size=max_size, + max_length=max_length, + min_output_freq=min_output_freq, + min_input_freq=min_input_freq, + min_mix_freq=min_mix_freq, + sizes=sizes, + output_weight=output_weight, + mode=mode, + idx=idx, + ) + size = len(ids) + + mask = mask[:size, :size] + return ids, mask, sizes + + def _dfs_get_freqs(self, nodes, freqs, idx, output_weight): + for node in nodes.values(): + fo = node.freqs.get(-1, 0.0) + fi = node.freqs.get(idx, 0.0) + if fo > 0 or fi > 0: + fm = (1.0 - output_weight) * fi + output_weight * fo + freqs.append([None, fi, fo, fm]) + if node.children: + self._dfs_get_freqs(node.children, freqs, idx, output_weight) + + def get_one_branch(self, token_ids, max_length=8, mode="mix", idx=0): + assert mode in ("input", "output", "mix") + + match_token_id, nodes = self._match(token_ids, mode=mode, idx=idx) + if len(nodes) == 0: + token_id = token_ids[-1] if len(token_ids) > 0 else self.token_id + return [token_id], [0, 0] + + ids = [match_token_id or self.token_id] + length = 0 + while True: + if len(nodes) == 0 or length >= max_length: + break + max_freq = 0.0 + max_node = None + max_id = None + if mode == "mix": + for t, node in nodes.items(): + freqs = node.freqs + fo = freqs.get(idx, 0.0) + fi = freqs.get(-1, 0.0) + if fo > 0 or fi > 0: + freq = 10000 * fi + fo + if freq > max_freq: + max_freq = freq + max_node = node + max_id = t + elif mode == "input": + for t, node in nodes.items(): + freqs = node.freqs + freq = freqs.get(idx, 0.0) + if freq > 0: + if freq > max_freq: + max_freq = freq + max_node = node + max_id = t + else: + for t, node in nodes.items(): + freqs = node.freqs + freq = freqs.get(-1, 0.0) + if freq > 0: + if freq > max_freq: + max_freq = freq + max_node = node + max_id = t + if max_node is None: + break + ids.append(max_id) + nodes = max_node.children + length += 1 + + return ids, [length] + + def _match(self, token_ids, mode="mix", idx=0): + nodes = self.nodes + token_id = None + for token_id in token_ids: + node = nodes.get(token_id, None) + nodes = {} + if node is None: + break + if mode == "input" and node.freqs.get(idx, 0.0) > 0: + nodes = node.children + elif mode == "output" and node.freqs.get(-1, 0.0) > 0: + nodes = node.children + elif node.freqs.get(idx, 0.0) > 0 or node.freqs.get(-1, 0.0) > 0: + nodes = node.children + return token_id, nodes + + def _ravel( + self, + nodes, + ids, + mask, + pid, + max_size=64, + max_length=8, + min_output_freq=1.0, + min_input_freq=1.0, + min_mix_freq=1.0, + output_weight=1e-4, + sizes=None, + mode="mix", + idx=0, + ): + if len(ids) >= max_size or max_length <= 0: + return + + sorts = sorted( + [ + ( + k, + v, + (1.0 - output_weight) * v.freqs.get(idx, 0.0) + + output_weight * v.freqs.get(-1, 0.0), + ) + for k, v in nodes.items() + ], + key=lambda x: x[2], + reverse=True, + ) + for tid, node, fm in sorts: + if len(ids) >= max_size: + return + fi = node.freqs.get(idx, 0.0) + fo = node.freqs.get(-1, 0.0) + if ( + mode == "mix" + and fi < min_input_freq + and fo < min_output_freq + and fm < min_mix_freq + ): + continue + elif mode == "input" and fi < min_input_freq: + continue + elif mode == "output" and fo < min_output_freq: + continue + if fi > 0.0: + sizes[0] += 1 + if fo > 0.0: + sizes[1] += 1 + ids.append(tid) + rid = len(ids) - 1 + + if pid > -1: + mask[rid] = mask[pid] + mask[rid, rid] = 1 + if node.children: + self._ravel( + node.children, + ids, + mask, + rid, + max_size=max_size, + max_length=max_length - 1, + min_output_freq=min_output_freq, + min_input_freq=min_input_freq, + min_mix_freq=min_mix_freq, + output_weight=output_weight, + sizes=sizes, + mode=mode, + idx=idx, + ) + + def squeeze(self): + if self.n_node > self.max_node or self.n_output_node > self.max_output_node: + self._squeeze(self.nodes) + sizes = [0] + self._count_node(self.nodes, sizes) + self.n_node = sizes[0] + self.n_output_node = sizes[0] + + def _squeeze(self, nodes): + for t, p in list(nodes.items()): + fo = p.freqs.get(-1, 0.0) + if fo > 1.0: + p.freqs[-1] *= 0.5 + if p.children: + self._squeeze(p.children) + else: + nodes.pop(t) + + def _count_node(self, nodes, sizes): + sizes[0] += len(nodes) + for n in nodes.values(): + if n.children: + self._count_node(n.children, sizes) + + def reset_input_freq(self, idx): + if self.nodes: + self._reset_input_freq(self.nodes, idx) + + def _reset_input_freq(self, nodes, idx): + for node in nodes.values(): + if node.freqs.get(idx, 0.0) > 0: + node.freqs[idx] = 0.0 + if node.children: + self._reset_input_freq(node.children, idx) + + +class LookaheadCache: + def __init__( + self, + debug=False, + eos_ids=(2,), + stop_words=None, + max_node=65536, + max_output_node=512, + gpu_id=0, + ): + self.debug = debug + self.eos_ids = eos_ids if eos_ids is not None else [None] + self.max_node = max_node + self.max_output_node = max_output_node + self.gpu_id = gpu_id + self.mem = {} + self._output_ids = defaultdict(list) + self._update_trees = set() + self._update_input_trees = set() + self.stop_words = stop_words if stop_words is not None else {} + self.default_mask = np.ones((1, 1), dtype=np.int64) + + def put(self, token_ids, branch_length=8, final=False, mode="output", idx=0): + for eos in self.eos_ids: + if eos in token_ids: + token_ids = token_ids[: token_ids.index(eos)] + if len(token_ids) >= 2: + ts = len(token_ids) # ts: token_ids size + for i in range(ts - 1): + token_id = token_ids[i] + tup = token_ids[i + 1 : i + branch_length + 1] + if self.debug: + print(f"input token:{token_id} tokens:{tup}") + tree = self.mem.get(token_id, None) + if tree is not None: + tree.put(tup, mode=mode, idx=idx) + self._update_trees.add(tree) + else: + tree = Tree( + token_id, + max_node=self.max_node, + max_output_node=self.max_output_node, + ) + tree.put(tup, mode=mode, idx=idx) + self.mem[token_id] = tree + + if mode == "input": + self._update_input_trees.add(tree) + + if final: + self.reset_input_freqs(idx) + self.squeeze_branch_counts() + + def stream_put(self, token_ids, branch_length=8, final=False, mode="output", idx=0): + # idx is only used for caching output_ids + assert mode == "output" and idx >= 0 + for eos in self.eos_ids: + if eos in token_ids: + token_ids = token_ids[: token_ids.index(eos)] + self._output_ids[idx].extend(token_ids) + output_ids = self._output_ids[idx] + ts = len(output_ids) + min_branch_length = 1 if final else branch_length + if ts > min_branch_length: + for i in range(ts - min_branch_length): + token_id = output_ids[i] + if token_id in self.stop_words: + continue + tup = output_ids[i + 1 : i + branch_length + 1] + if self.debug: + print(f"input token:{token_id} tokens:{tup}") + tree = self.mem.get(token_id, None) + if tree: + tree.put(tup, mode="output", idx=idx) + else: + tree = Tree( + token_id, + max_node=self.max_node, + max_output_node=self.max_output_node, + ) + tree.put(tup, mode="output", idx=idx) + self.mem[token_id] = tree + self._update_trees.add(tree) + if not final: + self._output_ids[idx] = output_ids[ts - branch_length :] + if final: + self._output_ids[idx] = [] + self.reset_input_freqs(idx) + self.squeeze_branch_counts() + + def hier_get( + self, + token_ids, + decoding_length=64, + branch_length=8, + min_input_size=0, + min_output_size=0, + mode="mix", + idx=0, + ): + assert mode in ("input", "output", "mix") + + decoding_masks = self.default_mask + if decoding_length <= 1 or branch_length == 0: + return token_ids[-1:], decoding_masks, [] + + decoding_ids = None + sizes = [0, 0] + for i, t in enumerate(token_ids): + tree = self.mem.get(t, None) + if tree is not None: + ids = token_ids[i + 1 :] + if t in self.stop_words and len(ids) == 0: + continue + decoding_ids, decoding_masks, sizes = tree.get( + ids, + max_size=decoding_length, + max_length=branch_length - 1, + min_input_size=min_input_size, + min_output_size=min_output_size, + mode=mode, + idx=idx, + ) + s = len(decoding_ids) + # token count is enough, not need retrieve again + if s >= branch_length: + break + + if decoding_ids is None: + decoding_ids = token_ids[-1:] + + return decoding_ids, decoding_masks, sizes + + def par_get( + self, + token_ids, + decoding_length=16, + branch_length=8, + min_input_size=0, + min_output_size=0, + mode="mix", + idx=0, + ): + + output_ids, decoding_masks, decoding_lengths = self.hier_get( + token_ids, + decoding_length=decoding_length, + branch_length=branch_length, + min_input_size=min_input_size, + min_output_size=min_output_size, + mode=mode, + idx=idx, + ) + sets = [] + true_decoding_length = len(output_ids) - 1 + for i in range(true_decoding_length, 0, -1): + (indices,) = np.nonzero(decoding_masks[i, 1:]) + indices = set(indices) + flag = True + for ss in sets: + if len(indices - ss) == 0: + flag = False + break + if flag: + sets.append(indices) + + sets.reverse() + count = 0 + max_decoding_length = true_decoding_length + branches = [] + for indices in sets: + indices = sorted(list(indices)) + rest_count = max_decoding_length - count + indices = indices[:rest_count] + count += len(indices) + branch = [] + for i in indices: + branch.append(output_ids[i + 1]) + branches.append(branch) + if count >= max_decoding_length: + break + ids = [output_ids[0]] + masks = np.tril(np.ones((count + 1, count + 1)), 0) + count = 1 + for branch in branches: + ids.extend(branch) + length = len(branch) + masks[count : count + length, 1:count] = 0 + count += length + + return ids, masks, [count - 1] + + def one_get( + self, + token_ids, + decoding_length=64, + branch_length=8, + min_input_size=0, + min_output_size=0, + mode="mix", + idx=0, + ): + assert mode in ("input", "output", "mix") + + max_decoding_masks = self.default_mask + if decoding_length <= 1 or branch_length == 0: + return token_ids[-1:], max_decoding_masks, [] + + max_decoding_ids = None + max_sizes = [0, 0] + for i, t in enumerate(token_ids): + tree = self.mem.get(t, None) + if tree is not None: + ids = token_ids[i + 1 :] + if t in self.stop_words and len(ids) == 0: + continue + decoding_ids, sizes = tree.get_one_branch( + ids, max_length=branch_length - 1, mode=mode, idx=idx + ) + s = len(decoding_ids) + decoding_masks = np.tril(np.ones((s, s), dtype=np.int64), 0) + + if max_decoding_ids is None: + max_decoding_ids = decoding_ids + max_decoding_masks = decoding_masks + max_sizes = sizes + if s > len(max_decoding_ids): + max_decoding_ids = decoding_ids + max_decoding_masks = decoding_masks + max_sizes = sizes + # token count is enough, not need retrieve again + if s >= branch_length // 2: + break + if max_decoding_ids is None: + max_decoding_ids = token_ids[-1:] + + return max_decoding_ids, max_decoding_masks, max_sizes + + def bat_get( + self, + token_id_list, + decoding_length=64, + branch_length=8, + decoding_cursors=None, + mode="output", + indices=None, + decoding_mode="hier", + ): + assert mode in ("input", "output", "mix") + assert decoding_mode in ("hier", "one") + bs = len(token_id_list) + assert bs == len(decoding_cursors) and bs == len( + indices + ), f"{bs=} {len(decoding_cursors)=} {len(indices)=}" + + decoding_id_list = [] + decoding_mask_list = [] + size_list = [] + + min_cur = min(decoding_cursors) + max_cur = max(decoding_cursors) + bs = len(decoding_cursors) + for sub_idx, token_ids in enumerate(token_id_list): + update_decoding_length = decoding_length // bs + min_input_size = 0 + min_output_size = max(update_decoding_length // 2, 1) + method_name = decoding_mode + "_get" + decoding_ids, decoding_masks, sizes = getattr(self, method_name)( + token_ids, + decoding_length=update_decoding_length, + branch_length=branch_length, + min_input_size=min_input_size, + min_output_size=min_output_size, + mode=mode, + idx=indices[sub_idx], + ) + decoding_id_list.append(decoding_ids) + decoding_mask_list.append(decoding_masks) + size_list.append(sizes) + + bs = len(token_id_list) + max_size = max([len(x) for x in decoding_id_list]) + + decoding_masks = np.zeros( + (bs, max_size, max_cur - min_cur + max_size), dtype=np.int64 + ) + for i, decoding_ids in enumerate(decoding_id_list): + org_size = len(decoding_ids) + gap = max_size - org_size + if gap > 0: + decoding_ids.extend([0] * gap) + cur = decoding_cursors[i] + decoding_masks[i, :org_size, cur - min_cur : cur - min_cur + org_size] = ( + decoding_mask_list[i] + ) + decoding_masks[i, :, : cur - min_cur + 1] = 1 + return decoding_id_list, decoding_masks, size_list + + def fresh(self): + self.mem = {} + + def reset_input_freqs(self, idx): + if len(self._update_input_trees) > 0: + for t in self._update_input_trees: + t.reset_input_freq(idx) + self._update_input_trees.clear() + + def squeeze_branch_counts(self): + if len(self._update_trees) >= 1024: + for t in self._update_trees: + t.squeeze() + self._update_trees.clear() + + def save_mem(self, save_dir): + serialized_object = pickle.dumps(self.mem) + json_string = json.dumps(serialized_object.decode("latin-1")) + with open(save_dir, "w") as f: + json.dump(json_string, f) + + def load_mem(self, load_dir): + with open(load_dir, "r") as f: + json_string = json.load(f) + self.mem = pickle.loads(json.loads(json_string).encode("latin-1")) diff --git a/python/sglang/srt/speculative/lookahead_utils.py b/python/sglang/srt/speculative/lookahead_utils.py new file mode 100644 index 00000000000..8255732bb01 --- /dev/null +++ b/python/sglang/srt/speculative/lookahead_utils.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Type + +import numpy as np +import torch +import triton +import triton.language as tl + +from sglang.srt.speculative.spec_info import SpecInfo + +if TYPE_CHECKING: + from python.sglang.srt.managers.schedule_batch import ScheduleBatch + +from sglang.srt.speculative.eagle_utils import ( + assign_req_to_token_pool, + create_flashinfer_kv_indices_triton, + eagle_verify_retrive, +) + + +class LookaheadVerifyInput(SpecInfo): + def __init__( + self, + draft_token: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_cum_len: torch.Tensor, + draft_token_num: torch.Tensor, + ): + self.draft_token = draft_token + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_cum_len = retrive_cum_len + self.draft_token_num = draft_token_num + self.draft_token_num_sum = draft_token_num.sum().item() + + def prepare_for_verify(self, batch: ScheduleBatch): + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + bs = batch.seq_lens.numel() + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + self.draft_token_num, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + batch_size = len(req_pool_indices) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + self.qo_indptr = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.qo_indptr[1:] = torch.cumsum(self.draft_token_num, dim=0) + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + bs = self.retrive_cum_len.numel() - 1 + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat( + [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1 + ) + draft_token = torch.cat( + [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")], + dim=-1, + ) + target_predict = predict[self.retrive_index] + candidates = draft_token[self.retrive_index] + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + triton.next_power_of_2(self.draft_token_num.max().item()), + triton.next_power_of_2(max_draft_len), + ) + + eos_token_id = batch.reqs[0].tokenizer.eos_token_id + # TODO: check other criteria for end check + mask_tensor = (predict[accept_index] == eos_token_id).int() + first_true_indices = torch.argmax(mask_tensor, dim=1) + has_true = torch.any(mask_tensor, dim=1) + if torch.any(mask_tensor): + batch_size, seq_length = accept_index.shape + range_vec = ( + torch.arange(seq_length, device="cuda") + .unsqueeze(0) + .repeat(batch_size, 1) + ) # shape: (batch_size, seq_length) + threshold = first_true_indices.unsqueeze(1) + 1 # shape: (batch_size, 1) + mask = (range_vec >= threshold) & has_true.unsqueeze( + 1 + ) # shape: (batch_size, seq_length) + accept_index[mask] = -1 + + accept_length = (accept_index != -1).sum(dim=1) + + accept_index_flatten = accept_index[accept_index != -1] + + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length, + batch.out_cache_loc[accept_index_flatten], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length) # TODO: mcheck the case for normal decoding + batch.seq_lens_sum = batch.seq_lens.sum().item() + + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index_flatten] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + batch.token_to_kv_pool.free(mem_need_free_idx) + + last_verified_ids = [] + accept_token_bs = [] + for i in range(bs): + req = batch.reqs[i] + accept_ids = torch.where(accept_index[i] != -1)[0] + accept_token = predict[accept_index[i][accept_ids]] + accept_token_cpu = accept_token.tolist() + req.output_ids.extend(accept_token_cpu) + accept_token_bs.append(accept_token) + req.check_finished() + # need to append the token for scheduler process_batch_result_decode to work + last_verified_ids.append(req.output_ids[-1]) + + verified_id = predict[accept_index_flatten] + verified_id_cpu = verified_id.tolist() + + last_verified_ids = torch.tensor(last_verified_ids, device="cuda") + logits_output.next_token_logits = logits_output.next_token_logits[ + accept_index_flatten + ] + return logits_output, last_verified_ids, accept_length.sum().item() + + def merge_batch(self, spec_info: LookaheadVerifyInput): + return diff --git a/python/sglang/srt/speculative/lookahead_worker.py b/python/sglang/srt/speculative/lookahead_worker.py new file mode 100644 index 00000000000..feeff02f57b --- /dev/null +++ b/python/sglang/srt/speculative/lookahead_worker.py @@ -0,0 +1,213 @@ +import logging +import threading +import time +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import torch + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.lookahead_cache import LookaheadCache +from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput +from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm +from sglang.srt.utils import broadcast_pyobj + +if TYPE_CHECKING: + from sglang.srt.managers.tp_worker import TpModelWorker + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + + +class LOOKAHEADWorker: + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + nccl_port: int, + target_worker: "TpModelWorker", + ): + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.tp_rank = tp_rank + self.num_branch_token: int = server_args.speculative_num_draft_tokens + self.one_branch = server_args.speculative_one_branch + self.lookahead_cache = None + if tp_rank == 0: + self.lookahead_cache = LookaheadCache( + debug=False, eos_ids=None, gpu_id=tp_rank + ) + if server_args.speculative_lookahead_path is not None: + logger.info( + f"Load lookahead from: {server_args.speculative_lookahead_path}" + ) + self.lookahead_cache.load_mem(server_args.speculative_lookahead_path) + self.rids = {} + + def prepare_for_verify(self, batch: ScheduleBatch): + bs = len(batch.reqs) + if bs > 4: + # if batch size is too large, fallback to normal decode + batch.spec_algorithm = SpeculativeAlgorithm.NONE + return None + ( + seq_lens, + leaf_nums, + look_ahead_res, + drafts, + tree_mask, + positions, + retrive_indexes, + draft_token_nums, + ) = ([], [], [], [], [], [], [], []) + for req in batch.reqs: + fill_ids = req.origin_input_ids + req.output_ids + seq_len = len(fill_ids) + seq_lens.append(seq_len) + if self.lookahead_cache is not None: + check_token = fill_ids[-self.num_branch_token :] + # make total_draft_len 2^n + total_draft_len = ( + 2 ** ((self.num_branch_token * 4 // bs) - 1).bit_length() + ) + is_one_branch = ( + self.one_branch or total_draft_len <= self.num_branch_token + ) + if is_one_branch: + req_drafts, mask, _ = self.lookahead_cache.one_get( + check_token, + branch_length=self.num_branch_token, + idx=self.rids[req.rid], + ) + else: + req_drafts, mask, _ = self.lookahead_cache.hier_get( + check_token, + idx=self.rids[req.rid], + branch_length=self.num_branch_token, + decoding_length=total_draft_len, + ) + data = broadcast_pyobj( + [req_drafts, mask], + self.tp_rank, + self.model_runner.tp_group.cpu_group, + ) + + else: + (req_drafts, mask) = broadcast_pyobj( + [], + self.tp_rank, + self.model_runner.tp_group.cpu_group, + ) + look_ahead_res.append((req_drafts, mask)) + # number of draft tokens might be different for each req + draft_token_nums.append(len(req_drafts)) + + # check the draft_token_nums all 1 s, if no match just normal decode + if np.sum(draft_token_nums) == bs: + batch.spec_algorithm = SpeculativeAlgorithm.NONE + return None + + cum_draft_token_nums = np.cumsum([0] + draft_token_nums) + for i, (req_drafts, mask_) in enumerate(look_ahead_res): + seq_len = seq_lens[i] + mask = torch.from_numpy(mask_).cuda() + req_mask = torch.ones( + (len(req_drafts), seq_len - 1) + ).cuda() # TODO: check the new generated token + req_mask = torch.cat((req_mask, mask), dim=1).to(torch.bool) + tree_mask.append(req_mask.flatten()) + + leaf_mask = mask_[np.argmax(mask_[:, mask_.sum(0) == 1], axis=0), :] + leaf_num = leaf_mask.shape[0] + leaf_nums.append(leaf_num) + row_indices, col_indices = np.nonzero(leaf_mask) + retrieve_index = [[] for _ in range(leaf_num)] + for row, col in zip(row_indices, col_indices): + retrieve_index[row].append(col + cum_draft_token_nums[i]) + for idxs in retrieve_index: + idxs.extend([-1] * (self.num_branch_token - len(idxs))) + + retrieve_index = torch.tensor(retrieve_index, device="cuda") + retrive_indexes.append(retrieve_index) + position = mask.sum(1) + seq_len - 1 + positions.append(position) + + drafts.extend(req_drafts) + + # only one row for each req for one branch case + leaf_nums = torch.tensor(leaf_nums, device="cuda") + cum_len = torch.cumsum(leaf_nums, dim=0) + retrive_cum_len = torch.zeros( + (leaf_nums.numel() + 1,), dtype=torch.int32, device="cuda" + ) + retrive_cum_len[1:] = cum_len + + draft_tokens = torch.tensor(drafts, device="cuda") + self.draft_token_nums = torch.tensor(draft_token_nums, device="cuda") + retrive_indexes = torch.vstack(retrive_indexes).to(torch.long).cuda() + positions = torch.cat(positions, axis=0).to(torch.long) + tree_mask = torch.cat(tree_mask, axis=0) + batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD + return LookaheadVerifyInput( + draft_tokens, + tree_mask, + positions, + retrive_indexes, + retrive_cum_len, + self.draft_token_nums, + ) + + def forward_batch_speculative_generation(self, batch: ScheduleBatch): + if batch.forward_mode.is_target_verify(): + verify_input = batch.spec_info + model_worker_batch = batch.get_model_worker_batch() + logits_output, _ = self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) + batch.forward_mode = ForwardMode.DECODE + logits_output, verified_id, accept_length_sum = verify_input.verify( + batch, logits_output + ) + return logits_output, verified_id, model_worker_batch, accept_length_sum + + else: + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.target_worker.forward_batch_generation( + model_worker_batch + ) + if self.lookahead_cache is not None: + next_token_ids_cpu = next_token_ids.tolist() + for r, token in zip(batch.reqs, next_token_ids_cpu): + self.rids[r.rid] = len(self.rids) + put_ids = r.fill_ids + [token] + self.lookahead_cache.put( + put_ids[1:], + branch_length=self.num_branch_token * 2, + mode="input", + idx=self.rids[r.rid], + ) + return logits_output, next_token_ids, model_worker_batch, 0 + + def finish_request(self, reqs: Union[Req, List[Req]]): + if not isinstance(reqs, List): + reqs = [reqs] + for req in reqs: + if self.lookahead_cache is not None: + put_ids = ( + req.origin_input_ids[-self.num_branch_token :] + req.output_ids + ) + # update the lookahead_cache after the request is finished, and do the clean up + self.lookahead_cache.put( + put_ids, + branch_length=self.num_branch_token * 2, + mode="output", + idx=self.rids[req.rid], + final=True, + ) + if len(self.rids) >= 1000: + self.rids = dict(list(self.rids.items())[-500:]) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 5f156b837f9..611f2786684 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -4,6 +4,7 @@ class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE = auto() + LOOKAHEAD = auto() def is_none(self): return self == SpeculativeAlgorithm.NONE @@ -11,10 +12,14 @@ def is_none(self): def is_eagle(self): return self == SpeculativeAlgorithm.EAGLE + def is_lookahead(self): + return self == SpeculativeAlgorithm.LOOKAHEAD + @staticmethod def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, + "LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD, None: SpeculativeAlgorithm.NONE, } return name_map[name] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d263bc11369..e0a28763fb4 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ "test_custom_allreduce.py", "test_double_sparsity.py", "test_eagle_infer.py", + "test_lookahead_infer.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", "test_gguf.py", diff --git a/test/srt/test_lookahead_infer.py b/test/srt/test_lookahead_infer.py new file mode 100644 index 00000000000..62a483063ed --- /dev/null +++ b/test/srt/test_lookahead_infer.py @@ -0,0 +1,36 @@ +import unittest + +import sglang as sgl + + +class TestLOOKAHEADEngine(unittest.TestCase): + + def test_lookahead_accuracy(self): + prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nWho are you? [/INST]" + target_model_path = "meta-llama/Llama-2-7b-chat-hf" + + sampling_params = {"temperature": 0.0001, "max_new_tokens": 20, "top_k": 1} + + engine = sgl.Engine( + model_path=target_model_path, + speculative_algorithm="LOOKAHEAD", + speculative_num_draft_tokens=4, + speculative_one_branch=True, + ) + out1 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + engine = sgl.Engine(model_path=target_model_path) + out2 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + print("==== Answer 1 ====") + print(out1) + + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) + + +if __name__ == "__main__": + unittest.main()