Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM INFER] Support speculative decoding (llama) #9180

Merged
merged 39 commits into from
Nov 25, 2024

Conversation

Wanglongzhi2001
Copy link
Contributor

@Wanglongzhi2001 Wanglongzhi2001 commented Sep 23, 2024

PR types

New features

PR changes

Others

Description

Support speculative decoding (Inference with reference) with dynamic batch.

Usage:

cd llm
python predict/predictor.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --speculate_method inference_with_reference --speculate_max_draft_token_num 5 --speculate_max_ngram_size 2
  • NOTE:
    Since inference with reference match output in input prompt, so user can achieve speed-up inference process in many practical generation scenarios where significant overlap between in-context reference and outputs exists (e.g. long document query, search engines and multi-turn conversations)

For example, user can achieve speed-up when given the following prompt:

LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
| ID  | Name          | Age | Occupation    | Country       | Email                  | Phone Number   | Address                       |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1   | John Doe      | 29  | Engineer      | USA           | john.doe@example.com   | 555-1234       | 123 Elm St, Springfield, IL  |
| 2   | Jane Smith    | 34  | Doctor        | Canada        | jane.smith@example.com | 555-5678       | 456 Oak St, Toronto, ON      |
| 3   | Alice Johnson | 27  | Teacher       | UK            | alice.j@example.com    | 555-8765       | 789 Pine St, London, UK      |
| 4   | Bob Brown     | 45  | Artist        | Australia     | bob.b@example.com      | 555-4321       | 321 Maple St, Sydney, NSW    |
| 5   | Carol White   | 31  | Scientist     | New Zealand   | carol.w@example.com    | 555-6789       | 654 Birch St, Wellington, NZ |
| 6   | Dave Green    | 28  | Lawyer        | Ireland       | dave.g@example.com     | 555-3456       | 987 Cedar St, Dublin, IE     |
| 7   | Emma Black    | 40  | Musician      | USA           | emma.b@example.com     | 555-1111       | 246 Ash St, New York, NY     |
"""
prompt = LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is "

@CLAassistant
Copy link

CLAassistant commented Sep 23, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Oct 16, 2024

Codecov Report

Attention: Patch coverage is 1.92308% with 153 lines in your changes missing coverage. Please review.

Project coverage is 52.90%. Comparing base (195fde3) to head (367237e).
Report is 2 commits behind head on develop.

Current head 367237e differs from pull request most recent head 1bca939

Please upload reports for the commit 1bca939 to get more accurate results.

Files with missing lines Patch % Lines
...dlenlp/experimental/transformers/llama/modeling.py 0.00% 64 Missing ⚠️
paddlenlp/trl/llm_utils.py 0.00% 30 Missing ⚠️
paddlenlp/experimental/transformers/proposers.py 0.00% 27 Missing ⚠️
...enlp/experimental/transformers/generation_utils.py 0.00% 8 Missing ⚠️
...erimental/transformers/fused_transformer_layers.py 0.00% 7 Missing ⚠️
paddlenlp/transformers/model_utils.py 0.00% 6 Missing ⚠️
paddlenlp/transformers/auto/modeling.py 0.00% 5 Missing ⚠️
paddlenlp/experimental/transformers/__init__.py 0.00% 1 Missing ⚠️
...dlenlp/experimental/transformers/bloom/modeling.py 0.00% 1 Missing ⚠️
...p/experimental/transformers/chatglm_v2/modeling.py 0.00% 1 Missing ⚠️
... and 3 more
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9180      +/-   ##
===========================================
+ Coverage    52.87%   52.90%   +0.03%     
===========================================
  Files          687      686       -1     
  Lines       109186   109047     -139     
===========================================
- Hits         57727    57693      -34     
+ Misses       51459    51354     -105     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@Wanglongzhi2001 Wanglongzhi2001 changed the title [WIP] Support speculative decoding Support speculative decoding Oct 23, 2024
@Wanglongzhi2001
Copy link
Contributor Author

tiny模型的dim_head太小了(我选的是tiny-random-llama),append_attn目前只支持dim_head为128,因此目前还未编写单测。


#include "paddle/extension.h"

void UpdateInputIdsCPU(const paddle::Tensor& input_ids_cpu,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些简单的自定义算子,考虑下是否有必要新增?或者和之前的自定义算子应该有某种办法复用吧

@Wanglongzhi2001 Wanglongzhi2001 force-pushed the specu_decoding_inner branch 2 times, most recently from 43c4efa to 4547242 Compare November 19, 2024 18:10
yuanlehome
yuanlehome previously approved these changes Nov 21, 2024
Copy link
Collaborator

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

yuanlehome
yuanlehome previously approved these changes Nov 21, 2024
qingqing01
qingqing01 previously approved these changes Nov 22, 2024
@yuanlehome yuanlehome changed the title Support speculative decoding [LLM INFER] Support speculative decoding (llama) Nov 22, 2024
Comment on lines +1019 to +1028
# init speculate components
if config.speculate_method == "inference_with_reference":
self.proposer = InferenceWithReferenceProposer(
config.speculate_max_draft_token_num,
config.speculate_max_ngram_size,
config.batch_size,
config.max_length,
)
else:
self.proposer = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我看动态图和静态图predictor都有,是否可以放在基类BlockInferencePredictorMixin里?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样的话语义上感觉不是很清晰?self.proposer是和self.model一个级别的所以放在一起初始化感觉清晰点,而BlockInferencePredictorMixin都是初始化模型的输入参数的放在这里感觉不是很合适,而且放在BlockInferencePredictorMixin里比较难找,别人可能很难知道proposer是什么意思

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的~

Comment on lines 1044 to 1063
# whether speculative decoding
if self.proposer is None:
read_res_process = mp.Process(
target=llm_utils.read_res, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
)
if self.tensor_parallel_rank == 0:
read_res_process.start()

output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu()
else:
read_res_process = mp.Process(
target=llm_utils.speculate_read_res,
args=[self.model_name_or_path, tensor_queue, result_queue, done_event],
)
if self.tensor_parallel_rank == 0:
read_res_process.start()

output_tensor = paddle.full(
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
).cpu()
Copy link
Collaborator

@yuanlehome yuanlehome Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块可以写成:

# whether speculative decoding
if self.proposer is None:
    read_res_func = llm_utils.read_res
    output_tensor_shape = [MAX_BSZ + 2, 1]
else:
    read_res_func = llm_utils.speculate_read_res
    output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]

read_res_process = mp.Process(
    target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
)
if self.tensor_parallel_rank == 0:
    read_res_process.start()

output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用本质的区别作为代码分支逻辑的判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,我改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome yuanlehome merged commit d68a385 into PaddlePaddle:develop Nov 25, 2024
10 of 12 checks passed
@Wanglongzhi2001 Wanglongzhi2001 deleted the specu_decoding_inner branch November 25, 2024 08:38
@Wanglongzhi2001 Wanglongzhi2001 restored the specu_decoding_inner branch November 25, 2024 08:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants