Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ngram spec #2886

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 118 additions & 15 deletions python/sglang/srt/layers/attention/torch_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _run_sdpa_forward_extend(
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)

per_req_out_redudant = (
scaled_dot_product_attention(
per_req_query_redudant.unsqueeze(0),
Expand All @@ -108,6 +107,94 @@ def _run_sdpa_forward_extend(
start_q, start_kv = end_q, end_kv
return output

def _run_sdpa_forward_target_verify(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
"""Run the target verify forward by using torch native sdpa op.

Args:
query: [num_tokens, num_heads, head_size]
output: [num_tokens, num_heads, head_size]
k_cache: [max_total_num_tokens, num_heads, head_size]
v_cache: [max_total_num_tokens, num_heads, head_size]
req_to_token: [max_num_reqs, max_context_len]
req_pool_indices: [num_seqs]
seq_lens: [num_seqs]
extend_prefix_lens: [num_seqs]
extend_seq_lens: [num_seqs]
scaling: float or None
enable_gqa: bool
causal: bool

Returns:
output: [num_tokens, num_heads, head_size]
"""
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
assert seq_lens.shape[0] == extend_seq_lens.shape[0]

# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)

start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
# TODO: this loop process a sequence per iter, this is inefficient.
# Need optimize the performance later.

extend_seq_len_q = extend_seq_lens[seq_idx]

seq_len_kv = seq_lens[seq_idx]
end_q = start_q + (seq_lens - extend_seq_len_q)
end_kv = start_kv + seq_len_kv

per_req_query = query[:, start_q:end_q, :]
per_req_query_redudant = torch.zeros(
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
dtype=per_req_query.dtype,
device=per_req_query.device,
)

per_req_query_redudant[:, -per_req_query.shape[1] :, :] = per_req_query

# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)

assert per_req_key.shape[1] == per_req_value.shape[1]
assert per_req_query_redudant.shape[1] == per_req_key.shape[1]

per_req_out_redudant = (
scaled_dot_product_attention(
per_req_query_redudant.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out_redudant[
-per_req_query.shape[1] :, :, :
]
start_q, start_kv = end_q, end_kv
return output

def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -202,20 +289,36 @@ def forward_extend(
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)

self._run_sdpa_forward_extend(
q_,
o_,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=not layer.is_cross_attention,
)
if forward_batch.forward_mode.is_target_verify():
self._run_sdpa_forward_target_verify(
q_,
o_,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=not layer.is_cross_attention,
)
else:
self._run_sdpa_forward_extend(
q_,
o_,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=not layer.is_cross_attention,
)
return o

def forward_decode(
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ class LogitsMetadata:

@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if forward_batch.spec_info:
if forward_batch.spec_info and hasattr(
forward_batch.spec_info, "capture_hidden_mode"
):
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
else:
capture_hidden_mode = CaptureHiddenMode.NULL
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def prepare_for_idle(self):

def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle():
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_ngram():
return

self.input_ids = self.output_ids
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ def __init__(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGramWorker

self.draft_worker = NGramWorker(
target_worker=self.tp_worker,
server_args=server_args,
)
else:
self.draft_worker = None

Expand Down
16 changes: 15 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ def __init__(self, model_runner: "ModelRunner"):
self.model_runner.server_args.speculative_num_draft_tokens
)

if model_runner.spec_algorithm.is_ngram():
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens + 1
)

self.compile_bs = (
[
bs
Expand Down Expand Up @@ -343,7 +349,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=self.get_spec_info(num_tokens, positions),
)

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
Expand Down Expand Up @@ -467,4 +472,13 @@ def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
)
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL

if self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGramSpecInfo

spec_info = NGramSpecInfo(
max_num_draft_tokens=self.model_runner.server_args.speculative_num_draft_tokens,
verified_ids=None,
draft_tokens=None,
positions=None,
)
return spec_info
6 changes: 4 additions & 2 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,11 @@ def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
return self.cuda_graph_runner.replay(forward_batch)

if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
res = self.forward_decode(forward_batch)
return res
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
res = self.forward_extend(forward_batch)
return res
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else:
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class ServerArgs:
speculative_num_steps: int = 5
speculative_num_draft_tokens: int = 64
speculative_eagle_topk: int = 8
speculative_ngram_window_size: int = 5

# Double Sparsity
enable_double_sparsity: bool = False
Expand Down Expand Up @@ -249,14 +250,17 @@ def __post_init__(self):
)

# Speculative Decoding
if self.speculative_algorithm == "EAGLE":
if (
self.speculative_algorithm == "EAGLE"
or self.speculative_algorithm == "NGRAM"
):
self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True
self.disable_radix_cache = True
self.disable_overlap_schedule = True
self.chunked_prefill_size = -1
logger.info(
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle/ngram speculative decoding."
)

# GGUF
Expand Down Expand Up @@ -664,7 +668,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--speculative-algorithm",
type=str,
choices=["EAGLE"],
choices=["EAGLE", "NGRAM"],
help="Speculative algorithm.",
)
parser.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def prepare_for_verify(self, batch: ScheduleBatch):
top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values

draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
Expand Down Expand Up @@ -541,12 +540,15 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
dim=-1,
)

target_predict = predict[self.retrive_index]
candidates = draft_token[self.retrive_index]

# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
accept_length_cpu = accept_mask.tolist()
bs = self.retrive_cum_len.numel() - 1

max_draft_len = self.retrive_index.shape[-1]
Expand All @@ -555,6 +557,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")

eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
Expand Down
Loading