From 47aacb506246f3e441fc76682a8ab8a13e50f5b1 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 29 Oct 2024 02:26:19 +0000 Subject: [PATCH 01/13] add speculate_decoding framework --- llm/server/server/engine/config.py | 5 + llm/server/server/engine/infer.py | 106 +++++++++++++-- llm/server/server/engine/proposers.py | 95 +++++++++++++ llm/server/server/engine/token_processor.py | 143 +++++++++++++++++++- 4 files changed, 331 insertions(+), 18 deletions(-) create mode 100644 llm/server/server/engine/proposers.py diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 6f0e1964e2..4508c7e0c6 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -91,6 +91,11 @@ def read_from_env(self): self.block_size = int(env.get("BLOCK_SIZE", 64)) self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0)) self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0)) + + # speculate decoding config + self.speculate_method = str(env.get("SPECULATE_METHOD", None)) + self.speculate_max_draft_token_num = int(os.getenv("SPECULATE_MAX_DRAFT_TOKEN_NUM", 5)) + self.speculate_max_ngram_size = int(os.getenv("SPECULATE_MAX_NGRAM_SIZE", 2)) # infer config self.max_batch_size = int(env.get("BATCH_SIZE", 50)) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index ac006bf4ae..51d33d3bec 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -26,9 +26,10 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet from paddlenlp.trl.llm_utils import get_rotary_position_embedding -from paddlenlp_ops import step_paddle +from paddlenlp_ops import step_paddle, speculate_step_paddle from server.data.processor import DataProcessor from server.engine.config import Config +from server.engine.proposers import InferenceWithReferenceProposer from server.utils import get_logger from task_queue_manager import TaskQueueManager @@ -67,6 +68,15 @@ def __init__(self, args): self.cache_kvs = {} self.init_inputs() + # whether use speculate decoding + if self.config.speculate_method is not None and self.config.speculate_method == "inference_with_reference": + self.proposer = InferenceWithReferenceProposer( + self.config.speculate_max_draft_token_num, + self.config.speculate_max_ngram_size, + self.args.max_batch_size) + else: + self.proposer = None + self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port) model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}") @@ -263,6 +273,20 @@ def init_inputs(self): shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64") self.share_inputs["ori_seq_lens_encoder"] = paddle.full( shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32") + # speculate decoding input + if self.config.speculate_method is not None: + self.share_inputs["input_ids_cpu"] = paddle.full( + shape=[self.args.max_batch_size, self.args.max_seq_len], fill_value=1, dtype='int64').cpu() + self.share_inputs["accept_tokens"] = paddle.full( + shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32") + self.share_inputs["draft_tokens"] = paddle.full( + shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + self.share_inputs["actual_draft_token_num"] = paddle.full( + shape=[self.args.max_batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32" + ) def dy_input_preprocess(self, tasks): """ @@ -318,23 +342,46 @@ def dy_input_preprocess(self, tasks): task["stop_seqs_len"], dtype="int32") self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array( task["stop_seqs"], dtype="int64") + if self.proposer is not None: + if self.config.speculate_method == "inference_with_reference": + speculate_update_input_ids_cpu(self.share_inputs['input_ids_cpu'], task['input_ids'], idx, self.args.max_seq_len) + self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.config.speculate_max_draft_token_num + 1]) + self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.config.speculate_max_draft_token_num]) + self.proposer.update(idx, length) + def step_cuda(self, seq_lens_this_time): """ step cuda """ - step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, - self.share_inputs['step_seq_lens_encoder'], - self.share_inputs['seq_lens_encoder'], - self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"], - self.share_inputs['encoder_block_lens'], - self.share_inputs["is_block_step"], self.share_inputs['step_block_list'], - self.share_inputs['step_lens'], self.share_inputs['recover_block_list'], - self.share_inputs['recover_lens'], self.share_inputs['need_block_list'], - self.share_inputs['need_block_len'], self.share_inputs['used_list_len'], - self.share_inputs['free_list'], self.share_inputs['free_list_len'], - self.share_inputs['input_ids'], self.share_inputs['pre_ids'], - self.share_inputs['step_idx'], self.share_inputs['next_tokens'], - self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id) + if self.config.speculate_method is None: + step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, + self.share_inputs['step_seq_lens_encoder'], + self.share_inputs['seq_lens_encoder'], + self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"], + self.share_inputs['encoder_block_lens'], + self.share_inputs["is_block_step"], self.share_inputs['step_block_list'], + self.share_inputs['step_lens'], self.share_inputs['recover_block_list'], + self.share_inputs['recover_lens'], self.share_inputs['need_block_list'], + self.share_inputs['need_block_len'], self.share_inputs['used_list_len'], + self.share_inputs['free_list'], self.share_inputs['free_list_len'], + self.share_inputs['input_ids'], self.share_inputs['pre_ids'], + self.share_inputs['step_idx'], self.share_inputs['next_tokens'], + self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id) + else: + speculate_step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, + self.share_inputs['step_seq_lens_encoder'], + self.share_inputs['seq_lens_encoder'], + self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"], + self.share_inputs['encoder_block_lens'], + self.share_inputs["is_block_step"], self.share_inputs['step_block_list'], + self.share_inputs['step_lens'], self.share_inputs['recover_block_list'], + self.share_inputs['recover_lens'], self.share_inputs['need_block_list'], + self.share_inputs['need_block_len'], self.share_inputs['used_list_len'], + self.share_inputs['free_list'], self.share_inputs['free_list_len'], + self.share_inputs['input_ids'], self.share_inputs['pre_ids'], + self.share_inputs['step_idx'], self.share_inputs['next_tokens'], + self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id, + self.config.speculate_max_draft_token_num) def initialize_engine_ready_check_flag(self): """ @@ -434,6 +481,9 @@ def run(self): self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time tasks, read_finish = self.infer_queue.get() + logger.info(f'tasks: {tasks}') + logger.info(f'read_finish: {read_finish}') + if read_finish: flag_broadcast_array[0] = 0 @@ -442,7 +492,7 @@ def run(self): real_bsz = int(bsz) req_dicts.extend(req_dict) logger.info( - f'rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}' + f'req_dict: {req_dict} rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}' ) self.dy_input_preprocess(req_dicts) @@ -459,10 +509,36 @@ def run(self): time.sleep(0.001) continue + if self.proposer is not None: + logger.info("start run proposer") + logger.info(f'before draft_tokens: {self.share_inputs["draft_tokens"]}') + logger.info(f'before accept_tokens: {self.share_inputs["accept_tokens"]}') + + self.proposer.run( + self.share_inputs, + real_batch_size=self.args.max_batch_size, + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + ) + logger.info(f'after draft_tokens: {self.share_inputs["draft_tokens"]}') + logger.info("finish run proposer") + logger.info(f'input_ids: {self.share_inputs["input_ids"]}') + logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}') + logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}') + logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}') + logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}') + logger.info(f'step_idx: {self.share_inputs["step_idx"]}') + logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}') + logger.info(f'before block_tables: {self.share_inputs["block_tables"]}') + self.infer_engine.predictor.run() + logger.info(f'after accept_tokens: {self.share_inputs["accept_tokens"]}') + logger.info(f'after accept_num: {self.share_inputs["accept_num"]}') + logger.info(f'after block_tables: {self.share_inputs["block_tables"]}') + self.share_inputs['infer_seed'].add_(infer_seed_increment) self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED if self.free_list_len > 0: + logger.info(f'free_list_len > 0') self.step_cuda(seq_lens_this_time) diff --git a/llm/server/server/engine/proposers.py b/llm/server/server/engine/proposers.py new file mode 100644 index 0000000000..68d2b41c9e --- /dev/null +++ b/llm/server/server/engine/proposers.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod + +import paddle +from paddlenlp_ops import ngram_match + + +class Proposer(ABC): + """ + Abstract base class for all proposers that can be used in the speculative decoding framework. + The subclasses of this class must implement the run method to get the draft tokens that are + generated by the proposer. + """ + + def __init__(self, **kwargs): + pass + + @abstractmethod + def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): + """ + Get the draft tokens that are generated by the proposer. + """ + raise NotImplementedError() + + +class InferenceWithReferenceProposer(Proposer): + """ + InferenceWithReference(https://arxiv.org/pdf/2304.04487) is one of the speculative decoding method. + It match tokens in the input and output as draft tokens. + """ + + def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int): + """ + Args: + max_draft_token_num (int): + Maximum number of tokens a proposer can generate at one time. + The hyperparameter of k in the paper. + max_ngram_size (int): + The maximum size of the window used to match inputs and outputs. + The hyperparameter of n in the paper. + max_batch_size (int): + The maximum batch size. + """ + super().__init__() + self.max_ngram_size = max_ngram_size + self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu() + self.max_batch_size = max_batch_size + self.max_draft_token_num = max_draft_token_num + # self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu() + + def update(self, bid: int, seq_len: int): + """ + Used when inserting a new query to update the length of the input_ids. + """ + self.input_ids_len[bid] = seq_len + + def run(self, share_inputs: dict[str, paddle.Tensor], **kargs): + """ + Use ngram_match to get draft tokens from the input and output. + """ + draft_tokens = share_inputs["draft_tokens"].cpu() + seq_lens_this_time = kargs["seq_lens_this_time"].cpu() + seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu() + seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu() + ngram_match( + share_inputs["input_ids_cpu"], + self.input_ids_len.cpu(), + share_inputs["pre_ids"].cpu(), + share_inputs["step_idx"].cpu(), + share_inputs["actual_draft_token_num"].cpu(), + draft_tokens, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + kargs["real_batch_size"], + self.max_ngram_size, + self.max_draft_token_num, + ) + share_inputs["draft_tokens"][:] = draft_tokens.cuda() + share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() + kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py index 507a3d43bd..1b2d6d596f 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/token_processor.py @@ -20,9 +20,11 @@ from datetime import datetime import numpy as np -from paddlenlp_ops import get_output +from paddlenlp_ops import get_output, speculate_get_output from server.utils import datetime_diff, model_server_logger, monitor_logger +SPECULATE_MAX_BSZ = 256 +MAX_DRAFT_TOKEN_NUM = 6 class TokenProcessor(object): """ @@ -37,7 +39,11 @@ def __init__(self, cfg): self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)] self.tokens_counter = Counter() - self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") + + if self.cfg.speculate_method is not None: + self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + MAX_DRAFT_TOKEN_NUM + 2], fill_value=2, dtype="int64") + else: + self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") self.worker = None self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600")) @@ -65,7 +71,10 @@ def run(self): if self.worker is not None: raise Exception("Worker is already running!") - self.worker = threading.Thread(target=self.process_sampling_results, args=()) + if self.cfg.speculate_method is not None: + self.worker = threading.Thread(target=self.process_speculate_results, args=()) + else: + self.worker = threading.Thread(target=self.process_sampling_results, args=()) self.worker.daemon = True self.worker.start() @@ -85,6 +94,22 @@ def process_sampling_results(self): except Exception as e: model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) + def process_speculate_results(self): + """ + read tokens from paddle inference engine and process + """ + while True: + try: + rank_id = 0 + is_blocking = True + speculate_get_output(self.output_tokens, rank_id, is_blocking) + + if self.output_tokens[0] == -2: + continue + self._process_speculate_output() + except Exception as e: + model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) + def postprocess(self, batch_result, exist_finished_task=False): """ single post-processing function @@ -160,6 +185,70 @@ def _get_single_result(self, i, task_id, token_id, task): return result + def _get_speculate_result(self, i, task_id, token_ids, task): + """ + processing single speculate results + + Args: + i (int): batch index + task_id (str): task id + token_ids (int): tokens id + task (dict): task information + + Returns: + dict: result + """ + inference_time_cost = time.time() - task["inference_start_time"] + task["inference_time_cost"] = inference_time_cost + task["tokens_all_num"] = len(self.all_tokens[i]) + task["inference_current_step_time"] = datetime.now() + result = { + "req_id": task_id, + "is_end": 0, + "token_ids": token_ids, + "send_idx": self.tokens_counter[task_id], + "inference_time_cost": inference_time_cost, + "infer_seed": task["infer_seed"], + "return_all_tokens": task.get("return_all_tokens", False), + } + + # get benchmark msg + if task.get("benchmark"): + keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time", + "inference_start_time", "inference_current_step_time"] + for key in keys: + if key in task: + result[key] = str(task[key]) + + + # fill some extra information when generate eos token + result["token_ids"] = [] + for token_id in token_ids: + if token_id in task["eos_token_ids"]: + result["is_end"] = 1 + result["tokens_all_num"] = len(self.all_tokens[i]) + 1 + result["tokens_all_ids"] = self.all_tokens[i] + + info_dict = {} + info_dict["req_id"] = task["req_id"] + info_dict["input_token_num"] = len(task["input_ids"]) + info_dict["output_token_num"] = len(self.all_tokens[i]) + if "preprocess_start_time" in task and "preprocess_end_time" in task: + info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"], + task["preprocess_end_time"]) + if "preprocess_end_time" in task and "schedule_start_time" in task: + info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"], + task["schedule_start_time"]) + info_dict["inference_time_cost"] = task["inference_time_cost"] + info_dict["version"] = "4.6" + info_dict["timestamp"] = time.time() + monitor_logger.info(f"{info_dict}") + break + else: + result["token_ids"].append(token_id) + + return result + def _recycle_resources(self, task_id, index, task): """ recycle resources @@ -208,6 +297,54 @@ def _process_batch_output(self): self.postprocess(batch_result, exist_finished_task) + def _process_speculate_output(self): + """ + batch post-processing function + """ + tokens = self.output_tokens.numpy() + batch = self.output_tokens[1] + output_token_msg_id = int(self.output_tokens[0]) + accept_num = tokens[2 : batch + 2] + batch_result = list() + # 用于判断当前此批结果中是否存在已完成的任务 + exist_finished_task = False + prefill_mode = False + tasks_prefill = [] + + for i in range(batch): + # 对应task如若已结束,跳过 + if self.resource_manager.stop_flags[i]: + continue + + token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist() + # 跳过非法token + if len(token_ids) == 0 or token_ids[-1] == 0: + continue + + task = self.resource_manager.tasks_list[i] + + # 将会移至data server解决 + task_id = task["req_id"] + result = self._get_speculate_result(i, task_id, token_ids, task) + + for token_id in token_ids: + self.tokens_counter[task_id] += 1 + if token_id not in task["eos_token_ids"]: + self.all_tokens[i].append(token_id) + + self.number_of_output_tokens += 1 + # 生成结束符时,重置相应变量 + if token_id in task["eos_token_ids"]: + self._recycle_resources(task_id, i, task) + model_server_logger.info("req_id: {0} finished".format(task_id)) + model_server_logger.info(f"{self.resource_manager.info()}") + exist_finished_task = True + break + batch_result.append(result) + + # 后处理函数调用 + self.postprocess(batch_result, exist_finished_task) + class WarmUpTokenProcessor(TokenProcessor): """ From e52155f07ead29e2f699dad1b27b7f381cfd70d8 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 17 Dec 2024 09:37:05 +0000 Subject: [PATCH 02/13] v1.0 align accuracy --- llm/server/server/engine/config.py | 4 +- llm/server/server/engine/infer.py | 51 +++++++++++---------- llm/server/server/engine/proposers.py | 37 ++++++++------- llm/server/server/engine/token_processor.py | 18 +++++++- 4 files changed, 62 insertions(+), 48 deletions(-) diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 4508c7e0c6..25106eb833 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -93,9 +93,7 @@ def read_from_env(self): self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0)) # speculate decoding config - self.speculate_method = str(env.get("SPECULATE_METHOD", None)) - self.speculate_max_draft_token_num = int(os.getenv("SPECULATE_MAX_DRAFT_TOKEN_NUM", 5)) - self.speculate_max_ngram_size = int(os.getenv("SPECULATE_MAX_NGRAM_SIZE", 2)) + self.speculate_method = str(os.getenv("SPECULATE_METHOD", None)) # infer config self.max_batch_size = int(env.get("BATCH_SIZE", 50)) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 51d33d3bec..0f804b393f 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -69,11 +69,16 @@ def __init__(self, args): self.init_inputs() # whether use speculate decoding - if self.config.speculate_method is not None and self.config.speculate_method == "inference_with_reference": - self.proposer = InferenceWithReferenceProposer( - self.config.speculate_max_draft_token_num, - self.config.speculate_max_ngram_size, - self.args.max_batch_size) + logger.info(f'speculate_method: {self.config.speculate_method}') + if self.config.speculate_method is not None: + if self.config.speculate_method == "inference_with_reference": + self.proposer = InferenceWithReferenceProposer( + self.model_cfg["speculate_max_draft_token_num"], + self.model_cfg["speculate_max_ngram_size"], + self.args.max_batch_size, + self.args.max_seq_len) + else: + raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.') else: self.proposer = None @@ -274,18 +279,17 @@ def init_inputs(self): self.share_inputs["ori_seq_lens_encoder"] = paddle.full( shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32") # speculate decoding input + logger.info(f'Speculative method: {self.config.speculate_method}') if self.config.speculate_method is not None: - self.share_inputs["input_ids_cpu"] = paddle.full( - shape=[self.args.max_batch_size, self.args.max_seq_len], fill_value=1, dtype='int64').cpu() self.share_inputs["accept_tokens"] = paddle.full( - shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" + shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64" ) self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32") self.share_inputs["draft_tokens"] = paddle.full( - shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" + shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64" ) self.share_inputs["actual_draft_token_num"] = paddle.full( - shape=[self.args.max_batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32" + shape=[self.args.max_batch_size], fill_value=self.model_cfg["speculate_max_draft_token_num"], dtype="int32" ) def dy_input_preprocess(self, tasks): @@ -344,10 +348,8 @@ def dy_input_preprocess(self, tasks): task["stop_seqs"], dtype="int64") if self.proposer is not None: if self.config.speculate_method == "inference_with_reference": - speculate_update_input_ids_cpu(self.share_inputs['input_ids_cpu'], task['input_ids'], idx, self.args.max_seq_len) - self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.config.speculate_max_draft_token_num + 1]) - self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.config.speculate_max_draft_token_num]) - self.proposer.update(idx, length) + self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1]) + self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]]) def step_cuda(self, seq_lens_this_time): """ @@ -381,7 +383,7 @@ def step_cuda(self, seq_lens_this_time): self.share_inputs['input_ids'], self.share_inputs['pre_ids'], self.share_inputs['step_idx'], self.share_inputs['next_tokens'], self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id, - self.config.speculate_max_draft_token_num) + self.model_cfg["speculate_max_draft_token_num"]) def initialize_engine_ready_check_flag(self): """ @@ -512,7 +514,6 @@ def run(self): if self.proposer is not None: logger.info("start run proposer") logger.info(f'before draft_tokens: {self.share_inputs["draft_tokens"]}') - logger.info(f'before accept_tokens: {self.share_inputs["accept_tokens"]}') self.proposer.run( self.share_inputs, @@ -521,19 +522,19 @@ def run(self): ) logger.info(f'after draft_tokens: {self.share_inputs["draft_tokens"]}') logger.info("finish run proposer") - logger.info(f'input_ids: {self.share_inputs["input_ids"]}') - logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}') - logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}') - logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}') - logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}') - logger.info(f'step_idx: {self.share_inputs["step_idx"]}') - logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}') - logger.info(f'before block_tables: {self.share_inputs["block_tables"]}') + # logger.info(f'input_ids: {self.share_inputs["input_ids"]}') + # logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}') + # logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}') + # logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}') + # logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}') + # logger.info(f'step_idx: {self.share_inputs["step_idx"]}') + # logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}') + # logger.info(f'before block_tables: {self.share_inputs["block_tables"]}') self.infer_engine.predictor.run() logger.info(f'after accept_tokens: {self.share_inputs["accept_tokens"]}') logger.info(f'after accept_num: {self.share_inputs["accept_num"]}') - logger.info(f'after block_tables: {self.share_inputs["block_tables"]}') + # logger.info(f'after block_tables: {self.share_inputs["block_tables"]}') self.share_inputs['infer_seed'].add_(infer_seed_increment) self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED diff --git a/llm/server/server/engine/proposers.py b/llm/server/server/engine/proposers.py index 68d2b41c9e..f2a1d2b0a5 100644 --- a/llm/server/server/engine/proposers.py +++ b/llm/server/server/engine/proposers.py @@ -16,7 +16,6 @@ from abc import ABC, abstractmethod import paddle -from paddlenlp_ops import ngram_match class Proposer(ABC): @@ -43,7 +42,7 @@ class InferenceWithReferenceProposer(Proposer): It match tokens in the input and output as draft tokens. """ - def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int): + def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs): """ Args: max_draft_token_num (int): @@ -54,34 +53,33 @@ def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size The hyperparameter of n in the paper. max_batch_size (int): The maximum batch size. + max_seq_len (int): + The maximum sequence length. """ super().__init__() self.max_ngram_size = max_ngram_size self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu() + self.input_ids_cpu = paddle.zeros(shape=[max_batch_size, max_seq_len], dtype="int64").cpu() self.max_batch_size = max_batch_size self.max_draft_token_num = max_draft_token_num - # self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu() - def update(self, bid: int, seq_len: int): - """ - Used when inserting a new query to update the length of the input_ids. - """ - self.input_ids_len[bid] = seq_len - - def run(self, share_inputs: dict[str, paddle.Tensor], **kargs): + def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): """ Use ngram_match to get draft tokens from the input and output. """ - draft_tokens = share_inputs["draft_tokens"].cpu() + draft_tokens = model_inputs["draft_tokens"].cpu() seq_lens_this_time = kargs["seq_lens_this_time"].cpu() - seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu() - seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu() + seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu() + seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu() + + from paddlenlp_ops import ngram_match + ngram_match( - share_inputs["input_ids_cpu"], + self.input_ids_cpu, self.input_ids_len.cpu(), - share_inputs["pre_ids"].cpu(), - share_inputs["step_idx"].cpu(), - share_inputs["actual_draft_token_num"].cpu(), + model_inputs["pre_ids"].cpu(), + model_inputs["step_idx"].cpu(), + model_inputs["actual_draft_token_num"].cpu(), draft_tokens, seq_lens_this_time, seq_lens_encoder, @@ -90,6 +88,7 @@ def run(self, share_inputs: dict[str, paddle.Tensor], **kargs): self.max_ngram_size, self.max_draft_token_num, ) - share_inputs["draft_tokens"][:] = draft_tokens.cuda() - share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() + + model_inputs["draft_tokens"][:] = draft_tokens.cuda() + model_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py index 1b2d6d596f..88a09c9005 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/token_processor.py @@ -41,7 +41,7 @@ def __init__(self, cfg): self.tokens_counter = Counter() if self.cfg.speculate_method is not None: - self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + MAX_DRAFT_TOKEN_NUM + 2], fill_value=2, dtype="int64") + self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64") else: self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") self.worker = None @@ -302,6 +302,7 @@ def _process_speculate_output(self): batch post-processing function """ tokens = self.output_tokens.numpy() + model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}") batch = self.output_tokens[1] output_token_msg_id = int(self.output_tokens[0]) accept_num = tokens[2 : batch + 2] @@ -373,6 +374,21 @@ def process_sampling_results(self): except Exception as e: model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) + def process_speculate_results(self): + """ + read tokens from paddle inference engine and process + """ + while self._is_running: + try: + rank_id = 0 + speculate_get_output(self.output_tokens, rank_id, self._is_blocking) + + if self.output_tokens[0] == -2: + continue + self._process_speculate_output() + except Exception as e: + model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) + def stop(self): """ stop warm up thread From d5b6499a94b622c2fce277db3c3d221b1b4869ef Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 17 Dec 2024 09:53:03 +0000 Subject: [PATCH 03/13] remove debug log --- llm/server/server/engine/infer.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 0f804b393f..d592c304ae 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -69,7 +69,6 @@ def __init__(self, args): self.init_inputs() # whether use speculate decoding - logger.info(f'speculate_method: {self.config.speculate_method}') if self.config.speculate_method is not None: if self.config.speculate_method == "inference_with_reference": self.proposer = InferenceWithReferenceProposer( @@ -279,7 +278,6 @@ def init_inputs(self): self.share_inputs["ori_seq_lens_encoder"] = paddle.full( shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32") # speculate decoding input - logger.info(f'Speculative method: {self.config.speculate_method}') if self.config.speculate_method is not None: self.share_inputs["accept_tokens"] = paddle.full( shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64" @@ -512,34 +510,16 @@ def run(self): continue if self.proposer is not None: - logger.info("start run proposer") - logger.info(f'before draft_tokens: {self.share_inputs["draft_tokens"]}') - self.proposer.run( self.share_inputs, real_batch_size=self.args.max_batch_size, seq_lens_this_time=self.share_inputs["seq_lens_this_time"], ) - logger.info(f'after draft_tokens: {self.share_inputs["draft_tokens"]}') - logger.info("finish run proposer") - # logger.info(f'input_ids: {self.share_inputs["input_ids"]}') - # logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}') - # logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}') - # logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}') - # logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}') - # logger.info(f'step_idx: {self.share_inputs["step_idx"]}') - # logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}') - # logger.info(f'before block_tables: {self.share_inputs["block_tables"]}') self.infer_engine.predictor.run() - logger.info(f'after accept_tokens: {self.share_inputs["accept_tokens"]}') - logger.info(f'after accept_num: {self.share_inputs["accept_num"]}') - # logger.info(f'after block_tables: {self.share_inputs["block_tables"]}') - self.share_inputs['infer_seed'].add_(infer_seed_increment) self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED if self.free_list_len > 0: - logger.info(f'free_list_len > 0') self.step_cuda(seq_lens_this_time) From 389015bf0417b92e9227462d1e4974466fa67349 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 17 Dec 2024 09:55:12 +0000 Subject: [PATCH 04/13] remove debug log --- llm/server/server/engine/infer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index d592c304ae..8542204cb8 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -481,9 +481,6 @@ def run(self): self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time tasks, read_finish = self.infer_queue.get() - logger.info(f'tasks: {tasks}') - logger.info(f'read_finish: {read_finish}') - if read_finish: flag_broadcast_array[0] = 0 From 08877a985d84753caf807e85d7508b85d584988c Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 17 Dec 2024 13:12:51 +0000 Subject: [PATCH 05/13] refactor code --- llm/server/server/engine/config.py | 3 --- llm/server/server/engine/infer.py | 21 +++++++++++---------- llm/server/server/engine/token_processor.py | 13 ++++++------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 25106eb833..6f0e1964e2 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -91,9 +91,6 @@ def read_from_env(self): self.block_size = int(env.get("BLOCK_SIZE", 64)) self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0)) self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0)) - - # speculate decoding config - self.speculate_method = str(os.getenv("SPECULATE_METHOD", None)) # infer config self.max_batch_size = int(env.get("BATCH_SIZE", 50)) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 8542204cb8..ec2816c08c 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -47,6 +47,7 @@ def __init__(self, args): self.config = Config() self.model_cfg = self.config.get_model_config() + self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None self.format_print_configuration() self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"]) @@ -68,16 +69,16 @@ def __init__(self, args): self.cache_kvs = {} self.init_inputs() - # whether use speculate decoding - if self.config.speculate_method is not None: - if self.config.speculate_method == "inference_with_reference": + if self.is_speculate_decoding: + logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.') + if self.model_cfg["speculate_method"] == "inference_with_reference": self.proposer = InferenceWithReferenceProposer( self.model_cfg["speculate_max_draft_token_num"], self.model_cfg["speculate_max_ngram_size"], self.args.max_batch_size, self.args.max_seq_len) else: - raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.') + raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.') else: self.proposer = None @@ -278,7 +279,7 @@ def init_inputs(self): self.share_inputs["ori_seq_lens_encoder"] = paddle.full( shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32") # speculate decoding input - if self.config.speculate_method is not None: + if self.is_speculate_decoding: self.share_inputs["accept_tokens"] = paddle.full( shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64" ) @@ -344,16 +345,16 @@ def dy_input_preprocess(self, tasks): task["stop_seqs_len"], dtype="int32") self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array( task["stop_seqs"], dtype="int64") - if self.proposer is not None: - if self.config.speculate_method == "inference_with_reference": - self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1]) - self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]]) + + if self.is_speculate_decoding: + self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1]) + self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]]) def step_cuda(self, seq_lens_this_time): """ step cuda """ - if self.config.speculate_method is None: + if not self.is_speculate_decoding: step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, self.share_inputs['step_seq_lens_encoder'], self.share_inputs['seq_lens_encoder'], diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py index 88a09c9005..9abae1e993 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/token_processor.py @@ -22,9 +22,8 @@ import numpy as np from paddlenlp_ops import get_output, speculate_get_output from server.utils import datetime_diff, model_server_logger, monitor_logger +from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ -SPECULATE_MAX_BSZ = 256 -MAX_DRAFT_TOKEN_NUM = 6 class TokenProcessor(object): """ @@ -40,8 +39,9 @@ def __init__(self, cfg): self.tokens_counter = Counter() - if self.cfg.speculate_method is not None: - self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64") + self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None + if self.is_speculate_decoding: + self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64") else: self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") self.worker = None @@ -71,7 +71,7 @@ def run(self): if self.worker is not None: raise Exception("Worker is already running!") - if self.cfg.speculate_method is not None: + if self.is_speculate_decoding: self.worker = threading.Thread(target=self.process_speculate_results, args=()) else: self.worker = threading.Thread(target=self.process_sampling_results, args=()) @@ -302,7 +302,6 @@ def _process_speculate_output(self): batch post-processing function """ tokens = self.output_tokens.numpy() - model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}") batch = self.output_tokens[1] output_token_msg_id = int(self.output_tokens[0]) accept_num = tokens[2 : batch + 2] @@ -317,7 +316,7 @@ def _process_speculate_output(self): if self.resource_manager.stop_flags[i]: continue - token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist() + token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist() # 跳过非法token if len(token_ids) == 0 or token_ids[-1] == 0: continue From e3bc5aac37f22e508e8ef50a4bc3a64def9e5420 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 17 Dec 2024 13:14:21 +0000 Subject: [PATCH 06/13] fix typo --- llm/server/server/engine/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index ec2816c08c..9b1db3d8f0 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -490,7 +490,7 @@ def run(self): real_bsz = int(bsz) req_dicts.extend(req_dict) logger.info( - f'req_dict: {req_dict} rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}' + f'rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}' ) self.dy_input_preprocess(req_dicts) From ce3c09d65208b6eff91d8bf2bc415ef053858f78 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Sun, 22 Dec 2024 12:39:07 +0000 Subject: [PATCH 07/13] import proposer from nlp --- llm/server/server/engine/infer.py | 3 +- llm/server/server/engine/proposers.py | 94 --------------------------- 2 files changed, 2 insertions(+), 95 deletions(-) delete mode 100644 llm/server/server/engine/proposers.py diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 9b1db3d8f0..f00cd05066 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -29,7 +29,7 @@ from paddlenlp_ops import step_paddle, speculate_step_paddle from server.data.processor import DataProcessor from server.engine.config import Config -from server.engine.proposers import InferenceWithReferenceProposer +from paddlenlp.experimental.transformers import InferenceWithReferenceProposer from server.utils import get_logger from task_queue_manager import TaskQueueManager @@ -518,6 +518,7 @@ def run(self): self.share_inputs['infer_seed'].add_(infer_seed_increment) self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED if self.free_list_len > 0: + logger.info('You got into step CUDA!!!') self.step_cuda(seq_lens_this_time) diff --git a/llm/server/server/engine/proposers.py b/llm/server/server/engine/proposers.py deleted file mode 100644 index f2a1d2b0a5..0000000000 --- a/llm/server/server/engine/proposers.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from abc import ABC, abstractmethod - -import paddle - - -class Proposer(ABC): - """ - Abstract base class for all proposers that can be used in the speculative decoding framework. - The subclasses of this class must implement the run method to get the draft tokens that are - generated by the proposer. - """ - - def __init__(self, **kwargs): - pass - - @abstractmethod - def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): - """ - Get the draft tokens that are generated by the proposer. - """ - raise NotImplementedError() - - -class InferenceWithReferenceProposer(Proposer): - """ - InferenceWithReference(https://arxiv.org/pdf/2304.04487) is one of the speculative decoding method. - It match tokens in the input and output as draft tokens. - """ - - def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs): - """ - Args: - max_draft_token_num (int): - Maximum number of tokens a proposer can generate at one time. - The hyperparameter of k in the paper. - max_ngram_size (int): - The maximum size of the window used to match inputs and outputs. - The hyperparameter of n in the paper. - max_batch_size (int): - The maximum batch size. - max_seq_len (int): - The maximum sequence length. - """ - super().__init__() - self.max_ngram_size = max_ngram_size - self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu() - self.input_ids_cpu = paddle.zeros(shape=[max_batch_size, max_seq_len], dtype="int64").cpu() - self.max_batch_size = max_batch_size - self.max_draft_token_num = max_draft_token_num - - def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): - """ - Use ngram_match to get draft tokens from the input and output. - """ - draft_tokens = model_inputs["draft_tokens"].cpu() - seq_lens_this_time = kargs["seq_lens_this_time"].cpu() - seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu() - seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu() - - from paddlenlp_ops import ngram_match - - ngram_match( - self.input_ids_cpu, - self.input_ids_len.cpu(), - model_inputs["pre_ids"].cpu(), - model_inputs["step_idx"].cpu(), - model_inputs["actual_draft_token_num"].cpu(), - draft_tokens, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - kargs["real_batch_size"], - self.max_ngram_size, - self.max_draft_token_num, - ) - - model_inputs["draft_tokens"][:] = draft_tokens.cuda() - model_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() - kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() From 934e8d846e42cb9f31192129a39f9372f6a4554d Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 23 Dec 2024 08:47:50 +0000 Subject: [PATCH 08/13] update --- llm/server/server/engine/config.py | 29 +++++++++++- llm/server/server/engine/infer.py | 72 +++++++++++++----------------- 2 files changed, 59 insertions(+), 42 deletions(-) diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 6f0e1964e2..a5e757d1bf 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -19,7 +19,7 @@ from paddlenlp.generation import GenerationConfig from server.utils import model_server_logger - +from dataclasses import dataclass class Config: """ @@ -203,6 +203,26 @@ def get_model_config(self): model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8')) return model_config_json + def get_speculate_config(self): + """ + get speculate_decoding related config + + Returns: + SpeculateConfig: the speculate related config + """ + speculate_config = SpeculateConfig() + if self.model_cfg.get("speculate_method") is not None: + speculate_config.speculate_method = self.model_cfg["speculate_method"] + speculate_config.speculate_max_draft_token_num = self.model_cfg[ + "speculate_max_draft_token_num"] + speculate_config.speculate_max_ngram_size = self.model_cfg[ + "speculate_max_ngram_size"] + + if speculate_config.speculate_method is not in ["none", "inference_with_reference"]: + model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}") + + return speculate_config + def read_from_config(self): """ reset model config from json file @@ -234,3 +254,10 @@ def get_unique_name(self, name): def __str__(self) -> str: return json.dumps(self.__dict__, indent=4) + + +@dataclass +class SpeculateConfig: + speculate_method: str = None + speculate_max_draft_token_num: int = 1 + speculate_max_ngram_size: int = 1 \ No newline at end of file diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index f00cd05066..29cd0a9717 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -26,7 +26,7 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet from paddlenlp.trl.llm_utils import get_rotary_position_embedding -from paddlenlp_ops import step_paddle, speculate_step_paddle +from paddlenlp_ops import step_paddle from server.data.processor import DataProcessor from server.engine.config import Config from paddlenlp.experimental.transformers import InferenceWithReferenceProposer @@ -47,7 +47,8 @@ def __init__(self, args): self.config = Config() self.model_cfg = self.config.get_model_config() - self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None + self.speculate_config = self.config.get_speculate_config() + self.is_speculate_decoding = self.speculate_config.speculate_method is not None self.format_print_configuration() self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"]) @@ -70,15 +71,13 @@ def __init__(self, args): self.init_inputs() if self.is_speculate_decoding: - logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.') - if self.model_cfg["speculate_method"] == "inference_with_reference": + logger.info(f'Using speculating decoding, method: {self.speculate_config.speculate_method}.') + if self.speculate_config.speculate_method == "inference_with_reference": self.proposer = InferenceWithReferenceProposer( - self.model_cfg["speculate_max_draft_token_num"], - self.model_cfg["speculate_max_ngram_size"], + self.speculate_config.speculate_max_draft_token_num, + self.speculate_config.speculate_max_ngram_size, self.args.max_batch_size, self.args.max_seq_len) - else: - raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.') else: self.proposer = None @@ -281,14 +280,14 @@ def init_inputs(self): # speculate decoding input if self.is_speculate_decoding: self.share_inputs["accept_tokens"] = paddle.full( - shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64" + shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" ) self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32") self.share_inputs["draft_tokens"] = paddle.full( - shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64" + shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64" ) self.share_inputs["actual_draft_token_num"] = paddle.full( - shape=[self.args.max_batch_size], fill_value=self.model_cfg["speculate_max_draft_token_num"], dtype="int32" + shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32" ) def dy_input_preprocess(self, tasks): @@ -347,42 +346,33 @@ def dy_input_preprocess(self, tasks): task["stop_seqs"], dtype="int64") if self.is_speculate_decoding: - self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1]) - self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]]) + self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1]) + self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num]) def step_cuda(self, seq_lens_this_time): """ step cuda """ - if not self.is_speculate_decoding: - step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, - self.share_inputs['step_seq_lens_encoder'], - self.share_inputs['seq_lens_encoder'], - self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"], - self.share_inputs['encoder_block_lens'], - self.share_inputs["is_block_step"], self.share_inputs['step_block_list'], - self.share_inputs['step_lens'], self.share_inputs['recover_block_list'], - self.share_inputs['recover_lens'], self.share_inputs['need_block_list'], - self.share_inputs['need_block_len'], self.share_inputs['used_list_len'], - self.share_inputs['free_list'], self.share_inputs['free_list_len'], - self.share_inputs['input_ids'], self.share_inputs['pre_ids'], - self.share_inputs['step_idx'], self.share_inputs['next_tokens'], - self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id) + # whether speculate decoding + if self.is_speculate_decoding: + speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1 else: - speculate_step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, - self.share_inputs['step_seq_lens_encoder'], - self.share_inputs['seq_lens_encoder'], - self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"], - self.share_inputs['encoder_block_lens'], - self.share_inputs["is_block_step"], self.share_inputs['step_block_list'], - self.share_inputs['step_lens'], self.share_inputs['recover_block_list'], - self.share_inputs['recover_lens'], self.share_inputs['need_block_list'], - self.share_inputs['need_block_len'], self.share_inputs['used_list_len'], - self.share_inputs['free_list'], self.share_inputs['free_list_len'], - self.share_inputs['input_ids'], self.share_inputs['pre_ids'], - self.share_inputs['step_idx'], self.share_inputs['next_tokens'], - self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id, - self.model_cfg["speculate_max_draft_token_num"]) + speculate_step_token_num = 0 + + step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time, + self.share_inputs['step_seq_lens_encoder'], + self.share_inputs['seq_lens_encoder'], + self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"], + self.share_inputs['encoder_block_lens'], + self.share_inputs["is_block_step"], self.share_inputs['step_block_list'], + self.share_inputs['step_lens'], self.share_inputs['recover_block_list'], + self.share_inputs['recover_lens'], self.share_inputs['need_block_list'], + self.share_inputs['need_block_len'], self.share_inputs['used_list_len'], + self.share_inputs['free_list'], self.share_inputs['free_list_len'], + self.share_inputs['input_ids'], self.share_inputs['pre_ids'], + self.share_inputs['step_idx'], self.share_inputs['next_tokens'], + self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id + speculate_step_token_num) def initialize_engine_ready_check_flag(self): """ From ae57a2b0680e5c9a9c466fa2bdb6bebcd18da970 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 23 Dec 2024 08:50:19 +0000 Subject: [PATCH 09/13] fix typo --- llm/server/server/engine/infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 29cd0a9717..d97e4934b6 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -508,7 +508,6 @@ def run(self): self.share_inputs['infer_seed'].add_(infer_seed_increment) self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED if self.free_list_len > 0: - logger.info('You got into step CUDA!!!') self.step_cuda(seq_lens_this_time) From 5c42585c2ad35aa435990d86725d1d4f458f0d74 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 23 Dec 2024 13:03:21 +0000 Subject: [PATCH 10/13] update --- llm/server/server/engine/config.py | 13 +- llm/server/server/engine/infer.py | 6 +- llm/server/server/engine/token_processor.py | 183 ++++---------------- 3 files changed, 39 insertions(+), 163 deletions(-) diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index a5e757d1bf..ba40baed9a 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -211,14 +211,15 @@ def get_speculate_config(self): SpeculateConfig: the speculate related config """ speculate_config = SpeculateConfig() - if self.model_cfg.get("speculate_method") is not None: - speculate_config.speculate_method = self.model_cfg["speculate_method"] - speculate_config.speculate_max_draft_token_num = self.model_cfg[ + model_cfg = self.get_model_config() + if model_cfg.get("speculate_method", "None") != "None": + speculate_config.speculate_method = str(model_cfg["speculate_method"]) + speculate_config.speculate_max_draft_token_num = model_cfg[ "speculate_max_draft_token_num"] - speculate_config.speculate_max_ngram_size = self.model_cfg[ + speculate_config.speculate_max_ngram_size = model_cfg[ "speculate_max_ngram_size"] - if speculate_config.speculate_method is not in ["none", "inference_with_reference"]: + if speculate_config.speculate_method not in ["None", "inference_with_reference"]: model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}") return speculate_config @@ -258,6 +259,6 @@ def __str__(self) -> str: @dataclass class SpeculateConfig: - speculate_method: str = None + speculate_method: str = "None" speculate_max_draft_token_num: int = 1 speculate_max_ngram_size: int = 1 \ No newline at end of file diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index d97e4934b6..04c85d497d 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -48,7 +48,7 @@ def __init__(self, args): self.config = Config() self.model_cfg = self.config.get_model_config() self.speculate_config = self.config.get_speculate_config() - self.is_speculate_decoding = self.speculate_config.speculate_method is not None + self.is_speculate_decoding = self.speculate_config.speculate_method != "None" self.format_print_configuration() self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"]) @@ -71,7 +71,7 @@ def __init__(self, args): self.init_inputs() if self.is_speculate_decoding: - logger.info(f'Using speculating decoding, method: {self.speculate_config.speculate_method}.') + logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.') if self.speculate_config.speculate_method == "inference_with_reference": self.proposer = InferenceWithReferenceProposer( self.speculate_config.speculate_max_draft_token_num, @@ -371,7 +371,7 @@ def step_cuda(self, seq_lens_this_time): self.share_inputs['free_list'], self.share_inputs['free_list_len'], self.share_inputs['input_ids'], self.share_inputs['pre_ids'], self.share_inputs['step_idx'], self.share_inputs['next_tokens'], - self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id + self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id, speculate_step_token_num) def initialize_engine_ready_check_flag(self): diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py index 9abae1e993..d919fb07be 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/token_processor.py @@ -39,9 +39,9 @@ def __init__(self, cfg): self.tokens_counter = Counter() - self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None + self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None" if self.is_speculate_decoding: - self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64") + self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64") else: self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") self.worker = None @@ -71,10 +71,7 @@ def run(self): if self.worker is not None: raise Exception("Worker is already running!") - if self.is_speculate_decoding: - self.worker = threading.Thread(target=self.process_speculate_results, args=()) - else: - self.worker = threading.Thread(target=self.process_sampling_results, args=()) + self.worker = threading.Thread(target=self.process_sampling_results, args=()) self.worker.daemon = True self.worker.start() @@ -86,30 +83,18 @@ def process_sampling_results(self): try: rank_id = 0 is_blocking = True - get_output(self.output_tokens, rank_id, is_blocking) + if self.is_speculate_decoding: + speculate_get_output(self.output_tokens, rank_id, is_blocking) + else: + get_output(self.output_tokens, rank_id, is_blocking) if self.output_tokens[0, 0] == -2: continue + self._process_batch_output() except Exception as e: model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) - def process_speculate_results(self): - """ - read tokens from paddle inference engine and process - """ - while True: - try: - rank_id = 0 - is_blocking = True - speculate_get_output(self.output_tokens, rank_id, is_blocking) - - if self.output_tokens[0] == -2: - continue - self._process_speculate_output() - except Exception as e: - model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) - def postprocess(self, batch_result, exist_finished_task=False): """ single post-processing function @@ -126,73 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False): with open(result_file, "a") as f: f.write("{}\n".format(result)) - def _get_single_result(self, i, task_id, token_id, task): + def _get_single_result(self, i, task_id, token_ids, task): """ processing single results Args: i (int): batch index task_id (str): task id - token_id (int): token id - task (dict): task information - - Returns: - dict: result - """ - inference_time_cost = time.time() - task["inference_start_time"] - task["inference_time_cost"] = inference_time_cost - task["tokens_all_num"] = len(self.all_tokens[i]) - task["inference_current_step_time"] = datetime.now() - result = { - "req_id": task_id, - "is_end": 0, - "token_ids": [token_id], - "send_idx": self.tokens_counter[task_id], - "inference_time_cost": inference_time_cost, - "infer_seed": task["infer_seed"], - "return_all_tokens": task.get("return_all_tokens", False), - } - - # get benchmark msg - if task.get("benchmark"): - keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time", - "inference_start_time", "inference_current_step_time"] - for key in keys: - if key in task: - result[key] = str(task[key]) - - # fill some extra information - if token_id in task["eos_token_ids"]: - result["is_end"] = 1 - result["token_ids"] = [] - result["tokens_all_num"] = len(self.all_tokens[i]) + 1 - result["tokens_all_ids"] = self.all_tokens[i] - - info_dict = {} - info_dict["req_id"] = task["req_id"] - info_dict["input_token_num"] = len(task["input_ids"]) - info_dict["output_token_num"] = len(self.all_tokens[i]) - if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"): - info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"], - task["preprocess_end_time"]) - if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"): - info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"], - task["schedule_start_time"]) - info_dict["inference_time_cost"] = task["inference_time_cost"] - info_dict["version"] = "4.6" - info_dict["timestamp"] = time.time() - monitor_logger.info(f"{info_dict}") - - return result - - def _get_speculate_result(self, i, task_id, token_ids, task): - """ - processing single speculate results - - Args: - i (int): batch index - task_id (str): task id - token_ids (int): tokens id + token_ids (list): token id task (dict): task information Returns: @@ -220,12 +146,12 @@ def _get_speculate_result(self, i, task_id, token_ids, task): if key in task: result[key] = str(task[key]) - - # fill some extra information when generate eos token + # fill some extra information result["token_ids"] = [] for token_id in token_ids: if token_id in task["eos_token_ids"]: result["is_end"] = 1 + result["token_ids"] = [] result["tokens_all_num"] = len(self.all_tokens[i]) + 1 result["tokens_all_ids"] = self.all_tokens[i] @@ -233,10 +159,10 @@ def _get_speculate_result(self, i, task_id, token_ids, task): info_dict["req_id"] = task["req_id"] info_dict["input_token_num"] = len(task["input_ids"]) info_dict["output_token_num"] = len(self.all_tokens[i]) - if "preprocess_start_time" in task and "preprocess_end_time" in task: + if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"): info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"], task["preprocess_end_time"]) - if "preprocess_end_time" in task and "schedule_start_time" in task: + if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"): info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"], task["schedule_start_time"]) info_dict["inference_time_cost"] = task["inference_time_cost"] @@ -266,7 +192,10 @@ def _process_batch_output(self): """ tokens = self.output_tokens.numpy() batch = self.output_tokens[1, 0] - tokens = tokens[2:batch + 2] + if not self.is_speculate_decoding: + tokens = tokens[2:batch + 2] + else: + accept_num = tokens[2:batch + 2] batch_result = list() exist_finished_task = False @@ -274,66 +203,25 @@ def _process_batch_output(self): if self.resource_manager.stop_flags[i]: continue - token_id = int(tokens[i, 0]) - if token_id < 0: + if not self.is_speculate_decoding: + token_ids = [int(tokens[i, 0])] + else: + token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i, 0], 0].tolist() + + if any(token_id < 0 for token_id in token_ids): continue task = self.resource_manager.tasks_list[i] task_id = task["req_id"] - result = self._get_single_result(i, task_id, token_id, task) - - self.tokens_counter[task_id] += 1 - if token_id not in task["eos_token_ids"]: - self.all_tokens[i].append(token_id) + result = self._get_single_result(i, task_id, token_ids, task) - self.number_of_output_tokens += 1 - if token_id in task["eos_token_ids"]: - self._recycle_resources(task_id, i, task) - model_server_logger.info("req_id: {0} finished".format(task_id)) - model_server_logger.info(f"{self.resource_manager.info()}") - exist_finished_task = True - batch_result.append(result) - - self.postprocess(batch_result, exist_finished_task) - - def _process_speculate_output(self): - """ - batch post-processing function - """ - tokens = self.output_tokens.numpy() - batch = self.output_tokens[1] - output_token_msg_id = int(self.output_tokens[0]) - accept_num = tokens[2 : batch + 2] - batch_result = list() - # 用于判断当前此批结果中是否存在已完成的任务 - exist_finished_task = False - prefill_mode = False - tasks_prefill = [] - - for i in range(batch): - # 对应task如若已结束,跳过 - if self.resource_manager.stop_flags[i]: - continue - - token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist() - # 跳过非法token - if len(token_ids) == 0 or token_ids[-1] == 0: - continue - - task = self.resource_manager.tasks_list[i] - - # 将会移至data server解决 - task_id = task["req_id"] - result = self._get_speculate_result(i, task_id, token_ids, task) - for token_id in token_ids: self.tokens_counter[task_id] += 1 if token_id not in task["eos_token_ids"]: self.all_tokens[i].append(token_id) self.number_of_output_tokens += 1 - # 生成结束符时,重置相应变量 if token_id in task["eos_token_ids"]: self._recycle_resources(task_id, i, task) model_server_logger.info("req_id: {0} finished".format(task_id)) @@ -342,7 +230,6 @@ def _process_speculate_output(self): break batch_result.append(result) - # 后处理函数调用 self.postprocess(batch_result, exist_finished_task) @@ -365,7 +252,10 @@ def process_sampling_results(self): while self._is_running: try: rank_id = 0 - get_output(self.output_tokens, rank_id, self._is_blocking) + if self.is_speculate_decoding: + speculate_get_output(self.output_tokens, rank_id, self._is_blocking) + else: + get_output(self.output_tokens, rank_id, self._is_blocking) if self.output_tokens[0, 0] == -2: continue @@ -373,21 +263,6 @@ def process_sampling_results(self): except Exception as e: model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) - def process_speculate_results(self): - """ - read tokens from paddle inference engine and process - """ - while self._is_running: - try: - rank_id = 0 - speculate_get_output(self.output_tokens, rank_id, self._is_blocking) - - if self.output_tokens[0] == -2: - continue - self._process_speculate_output() - except Exception as e: - model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) - def stop(self): """ stop warm up thread From ed5f65a6942c819fd770e6ceff2b1e9bd29518ad Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 23 Dec 2024 13:05:05 +0000 Subject: [PATCH 11/13] fix typo --- llm/server/server/engine/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index ba40baed9a..3b9a88f0c9 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -21,6 +21,7 @@ from server.utils import model_server_logger from dataclasses import dataclass + class Config: """ initial configuration From 99d09210498b5aaeb7e4486cae1cbd54afd7624c Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Thu, 9 Jan 2025 07:16:13 +0000 Subject: [PATCH 12/13] update --- llm/server/server/engine/infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 04c85d497d..2641e88994 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -500,8 +500,8 @@ def run(self): if self.proposer is not None: self.proposer.run( self.share_inputs, - real_batch_size=self.args.max_batch_size, - seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + real_batch_size=seq_lens_this_time.shape[0], + seq_lens_this_time=seq_lens_this_time, ) self.infer_engine.predictor.run() From fe35dc5d776d0c24bc23fc9a0ae1438080a4194d Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Fri, 10 Jan 2025 05:42:59 +0000 Subject: [PATCH 13/13] update version of info_dict --- llm/server/server/engine/token_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py index d919fb07be..1213a9384b 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/token_processor.py @@ -166,7 +166,7 @@ def _get_single_result(self, i, task_id, token_ids, task): info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"], task["schedule_start_time"]) info_dict["inference_time_cost"] = task["inference_time_cost"] - info_dict["version"] = "4.6" + info_dict["version"] = "OpenSource" info_dict["timestamp"] = time.time() monitor_logger.info(f"{info_dict}") break