Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative decoding with lookahead #2790

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jjjjohnson
Copy link
Contributor

@jjjjohnson jjjjohnson commented Jan 8, 2025

Motivation

n-gram based speculative is very effective in retrieval augmented generation(RAG). The cost of generating draft tokens is relatively low compared to eagle and has a great potential for accelerating token generation in RAG. Ant group has proposed the Trie-based retrieval and verification mechanism. They claimed to use lookahead based on vLLM for the single-query situation and obtain 1.6 times acceleration on a real-life scenario. I want to adopt lookahead to SGLang.

Related resources

Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy

Overall workflow

image

Features

  • No need to train draft model.
  • Trie tree will be updated with both prompt tokens and output tokens.
  • The draft tokens generation is a frequency based sort mechanism from the specific prompt tokens and ALL history output tokens(with evict).
  • Both Single-branch and Multi-branch are supported.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@jjjjohnson
Copy link
Contributor Author

import sglang as sgl
import time
import json
import numpy as np

def main():
    # Sample prompts.
    prompts = [
        '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n'
    ]

    sampling_params = {"temperature": 0.7, "repetition_penalty":1,
                       "max_new_tokens": 256,"top_k": 1,
                       "stop_token_ids": [151645, 151644, 151643]}


    model_path = "Qwen/Qwen2-7B-Instruct"

    # Create an LLM.
    llm = sgl.Engine(model_path=model_path, speculative_one_branch=True, disable_cuda_graph=False, 
                     speculative_num_draft_tokens=4, speculative_algorithm='LOOKAHEAD', mem_fraction_static=0.60, 
                     watchdog_timeout=1e8, log_level='info')


    for idx in range(5):
        start = time.time()
        outputs = llm.generate(prompts, sampling_params)
        cos = time.time()-start
        completion_tokens = 0
        # Print the outputs.
        for prompt, output in zip(prompts, outputs):
            completion_tokens += output["meta_info"]["completion_tokens"]
            print(f"{output['text']}")
            print('======================')
        print(f"{idx=}!!!!!!!!! tps =: {completion_tokens/cos}\n\n")

if __name__ == "__main__":
    main()
image

@zhyncs
Copy link
Member

zhyncs commented Jan 11, 2025

Hi @jjjjohnson Could you help resolve the conflicts? Thanks.

@jjjjohnson
Copy link
Contributor Author

Hi @jjjjohnson Could you help resolve the conflicts? Thanks.

Done

@merrymercy
Copy link
Contributor

Could you share any performance results?

@merrymercy merrymercy mentioned this pull request Jan 15, 2025
@jjjjohnson
Copy link
Contributor Author

jjjjohnson commented Jan 16, 2025

Could you share any performance results?

Sure!
Since the Lookahead speculative decode will cache input and output tokens, I run sglang.bench_serving 2 turns and disable the random.shuffle(dataset) to make the request same for 2 turns to compare the performance difference with normal decode.
Note: Lookahead speculative decode is turned off when batch size > 4 and I limit the max-concurrency and request-rate.

image

Start Server:

Normal decode:

python -m sglang.launch_server --model-path /mnt/workspace/model_hub/Qwen2-7B-Instruct --trust-remote-code --tp 1

Lookahead speculative decode:

python -m sglang.launch_server --model-path /mnt/workspace/model_hub/Qwen2-7B-Instruct \
      --trust-remote-code --tp 1 --speculative-num-draft-tokens 4 --speculative-algorithm LOOKAHEAD --speculative-one-branch

Benchmark:

python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --dataset-path /oss/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 500 --max-concurrency 3 --request-rate 2

Result:

Normal decode first run turn:

Backend:

Normal decode second run turn:

image

Lookahead speculative decode first run turn:

Backend:

Lookahead speculative decode second run turn:

Backend:

python/sglang/srt/server_args.py Outdated Show resolved Hide resolved
python/sglang/srt/managers/scheduler.py Show resolved Hide resolved
python/sglang/srt/speculative/lookahead_cache.py Outdated Show resolved Hide resolved
@mpjlu
Copy link
Contributor

mpjlu commented Feb 7, 2025

I find this PR cannot run DeepSeek V3, have you test this model?

@jjjjohnson
Copy link
Contributor Author

I find this PR cannot run DeepSeek V3, have you test this model?

No. What is the error message?

@mpjlu
Copy link
Contributor

mpjlu commented Feb 7, 2025

I find this PR cannot run DeepSeek V3, have you test this model?

No. What is the error message?

mla crash,no show very useful message.

@mpjlu
Copy link
Contributor

mpjlu commented Feb 11, 2025

I find this PR cannot run llama 8b with triton backend, the error is:

46 File "/data/peng/sglang/python/sglang/srt/speculative/lookahead_utils.py", line 160, in verify
47 batch.seq_lens_sum = batch.seq_lens.sum().item()
48 RuntimeError: CUDA error: an illegal memory access was encountered

Does this PR support triton backend?

@coolhok
Copy link
Contributor

coolhok commented Feb 12, 2025

mla

I think mla attention not support tree mask,so this pr not work with Deepseek.

@coolhok
Copy link
Contributor

coolhok commented Feb 12, 2025

I find this PR cannot run llama 8b with triton backend, the error is:

46 File "/data/peng/sglang/python/sglang/srt/speculative/lookahead_utils.py", line 160, in verify 47 batch.seq_lens_sum = batch.seq_lens.sum().item() 48 RuntimeError: CUDA error: an illegal memory access was encountered

Does this PR support triton backend?

lookahead depend on flashinfer tree mask attention.triton now is not support tree mask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants