Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support speculate decoding #2541

Merged
merged 13 commits into from
Jan 10, 2025
29 changes: 29 additions & 0 deletions llm/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from paddlenlp.generation import GenerationConfig
from server.utils import model_server_logger
from dataclasses import dataclass


class Config:
Expand Down Expand Up @@ -203,6 +204,27 @@ def get_model_config(self):
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
return model_config_json

def get_speculate_config(self):
"""
get speculate_decoding related config

Returns:
SpeculateConfig: the speculate related config
"""
speculate_config = SpeculateConfig()
model_cfg = self.get_model_config()
if model_cfg.get("speculate_method", "None") != "None":
speculate_config.speculate_method = str(model_cfg["speculate_method"])
speculate_config.speculate_max_draft_token_num = model_cfg[
"speculate_max_draft_token_num"]
speculate_config.speculate_max_ngram_size = model_cfg[
"speculate_max_ngram_size"]

if speculate_config.speculate_method not in ["None", "inference_with_reference"]:
model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}")

return speculate_config

def read_from_config(self):
"""
reset model config from json file
Expand Down Expand Up @@ -234,3 +256,10 @@ def get_unique_name(self, name):

def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)


@dataclass
class SpeculateConfig:
speculate_method: str = "None"
speculate_max_draft_token_num: int = 1
speculate_max_ngram_size: int = 1
47 changes: 46 additions & 1 deletion llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from paddlenlp_ops import step_paddle
from server.data.processor import DataProcessor
from server.engine.config import Config
from paddlenlp.experimental.transformers import InferenceWithReferenceProposer
from server.utils import get_logger
from task_queue_manager import TaskQueueManager

Expand All @@ -46,6 +47,8 @@ def __init__(self, args):

self.config = Config()
self.model_cfg = self.config.get_model_config()
self.speculate_config = self.config.get_speculate_config()
self.is_speculate_decoding = self.speculate_config.speculate_method != "None"
self.format_print_configuration()

self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
Expand All @@ -67,6 +70,17 @@ def __init__(self, args):
self.cache_kvs = {}
self.init_inputs()

if self.is_speculate_decoding:
logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.')
if self.speculate_config.speculate_method == "inference_with_reference":
self.proposer = InferenceWithReferenceProposer(
self.speculate_config.speculate_max_draft_token_num,
self.speculate_config.speculate_max_ngram_size,
self.args.max_batch_size,
self.args.max_seq_len)
else:
self.proposer = None

self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)

model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
Expand Down Expand Up @@ -263,6 +277,18 @@ def init_inputs(self):
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
# speculate decoding input
if self.is_speculate_decoding:
self.share_inputs["accept_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
)
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
self.share_inputs["draft_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
)
self.share_inputs["actual_draft_token_num"] = paddle.full(
shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32"
)

def dy_input_preprocess(self, tasks):
"""
Expand Down Expand Up @@ -318,10 +344,21 @@ def dy_input_preprocess(self, tasks):
task["stop_seqs_len"], dtype="int32")
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
task["stop_seqs"], dtype="int64")

if self.is_speculate_decoding:
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num])

def step_cuda(self, seq_lens_this_time):
"""
step cuda
"""
# whether speculate decoding
if self.is_speculate_decoding:
speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1
else:
speculate_step_token_num = 0

step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],
Expand All @@ -334,7 +371,8 @@ def step_cuda(self, seq_lens_this_time):
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
speculate_step_token_num)

def initialize_engine_ready_check_flag(self):
"""
Expand Down Expand Up @@ -459,6 +497,13 @@ def run(self):
time.sleep(0.001)
continue

if self.proposer is not None:
self.proposer.run(
self.share_inputs,
real_batch_size=seq_lens_this_time.shape[0],
seq_lens_this_time=seq_lens_this_time,
)

self.infer_engine.predictor.run()
self.share_inputs['infer_seed'].add_(infer_seed_increment)
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
Expand Down
111 changes: 69 additions & 42 deletions llm/server/server/engine/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from datetime import datetime

import numpy as np
from paddlenlp_ops import get_output
from paddlenlp_ops import get_output, speculate_get_output
from server.utils import datetime_diff, model_server_logger, monitor_logger
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ


class TokenProcessor(object):
Expand All @@ -37,7 +38,12 @@ def __init__(self, cfg):
self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]

self.tokens_counter = Counter()
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")

self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None"
if self.is_speculate_decoding:
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
else:
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
self.worker = None

self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600"))
Expand Down Expand Up @@ -77,10 +83,14 @@ def process_sampling_results(self):
try:
rank_id = 0
is_blocking = True
get_output(self.output_tokens, rank_id, is_blocking)
if self.is_speculate_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking)
else:
get_output(self.output_tokens, rank_id, is_blocking)

if self.output_tokens[0, 0] == -2:
continue

self._process_batch_output()
except Exception as e:
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
Expand All @@ -101,14 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
with open(result_file, "a") as f:
f.write("{}\n".format(result))

def _get_single_result(self, i, task_id, token_id, task):
def _get_single_result(self, i, task_id, token_ids, task):
"""
processing single results

Args:
i (int): batch index
task_id (str): task id
token_id (int): token id
token_ids (list): token id
task (dict): task information

Returns:
Expand All @@ -121,7 +131,7 @@ def _get_single_result(self, i, task_id, token_id, task):
result = {
"req_id": task_id,
"is_end": 0,
"token_ids": [token_id],
"token_ids": token_ids,
"send_idx": self.tokens_counter[task_id],
"inference_time_cost": inference_time_cost,
"infer_seed": task["infer_seed"],
Expand All @@ -137,26 +147,31 @@ def _get_single_result(self, i, task_id, token_id, task):
result[key] = str(task[key])

# fill some extra information
if token_id in task["eos_token_ids"]:
result["is_end"] = 1
result["token_ids"] = []
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
result["tokens_all_ids"] = self.all_tokens[i]

info_dict = {}
info_dict["req_id"] = task["req_id"]
info_dict["input_token_num"] = len(task["input_ids"])
info_dict["output_token_num"] = len(self.all_tokens[i])
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
task["preprocess_end_time"])
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
task["schedule_start_time"])
info_dict["inference_time_cost"] = task["inference_time_cost"]
info_dict["version"] = "4.6"
info_dict["timestamp"] = time.time()
monitor_logger.info(f"{info_dict}")
result["token_ids"] = []
for token_id in token_ids:
if token_id in task["eos_token_ids"]:
result["is_end"] = 1
result["token_ids"] = []
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
result["tokens_all_ids"] = self.all_tokens[i]

info_dict = {}
info_dict["req_id"] = task["req_id"]
info_dict["input_token_num"] = len(task["input_ids"])
info_dict["output_token_num"] = len(self.all_tokens[i])
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
task["preprocess_end_time"])
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
task["schedule_start_time"])
info_dict["inference_time_cost"] = task["inference_time_cost"]
info_dict["version"] = "4.6"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个version写成opensource吧,先前内部写的一个4.6是临时性的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

info_dict["timestamp"] = time.time()
monitor_logger.info(f"{info_dict}")
break
else:
result["token_ids"].append(token_id)

return result

Expand All @@ -177,33 +192,42 @@ def _process_batch_output(self):
"""
tokens = self.output_tokens.numpy()
batch = self.output_tokens[1, 0]
tokens = tokens[2:batch + 2]
if not self.is_speculate_decoding:
tokens = tokens[2:batch + 2]
else:
accept_num = tokens[2:batch + 2]

batch_result = list()
exist_finished_task = False
for i in range(batch):
if self.resource_manager.stop_flags[i]:
continue

token_id = int(tokens[i, 0])
if token_id < 0:
if not self.is_speculate_decoding:
token_ids = [int(tokens[i, 0])]
else:
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i, 0], 0].tolist()

if any(token_id < 0 for token_id in token_ids):
continue

task = self.resource_manager.tasks_list[i]

task_id = task["req_id"]
result = self._get_single_result(i, task_id, token_id, task)

self.tokens_counter[task_id] += 1
if token_id not in task["eos_token_ids"]:
self.all_tokens[i].append(token_id)

self.number_of_output_tokens += 1
if token_id in task["eos_token_ids"]:
self._recycle_resources(task_id, i, task)
model_server_logger.info("req_id: {0} finished".format(task_id))
model_server_logger.info(f"{self.resource_manager.info()}")
exist_finished_task = True
result = self._get_single_result(i, task_id, token_ids, task)

for token_id in token_ids:
self.tokens_counter[task_id] += 1
if token_id not in task["eos_token_ids"]:
self.all_tokens[i].append(token_id)

self.number_of_output_tokens += 1
if token_id in task["eos_token_ids"]:
self._recycle_resources(task_id, i, task)
model_server_logger.info("req_id: {0} finished".format(task_id))
model_server_logger.info(f"{self.resource_manager.info()}")
exist_finished_task = True
break
batch_result.append(result)

self.postprocess(batch_result, exist_finished_task)
Expand All @@ -228,7 +252,10 @@ def process_sampling_results(self):
while self._is_running:
try:
rank_id = 0
get_output(self.output_tokens, rank_id, self._is_blocking)
if self.is_speculate_decoding:
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
else:
get_output(self.output_tokens, rank_id, self._is_blocking)

if self.output_tokens[0, 0] == -2:
continue
Expand Down