diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 6f0e1964e2..3b9a88f0c9 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -19,6 +19,7 @@ from paddlenlp.generation import GenerationConfig from server.utils import model_server_logger +from dataclasses import dataclass class Config: @@ -203,6 +204,27 @@ def get_model_config(self): model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8')) return model_config_json + def get_speculate_config(self): + """ + get speculate_decoding related config + + Returns: + SpeculateConfig: the speculate related config + """ + speculate_config = SpeculateConfig() + model_cfg = self.get_model_config() + if model_cfg.get("speculate_method", "None") != "None": + speculate_config.speculate_method = str(model_cfg["speculate_method"]) + speculate_config.speculate_max_draft_token_num = model_cfg[ + "speculate_max_draft_token_num"] + speculate_config.speculate_max_ngram_size = model_cfg[ + "speculate_max_ngram_size"] + + if speculate_config.speculate_method not in ["None", "inference_with_reference"]: + model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}") + + return speculate_config + def read_from_config(self): """ reset model config from json file @@ -234,3 +256,10 @@ def get_unique_name(self, name): def __str__(self) -> str: return json.dumps(self.__dict__, indent=4) + + +@dataclass +class SpeculateConfig: + speculate_method: str = "None" + speculate_max_draft_token_num: int = 1 + speculate_max_ngram_size: int = 1 \ No newline at end of file diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index ac006bf4ae..2641e88994 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -29,6 +29,7 @@ from paddlenlp_ops import step_paddle from server.data.processor import DataProcessor from server.engine.config import Config +from paddlenlp.experimental.transformers import InferenceWithReferenceProposer from server.utils import get_logger from task_queue_manager import TaskQueueManager @@ -46,6 +47,8 @@ def __init__(self, args): self.config = Config() self.model_cfg = self.config.get_model_config() + self.speculate_config = self.config.get_speculate_config() + self.is_speculate_decoding = self.speculate_config.speculate_method != "None" self.format_print_configuration() self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"]) @@ -67,6 +70,17 @@ def __init__(self, args): self.cache_kvs = {} self.init_inputs() + if self.is_speculate_decoding: + logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.') + if self.speculate_config.speculate_method == "inference_with_reference": + self.proposer = InferenceWithReferenceProposer( + self.speculate_config.speculate_max_draft_token_num, + self.speculate_config.speculate_max_ngram_size, + self.args.max_batch_size, + self.args.max_seq_len) + else: + self.proposer = None + self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port) model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}") @@ -263,6 +277,18 @@ def init_inputs(self): shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64") self.share_inputs["ori_seq_lens_encoder"] = paddle.full( shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32") + # speculate decoding input + if self.is_speculate_decoding: + self.share_inputs["accept_tokens"] = paddle.full( + shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32") + self.share_inputs["draft_tokens"] = paddle.full( + shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + self.share_inputs["actual_draft_token_num"] = paddle.full( + shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32" + ) def dy_input_preprocess(self, tasks): """ @@ -318,10 +344,21 @@ def dy_input_preprocess(self, tasks): task["stop_seqs_len"], dtype="int32") self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array( task["stop_seqs"], dtype="int64") + + if self.is_speculate_decoding: + self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1]) + self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num]) + def step_cuda(self, seq_lens_this_time): """ step cuda """ + # whether speculate decoding + if self.is_speculate_decoding: + speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1 + else: + speculate_step_token_num = 0 + step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, self.share_inputs['step_seq_lens_encoder'], self.share_inputs['seq_lens_encoder'], @@ -334,7 +371,8 @@ def step_cuda(self, seq_lens_this_time): self.share_inputs['free_list'], self.share_inputs['free_list_len'], self.share_inputs['input_ids'], self.share_inputs['pre_ids'], self.share_inputs['step_idx'], self.share_inputs['next_tokens'], - self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id) + self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id, + speculate_step_token_num) def initialize_engine_ready_check_flag(self): """ @@ -459,6 +497,13 @@ def run(self): time.sleep(0.001) continue + if self.proposer is not None: + self.proposer.run( + self.share_inputs, + real_batch_size=seq_lens_this_time.shape[0], + seq_lens_this_time=seq_lens_this_time, + ) + self.infer_engine.predictor.run() self.share_inputs['infer_seed'].add_(infer_seed_increment) self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py index 507a3d43bd..1213a9384b 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/token_processor.py @@ -20,8 +20,9 @@ from datetime import datetime import numpy as np -from paddlenlp_ops import get_output +from paddlenlp_ops import get_output, speculate_get_output from server.utils import datetime_diff, model_server_logger, monitor_logger +from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ class TokenProcessor(object): @@ -37,7 +38,12 @@ def __init__(self, cfg): self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)] self.tokens_counter = Counter() - self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") + + self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None" + if self.is_speculate_decoding: + self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + else: + self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") self.worker = None self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600")) @@ -77,10 +83,14 @@ def process_sampling_results(self): try: rank_id = 0 is_blocking = True - get_output(self.output_tokens, rank_id, is_blocking) + if self.is_speculate_decoding: + speculate_get_output(self.output_tokens, rank_id, is_blocking) + else: + get_output(self.output_tokens, rank_id, is_blocking) if self.output_tokens[0, 0] == -2: continue + self._process_batch_output() except Exception as e: model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) @@ -101,14 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False): with open(result_file, "a") as f: f.write("{}\n".format(result)) - def _get_single_result(self, i, task_id, token_id, task): + def _get_single_result(self, i, task_id, token_ids, task): """ processing single results Args: i (int): batch index task_id (str): task id - token_id (int): token id + token_ids (list): token id task (dict): task information Returns: @@ -121,7 +131,7 @@ def _get_single_result(self, i, task_id, token_id, task): result = { "req_id": task_id, "is_end": 0, - "token_ids": [token_id], + "token_ids": token_ids, "send_idx": self.tokens_counter[task_id], "inference_time_cost": inference_time_cost, "infer_seed": task["infer_seed"], @@ -137,26 +147,31 @@ def _get_single_result(self, i, task_id, token_id, task): result[key] = str(task[key]) # fill some extra information - if token_id in task["eos_token_ids"]: - result["is_end"] = 1 - result["token_ids"] = [] - result["tokens_all_num"] = len(self.all_tokens[i]) + 1 - result["tokens_all_ids"] = self.all_tokens[i] - - info_dict = {} - info_dict["req_id"] = task["req_id"] - info_dict["input_token_num"] = len(task["input_ids"]) - info_dict["output_token_num"] = len(self.all_tokens[i]) - if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"): - info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"], - task["preprocess_end_time"]) - if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"): - info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"], - task["schedule_start_time"]) - info_dict["inference_time_cost"] = task["inference_time_cost"] - info_dict["version"] = "4.6" - info_dict["timestamp"] = time.time() - monitor_logger.info(f"{info_dict}") + result["token_ids"] = [] + for token_id in token_ids: + if token_id in task["eos_token_ids"]: + result["is_end"] = 1 + result["token_ids"] = [] + result["tokens_all_num"] = len(self.all_tokens[i]) + 1 + result["tokens_all_ids"] = self.all_tokens[i] + + info_dict = {} + info_dict["req_id"] = task["req_id"] + info_dict["input_token_num"] = len(task["input_ids"]) + info_dict["output_token_num"] = len(self.all_tokens[i]) + if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"): + info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"], + task["preprocess_end_time"]) + if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"): + info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"], + task["schedule_start_time"]) + info_dict["inference_time_cost"] = task["inference_time_cost"] + info_dict["version"] = "OpenSource" + info_dict["timestamp"] = time.time() + monitor_logger.info(f"{info_dict}") + break + else: + result["token_ids"].append(token_id) return result @@ -177,7 +192,10 @@ def _process_batch_output(self): """ tokens = self.output_tokens.numpy() batch = self.output_tokens[1, 0] - tokens = tokens[2:batch + 2] + if not self.is_speculate_decoding: + tokens = tokens[2:batch + 2] + else: + accept_num = tokens[2:batch + 2] batch_result = list() exist_finished_task = False @@ -185,25 +203,31 @@ def _process_batch_output(self): if self.resource_manager.stop_flags[i]: continue - token_id = int(tokens[i, 0]) - if token_id < 0: + if not self.is_speculate_decoding: + token_ids = [int(tokens[i, 0])] + else: + token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i, 0], 0].tolist() + + if any(token_id < 0 for token_id in token_ids): continue task = self.resource_manager.tasks_list[i] task_id = task["req_id"] - result = self._get_single_result(i, task_id, token_id, task) - - self.tokens_counter[task_id] += 1 - if token_id not in task["eos_token_ids"]: - self.all_tokens[i].append(token_id) - - self.number_of_output_tokens += 1 - if token_id in task["eos_token_ids"]: - self._recycle_resources(task_id, i, task) - model_server_logger.info("req_id: {0} finished".format(task_id)) - model_server_logger.info(f"{self.resource_manager.info()}") - exist_finished_task = True + result = self._get_single_result(i, task_id, token_ids, task) + + for token_id in token_ids: + self.tokens_counter[task_id] += 1 + if token_id not in task["eos_token_ids"]: + self.all_tokens[i].append(token_id) + + self.number_of_output_tokens += 1 + if token_id in task["eos_token_ids"]: + self._recycle_resources(task_id, i, task) + model_server_logger.info("req_id: {0} finished".format(task_id)) + model_server_logger.info(f"{self.resource_manager.info()}") + exist_finished_task = True + break batch_result.append(result) self.postprocess(batch_result, exist_finished_task) @@ -228,7 +252,10 @@ def process_sampling_results(self): while self._is_running: try: rank_id = 0 - get_output(self.output_tokens, rank_id, self._is_blocking) + if self.is_speculate_decoding: + speculate_get_output(self.output_tokens, rank_id, self._is_blocking) + else: + get_output(self.output_tokens, rank_id, self._is_blocking) if self.output_tokens[0, 0] == -2: continue