Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative decoding with lookahead #2790

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,11 @@ def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle():
return
elif self.spec_algorithm.is_lookahead():
self.spec_info.prepare_for_verify(self)
# overwrite the forward_mode
self.forward_mode = ForwardMode.TARGET_VERIFY
return

self.input_ids = self.output_ids
self.output_ids = None
Expand Down
15 changes: 14 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
)
self.decode_mem_cache_buf_multiplier = (
self.server_args.speculative_num_draft_tokens
if not self.spec_algorithm.is_none()
if self.spec_algorithm.is_eagle()
else 1
)
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Expand Down Expand Up @@ -257,6 +257,17 @@ def __init__(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_lookahead():
jjjjohnson marked this conversation as resolved.
Show resolved Hide resolved
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker

self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
)
else:
self.draft_worker = None

Expand Down Expand Up @@ -1064,6 +1075,8 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
if batch.batch_size() < initial_bs:
self.batch_is_full = False

if self.spec_algorithm.is_lookahead():
batch.spec_info = self.draft_worker.prepare_for_verify(batch)
# Update batch tensors
batch.prepare_for_decode()
return batch
Expand Down
65 changes: 54 additions & 11 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def set_torch_compile_config():
torch._dynamo.config.cache_size_limit = 1024


def get_batch_sizes_to_capture(model_runner: ModelRunner):
def get_batch_sizes_to_capture(model_runner: ModelRunner, is_spec=False):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
Expand All @@ -126,12 +126,17 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
)
)
)
capture_bs = [
bs
for bs in capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= server_args.cuda_graph_max_bs
]
if is_spec:
# For speculative inference, large batch sizes are not effective.
capture_bs = [1, 2, 3, 4]
else:
capture_bs = [
bs
for bs in capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= server_args.cuda_graph_max_bs
]

if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)]
compile_bs = (
Expand All @@ -158,7 +163,7 @@ def set_global_graph_memory_pool(val):
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""

def __init__(self, model_runner: ModelRunner):
def __init__(self, model_runner: "ModelRunner", is_spec=False):
# Parse args
self.model_runner = model_runner
self.graphs = {}
Expand All @@ -169,9 +174,11 @@ def __init__(self, model_runner: ModelRunner):
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size

self.is_spec = is_spec
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(
model_runner, is_spec
)
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
Expand All @@ -183,6 +190,12 @@ def __init__(self, model_runner: ModelRunner):
self.model_runner.server_args.speculative_num_draft_tokens
)

if model_runner.spec_algorithm.is_lookahead() and self.is_spec:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)

# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
Expand Down Expand Up @@ -213,6 +226,11 @@ def __init__(self, model_runner: ModelRunner):
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
if self.is_spec:
self.draft_token_num = (
torch.ones((len(self.capture_bs),), dtype=torch.int32)
jjjjohnson marked this conversation as resolved.
Show resolved Hide resolved
* self.num_tokens_per_bs
)

if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
Expand Down Expand Up @@ -281,7 +299,14 @@ def can_run(self, forward_batch: ForwardBatch):
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported

is_token_num_supported = True
if self.is_spec:
is_token_num_supported = (
forward_batch.batch_size * self.num_tokens_per_bs
== forward_batch.input_ids.numel()
)
return is_token_num_supported and is_bs_supported and is_encoder_lens_supported

def capture(self):
with graph_capture() as graph_capture_context:
Expand Down Expand Up @@ -481,5 +506,23 @@ def get_spec_info(self, num_tokens: int):
spec_steps=self.model_runner.server_args.speculative_num_steps,
capture_hidden_mode=CaptureHiddenMode.FULL,
)
if self.model_runner.spec_algorithm.is_lookahead() and self.is_spec:
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput

bs = int(num_tokens / self.num_tokens_per_bs)
spec_info = LookaheadVerifyInput(
None,
None,
None,
None,
None,
self.draft_token_num[:bs],
)
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
spec_info.custom_mask = torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
)

return spec_info
17 changes: 13 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def init_memory_pool(
4096,
)

if not self.spec_algorithm.is_none():
if self.spec_algorithm.is_eagle():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
else:
Expand Down Expand Up @@ -717,6 +717,7 @@ def init_double_sparsity_channel_config(self, selected_channel):
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
self.cuda_graph_runner_spec = None

if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
Expand All @@ -728,6 +729,10 @@ def init_cuda_graphs(self):
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
if self.spec_algorithm.is_lookahead():
# in case look_ahead failed to match any draft token, fallback to normal cuda graph decode
jjjjohnson marked this conversation as resolved.
Show resolved Hide resolved
self.cuda_graph_runner_spec = CudaGraphRunner(self, is_spec=True)

logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")

def apply_torch_tp(self):
Expand Down Expand Up @@ -772,12 +777,16 @@ def forward_idle(self, forward_batch: ForwardBatch):
)

def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
cuda_graph_runner = self.cuda_graph_runner
if forward_batch.spec_algorithm.is_lookahead():
cuda_graph_runner = self.cuda_graph_runner_spec

if (
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
and cuda_graph_runner
and cuda_graph_runner.can_run(forward_batch)
):
return self.cuda_graph_runner.replay(forward_batch)
return cuda_graph_runner.replay(forward_batch)

if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
Expand Down
27 changes: 26 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class ServerArgs:
speculative_num_steps: int = 5
speculative_num_draft_tokens: int = 64
speculative_eagle_topk: int = 8
speculative_lookahead_path: str = None
speculative_one_branch: bool = False

# Double Sparsity
enable_double_sparsity: bool = False
Expand Down Expand Up @@ -270,6 +272,18 @@ def __post_init__(self):
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
)

if self.speculative_algorithm == "LOOKAHEAD":
self.disable_overlap_schedule = True
self.chunked_prefill_size = -1
self.disable_mla = True
self.enable_double_sparsity = False
assert (
self.attention_backend == "flashinfer"
), "Lookahead speculative decoding only support flashinfer for now."
logger.info(
"The mla, chunked_prefill, overlap scheduler and double_sparsity are disabled because of lookahead speculative decoding."
)

# GGUF
if (
self.load_format == "auto" or self.load_format == "gguf"
Expand Down Expand Up @@ -698,14 +712,20 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--speculative-algorithm",
type=str,
choices=["EAGLE"],
choices=["EAGLE", "LOOKAHEAD"],
help="Speculative algorithm.",
)
parser.add_argument(
"--speculative-draft-model-path",
type=str,
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--speculative-lookahead-path",
type=str,
help="The path of the lookahead. If provided, the lookahead will be inited from this path. You can `lookahead_cache.save_mem('lookahrad.pkl')` to save the lookahead for later use.",
required=False,
)
parser.add_argument(
"--speculative-num-steps",
type=int,
Expand All @@ -725,6 +745,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
choices=[1, 2, 4, 8],
default=ServerArgs.speculative_eagle_topk,
)
parser.add_argument(
"--speculative-one-branch",
action="store_true",
help="Whether to use one branch in Lookahead Speculative Decoding.",
)

# Double Sparsity
parser.add_argument(
Expand Down
Loading
Loading