Skip to content

Commit

Permalink
Merge pull request #2541 from Wanglongzhi2001/speculate_decoding
Browse files Browse the repository at this point in the history
Support speculate decoding
  • Loading branch information
Jiang-Jia-Jun authored Jan 10, 2025
2 parents 947f230 + fe35dc5 commit 19e752c
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 43 deletions.
29 changes: 29 additions & 0 deletions llm/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from paddlenlp.generation import GenerationConfig
from server.utils import model_server_logger
from dataclasses import dataclass


class Config:
Expand Down Expand Up @@ -203,6 +204,27 @@ def get_model_config(self):
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
return model_config_json

def get_speculate_config(self):
"""
get speculate_decoding related config
Returns:
SpeculateConfig: the speculate related config
"""
speculate_config = SpeculateConfig()
model_cfg = self.get_model_config()
if model_cfg.get("speculate_method", "None") != "None":
speculate_config.speculate_method = str(model_cfg["speculate_method"])
speculate_config.speculate_max_draft_token_num = model_cfg[
"speculate_max_draft_token_num"]
speculate_config.speculate_max_ngram_size = model_cfg[
"speculate_max_ngram_size"]

if speculate_config.speculate_method not in ["None", "inference_with_reference"]:
model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}")

return speculate_config

def read_from_config(self):
"""
reset model config from json file
Expand Down Expand Up @@ -234,3 +256,10 @@ def get_unique_name(self, name):

def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)


@dataclass
class SpeculateConfig:
speculate_method: str = "None"
speculate_max_draft_token_num: int = 1
speculate_max_ngram_size: int = 1
47 changes: 46 additions & 1 deletion llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from paddlenlp_ops import step_paddle
from server.data.processor import DataProcessor
from server.engine.config import Config
from paddlenlp.experimental.transformers import InferenceWithReferenceProposer
from server.utils import get_logger
from task_queue_manager import TaskQueueManager

Expand All @@ -46,6 +47,8 @@ def __init__(self, args):

self.config = Config()
self.model_cfg = self.config.get_model_config()
self.speculate_config = self.config.get_speculate_config()
self.is_speculate_decoding = self.speculate_config.speculate_method != "None"
self.format_print_configuration()

self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
Expand All @@ -67,6 +70,17 @@ def __init__(self, args):
self.cache_kvs = {}
self.init_inputs()

if self.is_speculate_decoding:
logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.')
if self.speculate_config.speculate_method == "inference_with_reference":
self.proposer = InferenceWithReferenceProposer(
self.speculate_config.speculate_max_draft_token_num,
self.speculate_config.speculate_max_ngram_size,
self.args.max_batch_size,
self.args.max_seq_len)
else:
self.proposer = None

self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)

model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
Expand Down Expand Up @@ -263,6 +277,18 @@ def init_inputs(self):
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
# speculate decoding input
if self.is_speculate_decoding:
self.share_inputs["accept_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
)
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
self.share_inputs["draft_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
)
self.share_inputs["actual_draft_token_num"] = paddle.full(
shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32"
)

def dy_input_preprocess(self, tasks):
"""
Expand Down Expand Up @@ -318,10 +344,21 @@ def dy_input_preprocess(self, tasks):
task["stop_seqs_len"], dtype="int32")
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
task["stop_seqs"], dtype="int64")

if self.is_speculate_decoding:
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num])

def step_cuda(self, seq_lens_this_time):
"""
step cuda
"""
# whether speculate decoding
if self.is_speculate_decoding:
speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1
else:
speculate_step_token_num = 0

step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],
Expand All @@ -334,7 +371,8 @@ def step_cuda(self, seq_lens_this_time):
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
speculate_step_token_num)

def initialize_engine_ready_check_flag(self):
"""
Expand Down Expand Up @@ -459,6 +497,13 @@ def run(self):
time.sleep(0.001)
continue

if self.proposer is not None:
self.proposer.run(
self.share_inputs,
real_batch_size=seq_lens_this_time.shape[0],
seq_lens_this_time=seq_lens_this_time,
)

self.infer_engine.predictor.run()
self.share_inputs['infer_seed'].add_(infer_seed_increment)
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
Expand Down
111 changes: 69 additions & 42 deletions llm/server/server/engine/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from datetime import datetime

import numpy as np
from paddlenlp_ops import get_output
from paddlenlp_ops import get_output, speculate_get_output
from server.utils import datetime_diff, model_server_logger, monitor_logger
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ


class TokenProcessor(object):
Expand All @@ -37,7 +38,12 @@ def __init__(self, cfg):
self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]

self.tokens_counter = Counter()
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")

self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None"
if self.is_speculate_decoding:
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
else:
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
self.worker = None

self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600"))
Expand Down Expand Up @@ -77,10 +83,14 @@ def process_sampling_results(self):
try:
rank_id = 0
is_blocking = True
get_output(self.output_tokens, rank_id, is_blocking)
if self.is_speculate_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking)
else:
get_output(self.output_tokens, rank_id, is_blocking)

if self.output_tokens[0, 0] == -2:
continue

self._process_batch_output()
except Exception as e:
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
Expand All @@ -101,14 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
with open(result_file, "a") as f:
f.write("{}\n".format(result))

def _get_single_result(self, i, task_id, token_id, task):
def _get_single_result(self, i, task_id, token_ids, task):
"""
processing single results
Args:
i (int): batch index
task_id (str): task id
token_id (int): token id
token_ids (list): token id
task (dict): task information
Returns:
Expand All @@ -121,7 +131,7 @@ def _get_single_result(self, i, task_id, token_id, task):
result = {
"req_id": task_id,
"is_end": 0,
"token_ids": [token_id],
"token_ids": token_ids,
"send_idx": self.tokens_counter[task_id],
"inference_time_cost": inference_time_cost,
"infer_seed": task["infer_seed"],
Expand All @@ -137,26 +147,31 @@ def _get_single_result(self, i, task_id, token_id, task):
result[key] = str(task[key])

# fill some extra information
if token_id in task["eos_token_ids"]:
result["is_end"] = 1
result["token_ids"] = []
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
result["tokens_all_ids"] = self.all_tokens[i]

info_dict = {}
info_dict["req_id"] = task["req_id"]
info_dict["input_token_num"] = len(task["input_ids"])
info_dict["output_token_num"] = len(self.all_tokens[i])
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
task["preprocess_end_time"])
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
task["schedule_start_time"])
info_dict["inference_time_cost"] = task["inference_time_cost"]
info_dict["version"] = "4.6"
info_dict["timestamp"] = time.time()
monitor_logger.info(f"{info_dict}")
result["token_ids"] = []
for token_id in token_ids:
if token_id in task["eos_token_ids"]:
result["is_end"] = 1
result["token_ids"] = []
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
result["tokens_all_ids"] = self.all_tokens[i]

info_dict = {}
info_dict["req_id"] = task["req_id"]
info_dict["input_token_num"] = len(task["input_ids"])
info_dict["output_token_num"] = len(self.all_tokens[i])
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
task["preprocess_end_time"])
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
task["schedule_start_time"])
info_dict["inference_time_cost"] = task["inference_time_cost"]
info_dict["version"] = "OpenSource"
info_dict["timestamp"] = time.time()
monitor_logger.info(f"{info_dict}")
break
else:
result["token_ids"].append(token_id)

return result

Expand All @@ -177,33 +192,42 @@ def _process_batch_output(self):
"""
tokens = self.output_tokens.numpy()
batch = self.output_tokens[1, 0]
tokens = tokens[2:batch + 2]
if not self.is_speculate_decoding:
tokens = tokens[2:batch + 2]
else:
accept_num = tokens[2:batch + 2]

batch_result = list()
exist_finished_task = False
for i in range(batch):
if self.resource_manager.stop_flags[i]:
continue

token_id = int(tokens[i, 0])
if token_id < 0:
if not self.is_speculate_decoding:
token_ids = [int(tokens[i, 0])]
else:
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i, 0], 0].tolist()

if any(token_id < 0 for token_id in token_ids):
continue

task = self.resource_manager.tasks_list[i]

task_id = task["req_id"]
result = self._get_single_result(i, task_id, token_id, task)

self.tokens_counter[task_id] += 1
if token_id not in task["eos_token_ids"]:
self.all_tokens[i].append(token_id)

self.number_of_output_tokens += 1
if token_id in task["eos_token_ids"]:
self._recycle_resources(task_id, i, task)
model_server_logger.info("req_id: {0} finished".format(task_id))
model_server_logger.info(f"{self.resource_manager.info()}")
exist_finished_task = True
result = self._get_single_result(i, task_id, token_ids, task)

for token_id in token_ids:
self.tokens_counter[task_id] += 1
if token_id not in task["eos_token_ids"]:
self.all_tokens[i].append(token_id)

self.number_of_output_tokens += 1
if token_id in task["eos_token_ids"]:
self._recycle_resources(task_id, i, task)
model_server_logger.info("req_id: {0} finished".format(task_id))
model_server_logger.info(f"{self.resource_manager.info()}")
exist_finished_task = True
break
batch_result.append(result)

self.postprocess(batch_result, exist_finished_task)
Expand All @@ -228,7 +252,10 @@ def process_sampling_results(self):
while self._is_running:
try:
rank_id = 0
get_output(self.output_tokens, rank_id, self._is_blocking)
if self.is_speculate_decoding:
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
else:
get_output(self.output_tokens, rank_id, self._is_blocking)

if self.output_tokens[0, 0] == -2:
continue
Expand Down

0 comments on commit 19e752c

Please sign in to comment.