-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLM INFER] Support speculative decoding (llama) #9180
[LLM INFER] Support speculative decoding (llama) #9180
Conversation
cecdbdc
to
5c93e1d
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9180 +/- ##
===========================================
+ Coverage 52.87% 52.90% +0.03%
===========================================
Files 687 686 -1
Lines 109186 109047 -139
===========================================
- Hits 57727 57693 -34
+ Misses 51459 51354 -105 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
d7bd81c
to
b4fedab
Compare
tiny模型的dim_head太小了(我选的是tiny-random-llama),append_attn目前只支持dim_head为128,因此目前还未编写单测。 |
376e87a
to
0b85e9e
Compare
|
||
#include "paddle/extension.h" | ||
|
||
void UpdateInputIdsCPU(const paddle::Tensor& input_ids_cpu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些简单的自定义算子,考虑下是否有必要新增?或者和之前的自定义算子应该有某种办法复用吧
b6065da
to
d50703f
Compare
43c4efa
to
4547242
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
a6412a8
to
e62628a
Compare
01e0eb9
to
7155ae1
Compare
# init speculate components | ||
if config.speculate_method == "inference_with_reference": | ||
self.proposer = InferenceWithReferenceProposer( | ||
config.speculate_max_draft_token_num, | ||
config.speculate_max_ngram_size, | ||
config.batch_size, | ||
config.max_length, | ||
) | ||
else: | ||
self.proposer = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个我看动态图和静态图predictor都有,是否可以放在基类BlockInferencePredictorMixin里?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样的话语义上感觉不是很清晰?self.proposer是和self.model一个级别的所以放在一起初始化感觉清晰点,而BlockInferencePredictorMixin都是初始化模型的输入参数的放在这里感觉不是很合适,而且放在BlockInferencePredictorMixin里比较难找,别人可能很难知道proposer是什么意思
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的~
llm/predict/predictor.py
Outdated
# whether speculative decoding | ||
if self.proposer is None: | ||
read_res_process = mp.Process( | ||
target=llm_utils.read_res, args=[self.model_name_or_path, tensor_queue, result_queue, done_event] | ||
) | ||
if self.tensor_parallel_rank == 0: | ||
read_res_process.start() | ||
|
||
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() | ||
else: | ||
read_res_process = mp.Process( | ||
target=llm_utils.speculate_read_res, | ||
args=[self.model_name_or_path, tensor_queue, result_queue, done_event], | ||
) | ||
if self.tensor_parallel_rank == 0: | ||
read_res_process.start() | ||
|
||
output_tensor = paddle.full( | ||
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64" | ||
).cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块可以写成:
# whether speculative decoding
if self.proposer is None:
read_res_func = llm_utils.read_res
output_tensor_shape = [MAX_BSZ + 2, 1]
else:
read_res_func = llm_utils.speculate_read_res
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
read_res_process = mp.Process(
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
)
if self.tensor_parallel_rank == 0:
read_res_process.start()
output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用本质的区别作为代码分支逻辑的判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK,我改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
367237e
to
1bca939
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
Support speculative decoding (Inference with reference) with dynamic batch.
Usage:
Since inference with reference match output in input prompt, so user can achieve speed-up inference process in many practical generation scenarios where significant overlap between in-context reference and outputs exists (e.g. long document query, search engines and multi-turn conversations)
For example, user can achieve speed-up when given the following prompt: