diff --git a/.gitignore b/.gitignore
index 2eb6e5e1a2..56e85e18b3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,4 +50,20 @@ python/fastdeploy/code_version.py
log.txt
serving/build
serving/build.encrypt
-serving/build.encrypt.auth
\ No newline at end of file
+serving/build.encrypt.auth
+output
+res
+tmp
+log
+nohup.out
+llm/server/__pycache__
+llm/server/data/__pycache__
+llm/server/engine/__pycache__
+llm/server/http_server/__pycache__
+llm/server/log/
+llm/client/build/
+llm/client/dist/
+llm/client/fastdeploy_client.egg-info/
+llm/client/fastdeploy_client/tests/log/
+*.pyc
+*.log
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2387c0d25b..6da9864159 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: a11d9314b22d8f8c7556443875b731ef05965464
+ rev: ed714747d7acbc5790b171702bb012af3b0fe145
hooks:
- id: check-merge-conflict
- id: check-symlinks
@@ -9,8 +9,8 @@ repos:
- id: detect-private-key
- id: check-symlinks
- id: check-added-large-files
-- repo: local
+- repo: local
hooks:
- id: copyright_checker
name: copyright_checker
diff --git a/llm/.dockerignore b/llm/.dockerignore
new file mode 100644
index 0000000000..96dbf3cb73
--- /dev/null
+++ b/llm/.dockerignore
@@ -0,0 +1,11 @@
+README.md
+requirements-dev.txt
+pyproject.toml
+Makefile
+
+dockerfiles/
+docs/
+server/__pycache__
+server/http_server
+server/engine
+server/data
diff --git a/llm/README.md b/llm/README.md
new file mode 100644
index 0000000000..6475ae2f96
--- /dev/null
+++ b/llm/README.md
@@ -0,0 +1,39 @@
+
+
飞桨大模型高性能部署工具FastDeploy
+
+*FastDeploy基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计的解决方案。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。*
+
+# 快速开始
+
+ 基于预编译镜像部署,本节以 Meta-Llama-3-8B-Instruct-A8W8C8 为例,更多模型请参考[LLaMA](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/llama.md)、[Qwen](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/qwen.md)、[Mixtral](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/mixtral.md), 更细致的模型推理、量化教程可以参考[大模型推理教程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/inference.md):
+
+ ```
+ # 下载模型
+ wget https://paddle-qa.bj.bcebos.com/inference_model/Meta-Llama-3-8B-Instruct-A8W8C8.tar
+ mkdir Llama-3-8B-A8W8C8 && tar -xf Meta-Llama-3-8B-Instruct-A8W8C8.tar -C Llama-3-8B-A8W8C8
+
+ # 挂载模型文件
+ export MODEL_PATH=${PWD}/Llama-3-8B-A8W8C8
+
+ docker run --gpus all --shm-size 5G --network=host \
+ -v ${MODEL_PATH}:/models/ \
+ -dit registry.baidubce.com/paddlepaddle/fastdeploy:llm-serving-cuda123-cudnn9-v1.0 \
+ bash -c 'export USE_CACHE_KV_INT8=1 && cd /opt/output/Serving && bash start_server.sh; exec bash'
+ ```
+
+ 等待服务启动成功(服务初次启动大概需要40s),可以通过以下命令测试:
+
+ ```
+ curl 127.0.0.1:9965/v1/chat/completions \
+ -H 'Content-Type: application/json' \
+ -d '{"text": "hello, llm"}'
+ ```
+
+Note:
+1. 请保证 shm-size >= 5,不然可能会导致服务启动失败
+
+更多关于 FastDeploy 的使用方法,请查看[服务化部署流程](https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md)
+
+# License
+
+FastDeploy 遵循 [Apache-2.0开源协议](https://github.com/PaddlePaddle/FastDeploy/blob/develop/LICENSE) 。
diff --git a/llm/client/README.md b/llm/client/README.md
new file mode 100644
index 0000000000..396f83cc09
--- /dev/null
+++ b/llm/client/README.md
@@ -0,0 +1,110 @@
+# 客户端使用方式
+
+## 简介
+
+FastDeploy客户端提供命令行接口和Python接口,可以快速调用FastDeploy后端部署的LLM模型服务。
+
+## 安装
+
+源码安装
+```
+pip install .
+```
+
+## 命令行接口
+
+首先通过环境变量设置模型服务模式、模型服务URL、模型ID,然后使用命令行接口调用模型服务。
+
+| 参数 | 说明 | 是否必填 | 默认值 |
+| --- | --- | --- | --- |
+| FASTDEPLOY_MODEL_URL | 模型服务部署的IP地址和端口,格式为`x.x.x.x:xxx`。 | 是 | |
+
+```
+export FASTDEPLOY_MODEL_URL="x.x.x.x:xxx"
+
+# 流式接口
+fdclient stream_generate "你好?"
+
+# 非流式接口
+fdclient generate "你好,你是谁?"
+```
+
+## Python接口
+
+首先通过Python代码设置模型服务URL(hostname+port),然后使用Python接口调用模型服务。
+
+| 参数 | 说明 | 是否必填 | 默认值 |
+| --- | --- | --- | --- |
+| hostname+port | 模型服务部署的IP地址和端口,格式为`x.x.x.x。 | 是 | |
+
+
+```
+from fastdeploy_client.chatbot import ChatBot
+
+hostname = "x.x.x.x"
+port = xxx
+
+# 流式接口,stream_generate api的参数说明见附录
+chatbot = ChatBot(hostname=hostname, port=port)
+stream_result = chatbot.stream_generate("你好", topp=0.8)
+for res in stream_result:
+ print(res)
+
+# 非流式接口,generate api的参数说明见附录
+chatbot = ChatBot(hostname=hostname, port=port)
+result = chatbot.generate("你好", topp=0.8)
+print(result)
+```
+
+### 接口说明
+```
+ChatBot.stream_generate(message,
+ max_dec_len=1024,
+ min_dec_len=2,
+ topp=0.0,
+ temperature=1.0,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ eos_token_ids=254186)
+
+# 此函数返回一个iterator,其中每个元素为一个dict, 例如:{"token": "好的", "is_end": 0}
+# 其中token为生成的字符,is_end表明是否为生成的最后一个字符(0表示否,1表示是)
+# 注意:当生成结果出错时,返回错误信息;不同模型的eos_token_ids不同
+```
+
+```
+ChatBot.generate(message,
+ max_dec_len=1024,
+ min_dec_len=2,
+ topp=0.0,
+ temperature=1.0,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ eos_token_ids=254186)
+
+# 此函数返回一个,例如:{"results": "好的,我知道了。"},其中results即为生成结果
+# 注意:当生成结果出错时,返回错误信息;不同模型的eos_token_ids不同
+```
+
+### 参数说明
+
+| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
+| :---: | :-----: | :---: | :---: | :-----: | :----: |
+| req_id | str | 请求ID,用于标识一个请求。建议设置req_id,保证其唯一性 | 否 | 随机id | 如果推理服务中同时有两个相同req_id的请求,会返回req_id重复的错误信息 |
+| text | str | 请求的文本 | 是 | 无 | |
+| max_dec_len | int | 最大生成token的长度,如果请求的文本token长度加上max_dec_len大于模型的max_seq_len,会返回长度超限的错误信息 | 否 | max_seq_len减去文本token长度 | |
+| min_dec_len | int | 最小生成token的长度,最小是1 | 否 | 1 | |
+| topp | float | 控制随机性参数,数值越大则随机性越大,范围是0~1 | 否 | 0.7 | |
+| temperature | float | 控制随机性参数,数值越小随机性越大,需要大于 0 | 否 | 0.95 | |
+| frequency_score | float | 频率分数 | 否 | 0 | |
+| penalty_score | float | 惩罚分数 | 否 | 1 | |
+| presence_score | float | 存在分数 | 否 | 0 | |
+| stream | bool | 是否流式返回 | 否 | False | |
+| return_all_tokens | bool | 是否一次性返回所有结果 | 否 | False | 与stream参数差异见表后备注 |
+| timeout | int | 请求等待的超时时间,单位是秒 | 否 | 300 | |
+
+* 在正确配置PUSH_MODE_HTTP_PORT字段下,服务支持 GRPC 和 HTTP 两种请求服务
+ * stream 参数仅对 HTTP 请求生效
+ * return_all_tokens 参数对 GRPC 和 HTTP 请求均有效
diff --git a/llm/client/fastdeploy_client/__init__.py b/llm/client/fastdeploy_client/__init__.py
new file mode 100644
index 0000000000..83ae7a0036
--- /dev/null
+++ b/llm/client/fastdeploy_client/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import sys
+
+__version__ = "4.4.0"
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
diff --git a/llm/client/fastdeploy_client/chatbot.py b/llm/client/fastdeploy_client/chatbot.py
new file mode 100644
index 0000000000..5353e30001
--- /dev/null
+++ b/llm/client/fastdeploy_client/chatbot.py
@@ -0,0 +1,308 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import queue
+import traceback
+import uuid
+from functools import partial
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+from fastdeploy_client.message import ChatMessage
+from fastdeploy_client.utils import is_enable_benchmark
+from tritonclient import utils as triton_utils
+
+
+class ChatBotClass(object):
+ """
+ initiating conversations through the tritonclient interface of the model service.
+ """
+ def __init__(self, hostname, port, timeout=120):
+ """
+ Initialization function
+
+ Args:
+ hostname (str): gRPC hostname
+ port (int): gRPC port
+ timeout (int): Request timeout, default is 120 seconds.
+
+ Returns:
+ None
+ """
+ self.url = f"{hostname}:{port}"
+ self.timeout = timeout
+
+ def stream_generate(self,
+ message,
+ max_dec_len=1024,
+ min_dec_len=1,
+ topp=0.7,
+ temperature=0.95,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ system=None,
+ **kwargs):
+ """
+ Streaming interface
+
+ Args:
+ message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
+ max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
+ min_dec_len (int, optional): 最小解码长度. Defaults to 1.
+ topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
+ temperature (float, optional): 温度值. Defaults to 0.95.
+ frequency_score (float, optional): 频率分数. Defaults to 0.0.
+ penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
+ presence_score (float, optional): 存在分数. Defaults to 0.0.
+ system (str, optional): 系统设定. Defaults to None.
+ **kwargs: 其他参数
+ req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
+ eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
+ benchmark (bool, optional): 设置benchmark模式,如果是则返回完整的response. Defaults to False.
+ timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
+
+ Returns:
+ 返回一个生成器,每次yield返回一个字典。
+ 正常情况下,生成器返回字典的示例{"req_id": "xxx", "token": "好的", "is_end": 0},其中token为生成的字符,is_end表明是否为最后一个字符(0表示否,1表示是)
+ 错误情况下,生成器返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
+ """
+ try:
+ # 准备输入
+ model_name = "model"
+ inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
+ outputs = [grpcclient.InferRequestedOutput("OUT")]
+ output_data = OutputData()
+
+ msg = message.message if isinstance(message, ChatMessage) else message
+ input_data = self._prepare_input_data(msg, max_dec_len, min_dec_len,
+ topp, temperature, frequency_score,
+ penalty_score, presence_score, **kwargs)
+ req_id = input_data["req_id"]
+ inputs[0].set_data_from_numpy(np.array([json.dumps([input_data])], dtype=np.object_))
+ timeout = kwargs.get("timeout", self.timeout)
+
+ with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
+ # 建立连接
+ triton_client.start_stream(callback=partial(triton_callback, output_data))
+ # 发送请求
+ triton_client.async_stream_infer(model_name=model_name,
+ inputs=inputs,
+ request_id=req_id,
+ outputs=outputs)
+ # 处理结果
+ answer_str = ""
+ enable_benchmark = is_enable_benchmark(**kwargs)
+ while True:
+ try:
+ response = output_data._completed_requests.get(timeout=timeout)
+ except queue.Empty:
+ yield {"req_id": req_id, "error_msg": f"Fetch response from server timeout ({timeout}s)"}
+ break
+ if type(response) == triton_utils.InferenceServerException:
+ yield {"req_id": req_id, "error_msg": f"InferenceServerException raised by inference: {response.message()}"}
+ break
+ else:
+ if enable_benchmark:
+ response = json.loads(response.as_numpy("OUT")[0])
+ if isinstance(response, (list, tuple)):
+ response = response[0]
+ else:
+ response = self._format_response(response, req_id)
+ token = response.get("token", "")
+ if isinstance(token, list):
+ token = token[0]
+ answer_str += token
+ yield response
+ if response.get("is_end") == 1 or response.get("error_msg") is not None:
+ break
+ # 手动关闭
+ triton_client.stop_stream(cancel_requests=True)
+ triton_client.close()
+
+ if isinstance(message, ChatMessage):
+ message.message.append({"role": "assistant", "content": answer_str})
+ except Exception as e:
+ yield {"error_msg": f"{e}, details={str(traceback.format_exc())}"}
+
+ def generate(self,
+ message,
+ max_dec_len=1024,
+ min_dec_len=1,
+ topp=0.7,
+ temperature=0.95,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ system=None,
+ **kwargs):
+ """
+ 整句返回,直接使用流式返回的接口。
+
+ Args:
+ message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
+ max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
+ min_dec_len (int, optional): 最小解码长度. Defaults to 1.
+ topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
+ temperature (float, optional): 温度值. Defaults to 0.95.
+ frequency_score (float, optional): 频率分数. Defaults to 0.0.
+ penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
+ presence_score (float, optional): 存在分数. Defaults to 0.0.
+ system (str, optional): 系统设定. Defaults to None.
+ **kwargs: 其他参数
+ req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
+ eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
+ timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
+
+ Returns:
+ 返回一个字典。
+ 正常情况下,返回字典的示例{"req_id": "xxx", "results": "好的,我知道了。"}
+ 错误情况下,返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
+ """
+ stream_response = self.stream_generate(message, max_dec_len,
+ min_dec_len, topp, temperature,
+ frequency_score, penalty_score,
+ presence_score, system, **kwargs)
+ results = ""
+ token_ids = list()
+ error_msg = None
+ for res in stream_response:
+ if "token" not in res or "error_msg" in res:
+ error_msg = {"error_msg": f"response error, please check the info: {res}"}
+ elif isinstance(res["token"], list):
+ results = res["token"]
+ token_ids = res["token_ids"]
+ else:
+ results += res["token"]
+ token_ids += res["token_ids"]
+ if error_msg:
+ return {"req_id": res["req_id"], "error_msg": error_msg}
+ else:
+ return {"req_id": res["req_id"], "results": results, "token_ids": token_ids}
+
+ def _prepare_input_data(self,
+ message,
+ max_dec_len=1024,
+ min_dec_len=2,
+ topp=0.0,
+ temperature=1.0,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ system=None,
+ **kwargs):
+ """
+ 准备输入数据。
+ """
+ inputs = {
+ "max_dec_len": max_dec_len,
+ "min_dec_len": min_dec_len,
+ "topp": topp,
+ "temperature": temperature,
+ "frequency_score": frequency_score,
+ "penalty_score": penalty_score,
+ "presence_score": presence_score,
+ }
+
+ if system is not None:
+ inputs["system"] = system
+
+ inputs["req_id"] = kwargs.get("req_id", str(uuid.uuid4()))
+ if "eos_token_ids" in kwargs and kwargs["eos_token_ids"] is not None:
+ inputs["eos_token_ids"] = kwargs["eos_token_ids"]
+ inputs["response_timeout"] = kwargs.get("timeout", self.timeout)
+
+ if isinstance(message, str):
+ inputs["text"] = message
+ elif isinstance(message, list):
+ assert len(message) % 2 == 1, \
+ "The length of message should be odd while it's a list."
+ assert message[-1]["role"] == "user", \
+ "The {}-th element key should be 'user'".format(len(message) - 1)
+ for i in range(0, len(message) - 1, 2):
+ assert message[i]["role"] == "user", \
+ "The {}-th element key should be 'user'".format(i)
+ assert message[i + 1]["role"] == "assistant", \
+ "The {}-th element key should be 'assistant'".format(i + 1)
+ inputs["messages"] = message
+ else:
+ raise Exception(
+ "The message should be string or list of dict like [{'role': "
+ "'user', 'content': 'Hello, what's your name?''}]"
+ )
+
+ return inputs
+
+ def _format_response(self, response, req_id):
+ """
+ 对服务返回字段进行格式化
+ """
+ response = json.loads(response.as_numpy("OUT")[0])
+ if isinstance(response, (list, tuple)):
+ response = response[0]
+ is_end = response.get("is_end", False)
+
+ if "error_msg" in response:
+ return {"req_id": req_id, "error_msg": response["error_msg"]}
+ elif "choices" in response:
+ token = [x["token"] for x in response["choices"]]
+ token_ids = [x["token_ids"] for x in response["choices"]]
+ return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": 1}
+ elif "token" not in response and "result" not in response:
+ return {"req_id": req_id, "error_msg": f"The response should contain 'token' or 'result', but got {response}"}
+ else:
+ token_ids = response.get("token_ids", [])
+ if "result" in response:
+ token = response["result"]
+ elif "token" in response:
+ token = response["token"]
+ return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": is_end}
+
+
+class OutputData:
+ """接收Triton服务返回的数据"""
+ def __init__(self):
+ self._completed_requests = queue.Queue()
+
+
+def triton_callback(output_data, result, error):
+ """Triton客户端的回调函数"""
+ if error:
+ output_data._completed_requests.put(error)
+ else:
+ output_data._completed_requests.put(result)
+
+
+class ChatBot(object):
+ """
+ 对外的接口,用于创建ChatBotForPushMode的示例
+ """
+ def __new__(cls, hostname, port, timeout=120):
+ """
+ 初始化函数,用于创建一个GRPCInferenceService客户端对象
+ Args:
+ hostname (str): 服务器的地址
+ port (int): 服务器的端口号
+ timeout (int): 请求超时时间,单位为秒,默认120秒
+ Returns:
+ ChatBotClass: 返回一个BaseChatBot对象
+ """
+ if not isinstance(hostname, str) or not hostname:
+ raise ValueError("Invalid hostname")
+ if not isinstance(port, int) or port <= 0 or port > 65535:
+ raise ValueError("Invalid port number")
+
+ return ChatBotClass(hostname, port, timeout)
diff --git a/llm/client/fastdeploy_client/command.py b/llm/client/fastdeploy_client/command.py
new file mode 100644
index 0000000000..567f490cc2
--- /dev/null
+++ b/llm/client/fastdeploy_client/command.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+
+from fastdeploy_client.chatbot import ChatBot
+
+
+def _get_service_configuration():
+ """
+ 从环境变量获取服务配置信息
+ """
+ url = os.getenv("FASTDEPLOY_MODEL_URL")
+
+ if url is None:
+ raise ValueError("Please set service url by `export FASTDEPLOY_MODEL_URL`."
+ "For example: `export FASTDEPLOY_MODEL_URL=127.0.0.1:8500`")
+ hostname, port = url.strip().split(':')
+ port = int(port)
+ if port <= 0 or port > 65535:
+ raise ValueError("Invalid port number")
+
+ return hostname, port
+
+
+def stream_generate(prompt):
+ """
+ 命令工具:流式返回
+ """
+ hostname, port = _get_service_configuration()
+ chatbot = ChatBot(hostname=hostname, port=port)
+ stream_result = chatbot.stream_generate(prompt)
+ for res in stream_result:
+ print(res)
+
+
+def generate(prompt):
+ """
+ 命令工具:整句返回
+ """
+ hostname, port = _get_service_configuration()
+ chatbot = ChatBot(hostname=hostname, port=port)
+ result = chatbot.generate(prompt)
+ print(result)
+
+
+def main():
+ """
+ 命令工具主入口
+ """
+ if len(sys.argv) < 2 or sys.argv[1] not in ["generate", "stream_generate"]:
+ logging.error("Usage 1: fdclient generate \"Hello, How are you?\"")
+ return
+ prompt = sys.argv[2]
+ if sys.argv[1] == "generate":
+ return generate(prompt)
+ else:
+ return stream_generate(prompt)
diff --git a/llm/client/fastdeploy_client/message.py b/llm/client/fastdeploy_client/message.py
new file mode 100644
index 0000000000..7f1dc07326
--- /dev/null
+++ b/llm/client/fastdeploy_client/message.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+class ChatMessage(object):
+ """
+ 多轮对话数据结构,当使用这个与ChatBot对话时
+ 会将对话记录存储在此结构体内,支持多轮
+ """
+ def __init__(self, prompt=None):
+ if prompt is not None:
+ self.message = [{"role": "user", "content": prompt}]
+ else:
+ self.message = []
+
+ def add_user_message(self, content):
+ """
+ 添加一个用户消息
+ """
+ if len(self.message) > 0 and self.message[-1]["role"] != "assistant":
+ raise Exception("Cannot add user message, because the role of the "
+ f"last message is not assistant. The message is {self.message}")
+ self.message.append({"role": "user", "content": content})
+
+ def add_assistant_message(self, content):
+ """
+ 添加一个assistant消息
+ """
+ if len(self.message) > 0 and self.message[-1]["role"] != "user":
+ raise Exception("Cannot add user message, because the role of the "
+ f"last message is not user. The message is {self.message}")
+ self.message.append({"role": "assistant", "content": content})
+
+ def next_prompt(self, content):
+ """
+ 添加一个新的对话,保留用于兼容。
+ """
+ self.add_user_message(content)
+
+ def __str__(self):
+ return str(self.message)
diff --git a/llm/client/fastdeploy_client/utils.py b/llm/client/fastdeploy_client/utils.py
new file mode 100644
index 0000000000..935d80efc6
--- /dev/null
+++ b/llm/client/fastdeploy_client/utils.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def is_enable_benchmark(**kwargs):
+ """是否是benchmark模式"""
+ return "benchmark" in kwargs and kwargs["benchmark"] == 1
diff --git a/llm/client/requirements.txt b/llm/client/requirements.txt
new file mode 100644
index 0000000000..132f7f2b0d
--- /dev/null
+++ b/llm/client/requirements.txt
@@ -0,0 +1,5 @@
+grpcio
+streamlit<=1.33.0
+streamlit_chat<=0.1.1
+protobuf==3.20.0
+tritonclient[grpc]==2.41.1
diff --git a/llm/client/setup.py b/llm/client/setup.py
new file mode 100644
index 0000000000..9075c45ae0
--- /dev/null
+++ b/llm/client/setup.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import setuptools
+from fastdeploy_client import __version__ as version
+
+long_description = "No description"
+with open("requirements.txt") as fin:
+ REQUIRED_PACKAGES = fin.read()
+
+setuptools.setup(
+ name="fastdeploy-client",
+ version=version,
+ author="dltp-sz",
+ author_email="dltp-sz@baidu.com",
+ description="Client for fastdeploy llm serving",
+ long_description=long_description,
+ long_description_content_type="text/plain",
+ url="https://github.com/PaddlePaddle/Paddle",
+ packages=setuptools.find_packages(),
+ install_requires=REQUIRED_PACKAGES,
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ ],
+ license='Apache 2.0',
+ entry_points={'console_scripts': ['fdclient=fastdeploy_client.command:main', ]})
diff --git a/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8
new file mode 100644
index 0000000000..e288057756
--- /dev/null
+++ b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8
@@ -0,0 +1,35 @@
+FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda11.8-cudnn8-nccl2.15.5
+
+WORKDIR /opt/output/
+COPY ./server/ /opt/output/Serving
+COPY ./client/ /opt/output/client/
+
+RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \
+ && python3 -m pip install paddlenlp==3.0.0b0 \
+ && python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+ENV LD_LIBRARY_PATH "/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH"
+RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
+ && python3 setup_cuda.py build && python3 setup_cuda.py install --user \
+ && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
+ && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
+ && rm -rf PaddleNLP
+
+RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
+
+RUN python3 -m pip install -r /opt/output/Serving/requirements.txt && rm /opt/output/Serving/requirements.txt
+RUN mv Serving/server /usr/local/lib/python3.10/dist-packages/
+RUN mkdir -p /opt/output/Serving/llm_model/model/1 \
+ && mv /opt/output/Serving/config/config.pbtxt /opt/output/Serving/llm_model/model/ \
+ && rm -rf /opt/output/Serving/config/
+RUN echo "from server.triton_server import TritonPythonModel" >>/opt/output/Serving/llm_model/model/1/model.py
+
+RUN cd /opt/output/Serving/ \
+ && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
+ && rm -rf scripts
+
+RUN python3 -m pip install protobuf==3.20.0
+
+ENV http_proxy ""
+ENV https_proxy ""
diff --git a/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 b/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9
new file mode 100644
index 0000000000..bb56007b6e
--- /dev/null
+++ b/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9
@@ -0,0 +1,35 @@
+FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda12.3-cudnn9-nccl2.15.5
+
+WORKDIR /opt/output/
+COPY ./server/ /opt/output/Serving
+COPY ./client/ /opt/output/client/
+
+RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \
+ && python3 -m pip install paddlenlp==3.0.0b0 \
+ && python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+ENV LD_LIBRARY_PATH "/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH"
+RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
+ && python3 setup_cuda.py build && python3 setup_cuda.py install --user \
+ && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
+ && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
+ && rm -rf PaddleNLP
+
+RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
+
+RUN python3 -m pip install -r /opt/output/Serving/requirements.txt && rm /opt/output/Serving/requirements.txt
+RUN mv Serving/server /usr/local/lib/python3.10/dist-packages/
+RUN mkdir -p /opt/output/Serving/llm_model/model/1 \
+ && mv /opt/output/Serving/config/config.pbtxt /opt/output/Serving/llm_model/model/ \
+ && rm -rf /opt/output/Serving/config/
+RUN echo "from server.triton_server import TritonPythonModel" >>/opt/output/Serving/llm_model/model/1/model.py
+
+RUN cd /opt/output/Serving/ \
+ && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
+ && rm -rf scripts
+
+RUN python3 -m pip install protobuf==3.20.0
+
+ENV http_proxy ""
+ENV https_proxy ""
diff --git a/llm/docs/FastDeploy_usage_tutorial.md b/llm/docs/FastDeploy_usage_tutorial.md
new file mode 100644
index 0000000000..a55f78d263
--- /dev/null
+++ b/llm/docs/FastDeploy_usage_tutorial.md
@@ -0,0 +1,213 @@
+
+## 目录
+
+- [部署环境准备](#部署环境准备)
+ - [准备部署镜像](#准备部署镜像)
+ - [准备模型](#准备模型)
+ - [创建容器](#创建容器)
+ - [基于dockerfile创建自己的镜像](#基于dockerfile创建自己的镜像)
+- [启动服务](#启动服务)
+ - [配置参数](#配置参数)
+ - [启动FastDeploy](#启动FastDeploy)
+ - [服务状态查询](#服务状态查询)
+- [服务测试](#服务测试)
+ - [Python 客户端](#Python-客户端)
+ - [HTTP调用](#HTTP调用)
+ - [请求参数介绍](#请求参数介绍)
+ - [返回示例](#返回示例)
+
+## 部署环境准备
+
+### 准备部署镜像
+
+为了方便部署,我们提供了cuda12.3的镜像,可以直接拉取镜像,或者使用dockerfile[构建自定义镜像](#基于dockerfile创建自己的镜像)
+```
+docker pull registry.baidubce.com/paddlepaddle/fastdeploy:llm-serving-cuda123-cudnn9-v1.0
+```
+
+### 准备模型
+
+模型放在对应文件夹下,以 `/home/workspace/models_dir` 为例
+```
+cd /home/workspace/models_dir
+
+# 模型内目录结构需要整理成特定格式,如下是单卡部署的模型目录结构
+# /opt/output/Serving/models
+# ├── config.json # 模型配置文件(必选)
+# ├── xxxx.model # 词表模型文件(必选)
+# ├── special_tokens_map.json # 词表配置文件(必选)
+# ├── tokenizer_config.json # 词表配置文件(必选)
+# ├── rank_mapping.csv # 多卡模型会有此文件,如为单卡模型,则无此文件(可选,仅在多卡部署模式下需要)
+# └── rank_0 # 保存模型结构和权重文件的目录(必选)
+# ├── model.pdiparams
+# └── model.pdmodel
+```
+
+### 创建容器
+
+```
+docker run --gpus all \
+ --name fastdeploy_serving \
+ --network=host \
+ --shm-size=10G \
+ -v /home/workspace/models_dir:/fastdeploy/models/ \
+ -dit registry.baidubce.com/paddlepaddle/fastdeploy:llm-serving-cuda123-cudnn9-v1.0 bash
+
+# 进入容器,检查GPU环境和模型挂载是否正常
+docker exec -it fastdeploy_serving /bin/bash
+nvidia-smi
+ls /fastdeploy/models/
+```
+
+## 基于dockerfile创建自己的镜像
+
+```
+git clone https://github.com/PaddlePaddle/FastDeploy.git
+cd FastDeploy/llm
+
+docker build -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self .
+```
+
+创建自己的镜像后,可以基于该镜像[创建容器](#创建容器)
+
+## 启动服务
+
+### 配置参数
+
+根据需求和硬件信息,配置以下环境变量
+
+```
+# 单/多卡推理配置。自行修改。
+## 如果是单卡推理,使用0卡,设置如下环境变量。
+export MP_NUM=1
+export CUDA_VISIBLE_DEVICES=0
+
+## 如果是多卡推理,除了模型导出得满足2卡要求,同时设置如下环境变量。
+# export MP_NUM=2
+# export CUDA_VISIBLE_DEVICES=0,1
+
+# 如部署场景无流式Token返回需求,可配置如下开关
+# 服务将会将每个请求的所有生成Token一次性返回
+# 降低服务逐个Token发送压力
+# 默认关闭
+# export DISABLE_STREAMING=1
+
+# 配置数据服务。需要自行修改HTTP_PORT、GRPC_PORT、METRICS_PORT和INFER_QUEUE_PORT。
+# 请事先检查端口可用:执行`netstat -tuln | grep <端口号>`,如果没有log输出,则表示该端口未被占用。
+export HTTP_PORT="8751" # 探活服务的http端口(当前仅用于健康检查、探活)
+export GRPC_PORT="8752" # 模型推服务的grpc端口
+export METRICS_PORT="8753" # 模型服务中监督指标的端口
+export INFER_QUEUE_PORT="8754" # 模型服务内部使用的端口
+export PUSH_MODE_HTTP_PORT="8143" # 服务请求HTTP端口号,如不配置,默认为-1,即服务只支持GRPC协议
+
+# MAX_SEQ_LEN: 服务会拒绝input token数量超过MAX_SEQ_LEN的请求,并返回错误提示
+# MAX_DEC_LEN: 服务会拒绝请求中max_dec_len/min_dec_len超过此参数的请求,并返回错误提示
+export MAX_SEQ_LEN=8192
+export MAX_DEC_LEN=1024
+
+export BATCH_SIZE="48" # 设置最大Batch Size,模型可同时并发处理的最大输入数量,不能高于128
+export BLOCK_BS="5" # 缓存Block支持的最大Query Batch Size,如果出现out of memeory 错误,尝试减少该数值
+export BLOCK_RATIO="0.75" # 一般可以设置成 输入平均Token数/(输入+输出平均Token数)
+
+export MAX_CACHED_TASK_NUM="128" # 服务缓存队列最大长度,队列达到上限后,会拒绝新的请求,默认128
+# 开启HTTP接口配置如下参数
+export PUSH_MODE_HTTP_WORKERS="1" # HTTP服务进程数,在 PUSH_MODE_HTTP_PORT 配置的情况下有效,最高设置到8即可,默认为1
+```
+
+### 启动FastDeploy
+
+```
+cd /opt/output/Serving
+bash start_server.sh
+
+# 重新启动服务前,需要停止服务,在/opt/output/Serving目录下执行 bash stop_server.sh
+```
+
+### 服务状态查询
+
+```
+# port为上面启动服务时候指定的HTTP_PORT
+live接口: (服务是否能正常接收请求)
+ http://{ip}:{HTTP_PORT}/v2/health/live
+health接口:(模型是否准备好推理)
+ http://{ip}:{HTTP_PORT}/v2/health/ready
+```
+
+## 服务测试
+
+### Python 客户端
+
+```
+from fastdeploy_client.chatbot import ChatBot
+
+hostname = "127.0.0.1" # 服务部署的hostname
+port = 8000 # 服务配置的GRPC_PORT
+
+chatbot = ChatBot(hostname=hostname, port=port)
+
+# 非流式接口
+result = chatbot.generate("你好", topp=0.8, max_dec_len=128, timeout=120)
+print(result)
+
+# 流式接口
+chatbot = ChatBot(hostname=hostname, port=port, model_id=model_id, mode=mode)
+stream_result = chatbot.stream_generate("你好", max_dec_len=128, timeout=120)
+for res in stream_result:
+ print(res)
+```
+
+### HTTP调用
+
+提示:HTTP调用接口使用变量 PUSH_MODE_HTTP_PORT 配置!HTTP_PORT 仅用于探活接口使用!
+
+```
+import uuid
+import json
+import requests
+
+url = f"http://0.0.0.0:{PUSH_MODE_HTTP_PORT}/v1/chat/completions"
+req_id = str(uuid.uuid1())
+data = {
+ "text": "Hello, how are you?",
+ "req_id": req_id,
+ "max_dec_len": 64,
+ "stream": True,
+ }
+# 逐token返回
+res = requests.post(url, json=data, stream=True)
+for line in res.iter_lines():
+ print(json.loads(line))
+```
+
+### 请求参数介绍
+
+| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
+| :---: | :-----: | :---: | :---: | :-----: | :----: |
+| req_id | str | 请求ID,用于标识一个请求。建议设置req_id,保证其唯一性 | 否 | 随机id | 如果推理服务中同时有两个相同req_id的请求,会返回req_id重复的错误信息 |
+| text | str | 请求的文本 | 是 | 无 | |
+| max_dec_len | int | 最大生成token的长度,如果请求的文本token长度加上max_dec_len大于模型的max_seq_len,会返回长度超限的错误信息 | 否 | max_seq_len减去文本token长度 | |
+| min_dec_len | int | 最小生成token的长度,最小是1 | 否 | 1 | |
+| topp | float | 控制随机性参数,数值越大则随机性越大,范围是0~1 | 否 | 0.7 | |
+| temperature | float | 控制随机性参数,数值越小随机性越大,需要大于 0 | 否 | 0.95 | |
+| frequency_score | float | 频率分数 | 否 | 0 | |
+| penalty_score | float | 惩罚分数 | 否 | 1 | |
+| presence_score | float | 存在分数 | 否 | 0 | |
+| stream | bool | 是否流式返回 | 否 | False | |
+| return_all_tokens | bool | 是否一次性返回所有结果 | 否 | False | 与stream参数差异见表后备注 |
+| timeout | int | 请求等待的超时时间,单位是秒 | 否 | 300 | |
+
+* 在正确配置PUSH_MODE_HTTP_PORT字段下,服务支持 GRPC 和 HTTP 两种请求服务
+ * stream 参数仅对 HTTP 请求生效
+ * return_all_tokens 参数对 GRPC 和 HTTP 请求均有效
+
+### 返回示例
+
+```
+如果stream为True,流式返回
+ 如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+
+如果stream为False,非流式返回
+ 如果正常,返回{'tokens_all': xxx, 'tokens_all_num': xxx, ..., 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+```
diff --git a/llm/requirements-dev.txt b/llm/requirements-dev.txt
new file mode 100644
index 0000000000..e1eec92d20
--- /dev/null
+++ b/llm/requirements-dev.txt
@@ -0,0 +1,3 @@
+black[jupyter] == 23.3.0
+isort == 5.11.5
+pre-commit
diff --git a/llm/server/config/config.pbtxt b/llm/server/config/config.pbtxt
new file mode 100644
index 0000000000..375c41d013
--- /dev/null
+++ b/llm/server/config/config.pbtxt
@@ -0,0 +1,20 @@
+backend: "python"
+max_batch_size: 0
+model_transaction_policy {
+ decoupled: True
+}
+input [
+ {
+ name: "IN"
+ data_type: TYPE_STRING
+ dims: [ 1 ]
+ }
+]
+output [
+ {
+ name: "OUT"
+ data_type: TYPE_STRING
+ dims: [ 1 ]
+ }
+]
+instance_group [{ kind: KIND_CPU }]
diff --git a/llm/server/requirements.txt b/llm/server/requirements.txt
new file mode 100644
index 0000000000..d9bfd84e8f
--- /dev/null
+++ b/llm/server/requirements.txt
@@ -0,0 +1,22 @@
+# model server
+paddlenlp==2.7.2
+sentencepiece
+pycryptodome
+tritonclient[all]==2.41.1
+opencv-python
+patchify
+transformers
+
+# http server
+fastapi
+httpx
+openai==1.9.0
+asyncio
+uvicorn
+shortuuid
+
+# parameter search
+pynvml
+
+# paddlenlp
+tiktoken
diff --git a/llm/server/scripts/start_server.sh b/llm/server/scripts/start_server.sh
new file mode 100644
index 0000000000..784d3b7c44
--- /dev/null
+++ b/llm/server/scripts/start_server.sh
@@ -0,0 +1,58 @@
+#!/usr/bin/bash
+
+export GLOG_v=0
+export GLOG_logtostderr=1
+export PYTHONIOENCODING=utf8
+export LC_ALL=C.UTF-8
+
+# PaddlePaddle environment variables
+export FLAGS_allocator_strategy=naive_best_fit
+export FLAGS_fraction_of_gpu_memory_to_use=0.96
+export FLAGS_dynamic_static_unified_comm=0
+export FLAGS_use_xqa_optim=1
+export FLAGS_gemm_use_half_precision_compute_type=0
+export NVIDIA_TF32_OVERRIDE=0
+
+# Model hyperparameters
+export MP_NUM=${MP_NUM:-"1"} # GPU num
+export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} # GPU
+export MAX_SEQ_LEN=${MAX_SEQ_LEN:-"8192"}
+export MAX_DEC_LEN=${MAX_DEC_LEN:-"2048"}
+export BATCH_SIZE=${BATCH_SIZE:-"20"}
+export BLOCK_BS=${BLOCK_BS:-"4"}
+export BLOCK_SIZE=${BLOCK_SIZE:-"64"}
+export DTYPE=${DTYPE:-"bfloat16"}
+export USE_CACHE_KV_INT8=${USE_CACHE_KV_INT8:-"0"} # c8 model requires configuration 1
+export BLOCK_RATIO=${BLOCK_RATIO:-"0.75"}
+export ENC_DEC_BLOCK_NUM=${ENC_DEC_BLOCK_NUM:-"4"}
+export FIRST_TOKEN_ID=${FIRST_TOKEN_ID:-"1"}
+export MAX_PREFILL_BATCH=${MAX_PREFILL_BATCH:-"4"}
+export STOP_THRESHOLD=${STOP_THRESHOLD:-"0"}
+export MODEL_DIR=${MODEL_DIR:-"/models/"}
+export DISTRIBUTED_CONFIG=${DISTRIBUTED_CONFIG:-"${MODEL_DIR}/rank_mapping.csv"}
+export CONFIG_JSON_FILE=${CONFIG_JSON_FILE:-"config.json"}
+export PUSH_MODE_HTTP_WORKERS=${PUSH_MODE_HTTP_WORKERS:-"4"}
+
+# serving port
+export HTTP_PORT=${HTTP_PORT:-"8110"}
+export GRPC_PORT=${GRPC_PORT:-"8811"}
+export METRICS_PORT=${METRICS_PORT:-"8722"}
+export INFER_QUEUE_PORT=${INFER_QUEUE_PORT:-"8813"}
+export PUSH_MODE_HTTP_PORT=${PUSH_MODE_HTTP_PORT:-"9965"}
+
+mkdir -p log
+rm -rf console.log log/*
+rm -rf /dev/shm/*
+
+# 启动服务
+echo "start serving ..."
+
+tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \
+ --cuda-memory-pool-byte-size 2:0 --cuda-memory-pool-byte-size 3:0 --cuda-memory-pool-byte-size 4:0 \
+ --cuda-memory-pool-byte-size 5:0 --cuda-memory-pool-byte-size 6:0 --cuda-memory-pool-byte-size 7:0 \
+ --pinned-memory-pool-byte-size 0 --model-repository llm_model/ \
+ --allow-http false \
+ --grpc-port=${GRPC_PORT} \
+ --metrics-port=${METRICS_PORT} \
+ --log-file log/server.log --log-info true > log/console.log 2>&1 &
+echo "模型服务的启动日志,请查看" ${PWD}"/log/server.log 和 "${PWD}"/log/workerlog.0 "
diff --git a/llm/server/scripts/stop_server.sh b/llm/server/scripts/stop_server.sh
new file mode 100644
index 0000000000..ad8f0e3e93
--- /dev/null
+++ b/llm/server/scripts/stop_server.sh
@@ -0,0 +1,69 @@
+# /bin/bash
+
+pids=($(ps aux | grep -E 'tritonserver' | grep -v grep | awk '{print $2}'))
+
+if [ ${#pids[@]} -eq 0 ]; then
+ echo "未找到 tritonserver 相关进程"
+ timeout=1
+else
+ timeout=300
+fi
+
+# kill processor
+for pid in "${pids[@]}"; do
+ echo "正在中断进程 $pid"
+ kill -2 "$pid"
+done
+
+timeout_interval=$1
+if [ ! "$timeout_interval" == "" ]; then
+ timeout=$timeout_interval
+ echo $timeout
+fi
+
+start_time=$(date +%s)
+
+while : ; do
+ current_time=$(date +%s)
+
+ elapsed_time=$((current_time - start_time))
+
+ if [ $elapsed_time -ge $timeout ]; then
+ echo "tritonserver进程超时未退出"
+ echo "强制杀死所有有关进程"
+ pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|new_infer.py|infer|multiprocessing.resource_tracker|paddle.distributed.launch|task_queue_manager|app.py|memory_log.py|spawn_main" | grep -v grep | grep -v start_both | awk '{print $2}');
+ echo $pids;
+ for pid in ${pids[@]}; do
+ kill -9 ${pid}
+ done
+ break
+ fi
+
+ pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|new_infer.py|multiprocessing.resource_tracker|paddle.distributed.launch|app.py|memory_log.py|spawn_main" | grep -v grep | awk '{print $2}');
+ array=($(echo "$pids" | tr ' ' '\n'))
+
+ if [ ${#array[*]} -ne 0 ]; then
+ echo "进程还没有清理干净, 等待清理完毕"
+ sleep 1
+ else
+ echo "进程已经清理干净"
+ break
+ fi
+done
+
+manager_pids=$(ps auxww | grep "task_queue_manager" | grep -v grep | awk '{print $2}')
+echo $manager_pids
+for in_pid in ${manager_pids[@]}; do
+ kill -9 ${in_pid}
+done
+echo 'end kill queue manager'
+
+health_checker_pids=$(ps auxww | grep "health.py" | grep -v grep | awk '{print $2}')
+echo $health_checker_pids
+for in_pid in ${health_checker_pids[@]}; do
+ kill -9 ${in_pid}
+done
+echo 'end kill health checker'
+
+echo "所有进程已终止"
+exit 0
diff --git a/llm/server/server/__init__.py b/llm/server/server/__init__.py
new file mode 100644
index 0000000000..5ae9b7e8cf
--- /dev/null
+++ b/llm/server/server/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__version__ = "dev"
+__commit__ = "dev"
diff --git a/llm/server/server/checker.py b/llm/server/server/checker.py
new file mode 100644
index 0000000000..55944fd1b0
--- /dev/null
+++ b/llm/server/server/checker.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def check_basic_params(req_dict):
+ """
+ 对单个输入请求进行基础的校验检查,适用于推拉模式。
+ 对输入的全部字段进行检查,统一将报错信息发送给用户,注意同一个字段的检查逻辑是独立的,避免重复的报错信息。
+
+ Args:
+ req_dict (dict): 请求的字典格式数据,包含文本、模型、序列长度、最大token数等字段。
+
+ Returns:
+ list[str]: 如果校验有错误,返回错误信息列表,如果校验正确,返回空列表。
+ """
+
+ error_msg = []
+
+ # text、input_ids和messages必须设置一个
+ bools = ("text" in req_dict, "input_ids" in req_dict, "messages" in req_dict)
+ if sum(bools) == 0:
+ error_msg.append("The input parameters should contain either `text`, `input_ids` or `messages`")
+ else:
+ if "text" in req_dict:
+ if not isinstance(req_dict["text"], str):
+ error_msg.append("The `text` in input parameters must be a string")
+ elif req_dict["text"] == "":
+ error_msg.append("The `text` in input parameters cannot be empty")
+ if "system" in req_dict and not isinstance(req_dict["system"], str):
+ error_msg.append("The `system` in input parameters must be a string")
+ if "input_ids" in req_dict and not isinstance(req_dict["input_ids"], list):
+ error_msg.append("The `input_ids` in input parameters must be a list")
+ if "messages" in req_dict:
+ msg_len = len(req_dict["messages"])
+ if msg_len % 2 == 0:
+ error_msg.append(f"The number of the message {msg_len} must be odd")
+ if not all("content" in item for item in req_dict["messages"]):
+ error_msg.append("The item in messages must include `content`")
+
+ if "req_id" not in req_dict:
+ error_msg.append("The input parameters should contain `req_id`.")
+
+ if "min_dec_len" in req_dict and \
+ (not isinstance(req_dict["min_dec_len"], int) or req_dict["min_dec_len"] < 1):
+ error_msg.append("The `min_dec_len` must be an integer and greater than 0")
+
+ # 如果设置了seq_len和max_tokens,最终都赋值给max_dec_len
+ keys = ("max_dec_len", "seq_len", "max_tokens")
+ for key in keys:
+ if key in req_dict and (not isinstance(req_dict[key], int) or req_dict[key] < 1):
+ error_msg.append(f"The `{key}` must be an integer and greater than 0")
+ if "seq_len" in req_dict and "max_dec_len" not in req_dict:
+ req_dict["max_dec_len"] = req_dict["seq_len"]
+ if "max_tokens" in req_dict and "max_dec_len" not in req_dict:
+ req_dict["max_dec_len"] = req_dict["max_tokens"]
+
+ # 简化处理,topp和top_p只允许有一个,且最终都赋值给topp
+ keys = ("topp", "top_p")
+ if sum([key in req_dict for key in keys]) > 1:
+ error_msg.append(f"Only one of {keys} should be set")
+ else:
+ for key in keys:
+ if key in req_dict and not 0 <= req_dict[key] <= 1:
+ error_msg.append(f"The `{key}` must be in [0, 1]")
+ if "top_p" in req_dict and "topp" not in req_dict:
+ req_dict["topp"] = req_dict["top_p"]
+
+ if "temperature" in req_dict and not 0 <= req_dict["temperature"]:
+ error_msg.append(f"The `temperature` must be >= 0")
+
+ if "eos_token_ids" in req_dict:
+ if isinstance(req_dict["eos_token_ids"], int):
+ req_dict["eos_token_ids"] = [req_dict["eos_token_ids"]]
+ elif isinstance(req_dict["eos_token_ids"], tuple):
+ req_dict["eos_token_ids"] = list(req_dict["eos_token_ids"])
+ if not isinstance(req_dict["eos_token_ids"], list):
+ error_msg.append("The `eos_token_ids` must be an list")
+ elif len(req_dict["eos_token_ids"]) > 1:
+ error_msg.append("The length of `eos_token_ids` must be 1 if you set it")
+
+ # 简化处理,infer_seed和seed只允许有一个,且最终都赋值给infer_seed
+ keys = ("infer_seed", "seed")
+ if sum([key in req_dict for key in keys]) > 1:
+ error_msg.append(f"Only one of {keys} should be set")
+ else:
+ if "seed" in req_dict and "infer_seed" not in req_dict:
+ req_dict["infer_seed"] = req_dict["seed"]
+
+ if "stream" in req_dict and not isinstance(req_dict["stream"], bool):
+ error_msg.append("The `stream` must be a boolean")
+
+ if "response_type" in req_dict and (req_dict["response_type"].lower() not in ("fastdeploy", "openai")):
+ error_msg.append("The `response_type` must be either `fastdeploy` or `openai`.")
+
+ # 返回信息
+ return error_msg
+
+def add_default_params(req_dict):
+ """
+ 给req_dict字典添加默认值。
+ 注意:虽然infer.py中设置请求参数有默认值,但为了统一,这里提前设置默认值。请保证此处默认值和infer.py中一致。
+ 返回添加默认值后的req_dict字典。
+
+ """
+ assert isinstance(req_dict, dict), "The `req_dict` must be a dict."
+ if "min_dec_len" not in req_dict:
+ req_dict["min_dec_len"] = 1
+ if "topp" not in req_dict:
+ req_dict["topp"] = 0.7
+ if "temperature" not in req_dict:
+ req_dict["temperature"] = 0.95
+ if "penalty_score" not in req_dict:
+ req_dict["penalty_score"] = 1.0
+ if "frequency_score" not in req_dict:
+ req_dict["frequency_score"] = 0.0
+ if "presence_score" not in req_dict:
+ req_dict["presence_score"] = 0.0
+ return req_dict
diff --git a/llm/server/server/data/__init__.py b/llm/server/server/data/__init__.py
new file mode 100644
index 0000000000..97043fd7ba
--- /dev/null
+++ b/llm/server/server/data/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/llm/server/server/data/processor.py b/llm/server/server/data/processor.py
new file mode 100644
index 0000000000..748db41ed9
--- /dev/null
+++ b/llm/server/server/data/processor.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from abc import ABC, abstractmethod
+from paddlenlp.utils.llm_utils import get_eos_token_id
+from paddlenlp.transformers import (
+ LlamaTokenizer,
+ Llama3Tokenizer,
+ AutoTokenizer,
+)
+
+from server.utils import data_processor_logger
+from server.engine.config import Config
+
+
+class BaseDataProcessor(ABC):
+ """Data processor的基类"""
+
+ def __init__(self):
+ """
+ Returns:
+ None
+ """
+ self.tokenizer = self._load_tokenizer()
+ self.tokenizer.bos_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.bos_token)
+ self.tokenizer.cls_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.cls_token)
+ self.tokenizer.sep_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.sep_token)
+ self.tokenizer.eos_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.eos_token)
+ self.tokenizer.mask_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.mask_token)
+ data_processor_logger.info((f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, ",
+ f"cls_token is {self.tokenizer.cls_token}, {self.tokenizer.cls_token_id}, "
+ f"sep_token is {self.tokenizer.sep_token}, {self.tokenizer.sep_token_id}, "
+ f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}, "
+ f"mask_token is {self.tokenizer.mask_token}, {self.tokenizer.mask_token_id}"))
+
+ @abstractmethod
+ def process_request(self, request, **kwargs):
+ """
+ Preprocess the request
+
+ Args:
+ request (Dict): may contain text and messages fields
+ **kwargs: others
+
+ Returns:
+ bool: Whether preprocessing is successful
+ str: error message
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def process_response(self, response_dict):
+ """
+ Preprocess the response
+
+ Args:
+ response_dict (Dict): response for engine, contain ids fields
+
+ Returns:
+ Dict: response contain text fields
+ """
+ raise NotImplementedError
+
+ def text2ids(self, text):
+ """
+ 将文本转换为对应的 ID
+
+ Args:
+ text (str): 待转换的文本。
+
+ Returns:
+ List[int]: 转换后的 ID 列表。
+ """
+ raise NotImplementedError
+
+ def messages2ids(self, messages):
+ """
+ 将多轮对话转换为对话ID序列。
+
+ Args:
+ messages (List[List[Dict[str, Any]]]): 对话列表,每个对话是一个字典。
+
+ Returns:
+ List[int]: 对话ID序列,每个ID是一个整数。
+
+ """
+ raise NotImplementedError
+
+ def ids2tokens(self, token_ids, task_id=None):
+ """
+ 将 token ids 解码为字符串
+
+ Args:
+ token_ids (List[int]): 包含 token ids 的列表
+ task_id (str): 当前task_ids对应的任务ID
+
+ Returns:
+ List[str]: 解码后的 tokenized 字符串列表
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _load_tokenizer(self):
+ """
+ 加载分词器。
+ Returns:
+ tokenizer: 分词器。
+ """
+ raise NotImplementedError
+
+
+class DataProcessor(BaseDataProcessor):
+ """继承自Data processor的基类"""
+
+ def __init__(self):
+ """
+ 初始化函数。
+ """
+ self.config = Config()
+ max_length = self.config.get_model_config().get('max_length', 1024)
+ self.src_length = max_length - self.config.seq_len_limit
+
+ self.decode_status = dict()
+ self.tokenizer = self._load_tokenizer()
+ data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, "+
+ f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}, ")
+
+ def process_request(self, request, max_seq_len=None):
+ """
+ Preprocess the request
+
+ Args:
+ request (Dict): may contain text and messages fields
+
+ Returns:
+ bool: Whether preprocessing is successful
+ str: error message
+ """
+ if "eos_token_ids" not in request or request["eos_token_ids"] == [None]:
+ request["eos_token_ids"] = []
+ request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
+
+ if "input_ids" in request:
+ input_ids = request["input_ids"]
+ else:
+ input_ids = self.text2ids(request['text'])
+
+ if max_seq_len is not None and len(input_ids) > max_seq_len:
+ input_ids = input_ids[:max_seq_len-1]
+ request["input_ids"] = input_ids
+ data_processor_logger.info(f"processed request: {request}")
+ return request
+
+ def process_response(self, response_dict, **kwargs):
+ """
+ Preprocess the response
+
+ Args:
+ response_dict (Dict): response for engine, contain ids fields
+
+ Returns:
+ Dict: response contain text fields
+ """
+ is_end = response_dict.get("is_end", 0)
+ req_id = response_dict.get("req_id")
+ if "choices" in response_dict:
+ for i in range(len(response_dict["choices"])):
+ response_dict["token"] = self.ids2tokens(response_dict["choices"][i]["token_ids"], req_id)
+ return response_dict
+
+ token_ids = response_dict.get("token_ids", [])
+ response_dict["token"] = self.ids2tokens(token_ids, response_dict["req_id"])
+
+ if is_end:
+ response_dict["tokens_all"] = self.clear_request_status(req_id)
+ return response_dict
+
+ def text2ids(self, text):
+ """
+ text to ids
+ """
+ if self.tokenizer.chat_template is not None:
+ text = [text] if isinstance(text, str) else text
+ text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in text]
+
+ tokens = self.tokenizer(
+ text,
+ return_tensors="np",
+ padding=True,
+ truncation=True,
+ max_length=self.src_length,
+ add_special_tokens=self.tokenizer.chat_template is None,
+ )
+ return tokens["input_ids"][0]
+
+ def messages2ids(self, messages):
+ """
+ 将多轮对话转换为对话ID序列。
+
+ Args:
+ messages (List[List[Dict[str, Any]]]): 对话列表,每个对话是一个字典。
+
+ Returns:
+ List[int]: 对话ID序列,每个ID是一个整数。
+
+ """
+ return
+
+ def ids2tokens(self, token_id, task_id):
+ """
+ ids to tokens
+ """
+ if task_id not in self.decode_status:
+ # 记录deocde的prefix offset & read offset & history token ids & history token strings
+ self.decode_status[task_id] = [0, 0, [], []]
+
+ prefix_offset = self.decode_status[task_id][0]
+ read_offset = self.decode_status[task_id][1]
+ previous_token_ids = self.decode_status[task_id][2]
+ decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(
+ previous_token_ids + token_id, prefix_offset, read_offset)
+ self.decode_status[task_id][0] = prefix_offset
+ self.decode_status[task_id][1] = read_offset
+ self.decode_status[task_id][2] += token_id
+ self.decode_status[task_id][3].append(decode_str)
+ # 此处为流式返回中的每个token字符串结果,可自行添加处理
+ return decode_str
+
+ def _load_tokenizer(self):
+ """
+ load tokenizer
+
+ Returns:
+ tokenizer (AutoTokenizer)
+ """
+ return AutoTokenizer.from_pretrained(self.config.model_dir)
+
+ def clear_request_status(self, task_id):
+ """
+ clear request status
+ """
+ results_all = ""
+ if task_id in self.decode_status:
+ results_all = "".join(self.decode_status[task_id][3])
+ del self.decode_status[task_id]
+ return results_all
+
+ def get_eos_tokens_lens(self):
+ """
+ get eos_token_id lens
+ """
+ return len(get_eos_token_id(self.tokenizer, self.config.generation_config))
+
+ def get_eos_tokens(self):
+ """
+ get all eos_token_id
+ """
+ return get_eos_token_id(self.tokenizer, self.config.generation_config)
+
+ def get_pad_id(self):
+ """
+ get pad_token_id, if not pad_token_id, use eos_token
+ """
+ if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
+ return self.tokenizer.eos_token
+ return self.tokenizer.pad_token_id
diff --git a/llm/server/server/engine/__init__.py b/llm/server/server/engine/__init__.py
new file mode 100644
index 0000000000..fd05a92081
--- /dev/null
+++ b/llm/server/server/engine/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py
new file mode 100644
index 0000000000..d35b567e75
--- /dev/null
+++ b/llm/server/server/engine/config.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import sys
+from datetime import datetime
+from paddlenlp.generation import GenerationConfig
+
+from server.utils import model_server_logger
+
+
+class Config:
+ """
+ 初始化配置,各参数优先以环境变量配置的值为准
+ """
+
+ def __init__(self):
+ self.read_from_env()
+
+ def read_from_env(self):
+ """
+ 从环境变量中读取参数
+ """
+ env = os.environ
+ self.model_dir = env.get(
+ "MODEL_DIR", "/opt/output/Serving/models")
+ if not self.model_dir:
+ raise Exception("The parameter MODEL_DIR is None.")
+ self.mp_num = int(env.get("MP_NUM", 8))
+ self.config_json_file = env.get("CONFIG_JSON_FILE", "config.json")
+ self.model_config_path = os.path.join(self.model_dir, self.config_json_file)
+ if env.get("FD_MODEL_CONFIG_PATH", None):
+ self.model_config_path = env.get("FD_MODEL_CONFIG_PATH")
+
+ # 分布式配置文件
+ self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv")
+ if os.getenv("DISTRIBUTED_CONFIG", None):
+ self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG")
+
+ # 硬件配置信息
+ self.device = env.get("DEVICE", "GPU")
+ self.device_ids = ",".join([str(i) for i in range(self.mp_num)])
+ if self.device == "GPU":
+ self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES",
+ self.device_ids)
+ else:
+ raise Exception(f"unsupported device type: {self.device}")
+
+ # Triton服务层参数
+ self.max_prefill_batch = int(os.getenv("MAX_PREFILL_BATCH", 1))
+ if self.max_prefill_batch <= 0:
+ raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0")
+ self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0))
+
+ # 最大支持缓存的task数
+ self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128"))
+ # 如果没有配置PUSH_MODE_HTTP_PORT, 则只支持 GRPC 服务模式
+ self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1"))
+ if self.push_mode_http_port > 0:
+ grpc_port = os.getenv("GRPC_PORT", None)
+ if grpc_port is None:
+ raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0")
+ self.grpc_port = int(grpc_port)
+
+ # http服务线的worker数
+ self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1"))
+ if self.push_mode_http_workers < 1:
+ raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive")
+
+ # 导出Paddle代码版本,便于对比版本号
+ import paddle
+ self.paddle_commit_id = paddle.version.commit
+
+ # 探活时检测engine主循环是否正常的时间间隔
+ self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10))
+
+ # 与模型相关信息(注意要与导出的模型保持一致,否则存在效果问题)
+ self.dtype = env.get("DTYPE", "bfloat16")
+ self.block_size = int(env.get("BLOCK_SIZE", 64))
+ self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
+ self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
+
+ # 推理引擎配置
+ self.max_batch_size = int(env.get("BATCH_SIZE", 50))
+ self.max_seq_len = int(env.get("MAX_SEQ_LEN", 8192))
+ self.max_dec_len = int(env.get("MAX_DEC_LEN", 1024))
+ self.enc_dec_block_num = int(os.getenv("ENC_DEC_BLOCK_NUM", 2))
+ self.block_bs = float(env.get("BLOCK_BS", 50))
+ self.block_ratio = float(os.getenv("BLOCK_RATIO", 0.75))
+ self.bad_tokens = str(env.get("BAD_TOKENS", "-1"))
+ self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1))
+
+ # 引擎输入队列端口号
+ self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666))
+
+ # 是否开启探活服务
+ self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
+
+ # 环境变量配置MAX_SEQ_LEN,MAX_DEC_LEN将用于控制服务请求合法性检查
+ self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 7168))
+ self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024))
+
+ # warmup
+ self.use_warmup = int(os.getenv("USE_WARMUP", 0)) == 1
+
+ # uuid
+ self.shm_uuid = os.getenv("SHM_UUID", '')
+
+ # 加载 Generation 文件
+ try:
+ self.generation_config = GenerationConfig.from_pretrained(self.model_dir)
+ except:
+ model_server_logger.warning(
+ "Can't find generation config, so it will not use generation_config field in the model config"
+ )
+ self.generation_config = None
+
+ self.read_from_config()
+ self.postprocess()
+ self.check()
+
+ def postprocess(self):
+ """
+ 根据配置参数,计算部分额外的参数
+ """
+ if self.block_ratio >= 1.0:
+ self.enc_dec_block_num = (self.max_dec_len + self.block_size - 1) // self.block_size
+ self.max_query_block_num = (max(self.max_dec_len, self.max_seq_len) +
+ self.block_size - 1) // self.block_size
+ self.max_query_block_num = (self.max_dec_len + self.max_seq_len +
+ self.block_size - 1) // self.block_size
+ self.dec_token_num = self.enc_dec_block_num * self.block_size
+ self.total_block_num = int(self.block_bs * self.max_query_block_num)
+ self.max_block_num = int(self.total_block_num * self.block_ratio)
+ model_server_logger.info(f"max_block_num:{self.max_block_num}")
+
+ def check(self):
+ """
+ 检查参数配置合法性
+ """
+ assert self.max_batch_size <= 256, (
+ "The parameter `max_batch_size` is not allowed to exceed 256, "
+ "but now it's {}.".format(self.max_batch_size)
+ )
+ assert self.seq_len_limit <= self.max_seq_len, (
+ f"The seq_len_limit shouldn't greater than max_seq_len in model, "
+ f"which means the exported MAX_SEQ_LEN should less than "
+ f"{self.max_seq_len}, but now it's {self.seq_len_limit}."
+ )
+ assert self.dec_len_limit <= self.max_seq_len, (
+ f"The dec_len_limit shouldn't greater than max_seq_len in model, "
+ f"which means the exported MAX_DEC_LEN should less than "
+ f"{self.max_seq_len}, but now it's {self.dec_len_limit}."
+ )
+
+ def print(self, file=None):
+ """
+ 输出所有参数配置
+
+ file: 如若指定file路径,同时将日志以追加方式写入到另外的文件中
+ 解决当前日志系统仅保留7天,无法追查启动信息问题
+ """
+ model_server_logger.info(
+ "=================== Configuration Information ===============")
+ for k, v in self.__dict__.items():
+ if k == "generation_config" and v is not None:
+ for gck, gcv in v.to_dict().items():
+ model_server_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
+ else:
+ model_server_logger.info("{:<20}:{:<6}{}".format(k, "", v))
+ model_server_logger.info(
+ "=============================================================")
+ if file is not None:
+ f = open(file, "a")
+ now_time = datetime.now()
+ f.write(f"{now_time} configuration information as below,\n")
+ for k, v in self.__dict__.items():
+ f.write("{:<20}:{:<6}{}\n".format(k, "", v))
+ f.close()
+
+ def get_model_config(self):
+ """
+ 读取模型配置文件
+ """
+ model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
+ return model_config_json
+
+ def read_from_config(self):
+ """
+ 从配置文件中读取参数
+ """
+ from server.utils import get_logger
+ logger = get_logger("model_server", "infer_config.log")
+ config = self.get_model_config()
+
+ def reset_value(self, value_name, key, config):
+ if key in config:
+ value = config[key]
+ setattr(self, value_name, value)
+ logger.info(f"Reset parameter {value_name} = {value} from configuration.")
+
+ reset_value(self, "block_size", "infer_model_block_size", config)
+ reset_value(self, "max_seq_len", "infer_model_max_seq_len", config)
+
+ assert self.seq_len_limit <= self.max_seq_len, f"The loading model requires len(input_ids) <= {self.max_seq_len}, but now the setting MAX_SEQ_LEN={self.seq_len_limit}."
+ assert self.dec_len_limit <= self.max_seq_len, f"The loading model requires MAX_DEC_LEN <= {self.max_seq_len}, but now the setting MAX_DEC_LEN={self.dec_len_limit}."
+
+ def get_unique_name(self, name):
+ return name + f"_{self.shm_uuid}"
+
+ def __str__(self) -> str:
+ return json.dumps(self.__dict__, indent=4)
diff --git a/llm/server/server/engine/engine.py b/llm/server/server/engine/engine.py
new file mode 100644
index 0000000000..57fdde1d3e
--- /dev/null
+++ b/llm/server/server/engine/engine.py
@@ -0,0 +1,365 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import signal
+import subprocess
+import time
+import uuid
+import weakref
+import multiprocessing
+import numpy as np
+from datetime import datetime
+from multiprocessing import shared_memory
+
+from server.engine.task_queue_manager import (
+ TaskQueueManager,
+ launch_queue_service,
+)
+from server.engine.resource_manager import ResourceManager
+from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor
+from server.utils import model_server_logger
+
+
+class Engine(object):
+ """
+ 底层推理引擎,维护队列用于引擎使用
+ """
+ def __init__(self, cfg, token_processor):
+ self.cfg = cfg
+ self.resource_manager = ResourceManager(self.cfg)
+ self.token_processor = token_processor
+ self.token_processor.set_resource_manager(self.resource_manager)
+ self.is_started = False
+
+ self._init_engine_flags()
+ # 此处函数可考虑是否注释,添加后,如果引擎结束
+ # 会自动结束队列进程和推理infer进程
+ self._finalizer = weakref.finalize(self, self._exit_sub_services)
+
+ def start(self):
+ """
+ 初始化引擎所需的各进程
+ """
+ assert not self.is_started, "The engine is already started.!"
+ start_time = time.time()
+ # 启动队列进程(服务层与引擎层通信)服务
+ self.queue_service = self._start_tasks_queue_service()
+ self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
+
+ # 由于BeamSearch在后处理时依赖queue与infer.py进行通信
+ # 此处将tasks_queue共享给TokenProcessor
+ self.token_processor.tasks_queue = self.tasks_queue
+
+ self.infer_proc = self._start_infer_service()
+ model_server_logger.info("Waitting infer processes ready...")
+ while not self._infer_processes_ready():
+ time.sleep(1)
+ self.is_started = True
+
+ # 启动warmup
+ if self.cfg.use_warmup:
+ model_server_logger.info("Start warmup")
+ self._set_warmup_token_processor()
+ self.warmup()
+ self._del_warmup_token_processor()
+ model_server_logger.info("Warmup finish")
+
+ # 启动TokenProcessor子线程
+ self.token_processor.run()
+ model_server_logger.info("Infer processes are launched with {} seconds.".format(time.time() - start_time))
+
+ def warmup(self):
+ """
+ 通过构造测试数据进行推理,确保推理过程中不会出现OOM,能够正常进行推理
+ """
+ # 获取eos_token_id
+ from server.data.processor import DataProcessor
+ eos_token_ids = DataProcessor().get_eos_tokens()
+
+ # 构造测试任务数据
+ res_task = []
+ for j in range(2 * self.cfg.max_batch_size):
+ data = {
+ "input_ids": [5],
+ "req_id": j,
+ "max_dec_len": self.cfg.dec_len_limit,
+ "min_dec_len": int(self.cfg.dec_len_limit * 0.5) + 1,
+ "eos_token_ids": eos_token_ids
+ }
+ res_task.append(data)
+ for j in range(2 * self.cfg.max_prefill_batch):
+ data = {
+ "input_ids": [5] * self.cfg.seq_len_limit,
+ "req_id": j + 2 * self.cfg.max_batch_size,
+ "max_dec_len": 1,
+ "min_dec_len": 1,
+ "eos_token_ids": eos_token_ids
+ }
+ res_task.append(data)
+
+ # 插入任务
+ for x in res_task:
+ while self.available_batch() == 0 or not self.insert_tasks([x]):
+ time.sleep(0.0002)
+
+ self.token_processor._is_blocking = False
+ # 等待所有数据推理结束
+ while not self.all_tasks_finished():
+ time.sleep(1)
+
+ def insert_tasks(self, tasks):
+ """
+ 插入任务到引擎队列
+ """
+ if not isinstance(tasks, list):
+ tasks = [tasks]
+
+ for item in tasks:
+ item["schedule_start_time"] = datetime.now()
+
+ available_batch = np.sum(self.resource_manager.stop_flags)
+ if len(tasks) > available_batch:
+ model_server_logger.error("Inserting batch:{} exceeds the available batch:{}.".format(
+ len(tasks), available_batch))
+ model_server_logger.error("The exceeded part will be ignored!")
+ tasks = tasks[:available_batch]
+
+ for i in range(len(tasks)):
+ req_id = tasks[i]["req_id"]
+ input_token_num = len(tasks[i]["input_ids"])
+ if input_token_num >= self.cfg.max_seq_len - 1:
+ model_server_logger.warning(f"{req_id}: Input length:{input_token_num}, exceed the limits.")
+ tasks[i]["input_ids"] = tasks[i]["input_ids"][:self.cfg.max_seq_len - 1]
+ if "seq_len" in tasks[i] and "max_dec_len" not in tasks[i]:
+ tasks[i]["max_dec_len"] = tasks[i]["seq_len"]
+ # max_dec_len + input_token_num > MAX_SEQ_LEN
+ if input_token_num + tasks[i]["max_dec_len"] > self.cfg.max_seq_len:
+ tasks[i]["max_dec_len"] = self.cfg.max_seq_len - input_token_num
+ model_server_logger.warning("Force max_dec_len to be {} for req_id={}.".format(
+ tasks[i]["max_dec_len"], tasks[i]["req_id"]))
+ # min_dec_len + input_token_num > MAX_SEQ_LEN
+ if input_token_num + tasks[i]["min_dec_len"] > self.cfg.max_seq_len:
+ tasks[i]["min_dec_len"] = self.cfg.max_seq_len - input_token_num
+ model_server_logger.warning("Force min_dec_len to be {} for req_id={}.".format(
+ tasks[i]["min_dec_len"], tasks[i]["req_id"]))
+
+ tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
+ if not tasks:
+ return False
+
+ self.token_processor.number_of_tasks += len(tasks)
+ for i in range(len(tasks)):
+ self.token_processor.number_of_input_tokens += len(tasks[i]["input_ids"])
+
+ req_ids = [t["req_id"] for t in tasks]
+ model_server_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
+ self.tasks_queue.put((tasks, self.resource_manager.real_bsz))
+ return True
+
+ def task_is_finished(self, index):
+ """
+ 判断相应位置的任务是否完成
+ """
+ assert index < len(self.resource_manager.stop_flags)
+ return self.resource_manager.stop_flags[index]
+
+ def is_queue_empty(self):
+ """
+ 判断引擎队列是否为空
+ """
+ return self.tasks_queue.empty()
+
+ def is_resource_sufficient(self, input_token_num):
+ """
+ 根据输入的token id长度,判断引擎资源是否充足
+ """
+ return self.resource_manager.is_resource_sufficient(input_token_num)
+
+ def all_tasks_finished(self):
+ """
+ 判断是否所有的引擎正在计算的任务已完成
+ """
+ return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
+
+ def available_batch(self):
+ """
+ 引擎当前可用的最大Batch
+ """
+ return self.resource_manager.available_batch()
+
+ def available_block_num(self):
+ """
+ 引擎当前可用的block数量
+ """
+ return self.resource_manager.availabel_block_num()
+
+ def _set_warmup_token_processor(self):
+ """
+ 设置token_processor,用于warmup阶段
+ """
+ self.token_processor_backup = self.token_processor
+ self.token_processor = WarmUpTokenProcessor(self.cfg)
+ # 设置resource_manager
+ self.token_processor.set_resource_manager(self.resource_manager)
+ self.token_processor.tasks_queue = self.tasks_queue
+ # 启动TokenProcessor子线程
+ self.token_processor.run()
+
+ def _del_warmup_token_processor(self):
+ """
+ 删除token_processor,用于正常推理阶段
+ """
+ # 停止worker 线程
+ self.token_processor.stop()
+ del self.token_processor
+ # 恢复token_processor
+ self.token_processor = self.token_processor_backup
+ del self.token_processor_backup
+
+ def _infer_processes_ready(self):
+ """
+ 判断引擎是否初始化完成
+ """
+ if np.sum(self.flag_ready_array) == self.cfg.mp_num:
+ return True
+ return False
+
+ def _clear_engine_flags(self):
+ """
+ 清除共享内存
+ """
+ try:
+ self.shm_flag_ready.close()
+ self.shm_flag_ready.unlink()
+ self.shm_flag_has_block_step.close()
+ self.shm_flag_has_block_step.unlink()
+ except:
+ pass
+
+ def _init_engine_flags(self):
+ """
+ 初始化各共享内存,用于指示引擎状态
+ """
+ # 标记是否启动
+ flag_array = np.zeros([self.cfg.mp_num], dtype=np.int32)
+ try:
+ tmp = shared_memory.SharedMemory(
+ create=False, size=flag_array.nbytes, name=self.cfg.get_unique_name("shm_flag_infer_ready")
+ )
+ tmp.close()
+ tmp.unlink()
+ except:
+ pass
+ self.shm_flag_ready = shared_memory.SharedMemory(
+ create=True, size=flag_array.nbytes, name=self.cfg.get_unique_name("shm_flag_infer_ready")
+ )
+ self.flag_ready_array = np.ndarray(
+ flag_array.shape, dtype=flag_array.dtype, buffer=self.shm_flag_ready.buf
+ )
+ self.flag_ready_array[:] = 0
+
+ # 广播读取数据
+ broadcast_flag_array = np.zeros([1], dtype=np.int32)
+ try:
+ tmp = shared_memory.SharedMemory(
+ create=False,
+ size=broadcast_flag_array.nbytes,
+ name=self.cfg.get_unique_name("shm_pd_infer_flag_broadcast"),
+ )
+ tmp.close()
+ tmp.unlink()
+ except:
+ pass
+ self.shm_flag_broadcast = shared_memory.SharedMemory(
+ create=True, size=broadcast_flag_array.nbytes, name=self.cfg.get_unique_name("shm_pd_infer_flag_broadcast")
+ )
+ self.flag_broadcast_array = np.ndarray(
+ broadcast_flag_array.shape,
+ dtype=broadcast_flag_array.dtype,
+ buffer=self.shm_flag_broadcast.buf,
+ )
+ self.flag_broadcast_array[0] = 0
+
+ # 标记引擎是否有调度出去的query
+ has_block_step_flag_array = np.zeros([1], dtype=np.int32)
+ try:
+ tmp = shared_memory.SharedMemory(
+ create=False,
+ size=has_block_step_flag_array.nbytes,
+ name=self.cfg.get_unique_name("shm_flag_has_block_step"))
+ tmp.close()
+ tmp.unlink()
+ except:
+ pass
+ self.shm_flag_has_block_step = shared_memory.SharedMemory(
+ create=True,
+ size=has_block_step_flag_array.nbytes,
+ name=self.cfg.get_unique_name("shm_flag_has_block_step"))
+ self.flag_has_block_step_array = np.ndarray(
+ has_block_step_flag_array.shape,
+ dtype=has_block_step_flag_array.dtype,
+ buffer=self.shm_flag_has_block_step.buf)
+ self.flag_has_block_step_array[:] = 0
+
+ def _exit_sub_services(self):
+ if hasattr(self, "queue_service") and self.queue_service is not None:
+ self.queue_service.terminate()
+ self.queue_service.join()
+ if hasattr(self, "infer_proc") and self.infer_proc is not None:
+ os.killpg(self.infer_proc.pid, signal.SIGTERM)
+
+ def _start_tasks_queue_service(self):
+ p = multiprocessing.Process(target=launch_queue_service, args=(self.cfg.infer_port, self.cfg.mp_num))
+ p.start()
+ time.sleep(0.3)
+ if p.is_alive():
+ model_server_logger.info("start tasks queue service successfully")
+ else:
+ error_msg = "Failed to start tasks queue service, please check " \
+ "the log/task_queue_manager.log for details"
+ model_server_logger.info(error_msg)
+ raise Exception(error_msg)
+ return p
+
+ def _start_gpu_infer_service(self):
+ """
+ GPU模型推理进程启动
+ """
+ current_file_path = os.path.abspath(__file__)
+ current_dir_path = os.path.split(current_file_path)[0]
+ pd_cmd = "python3 -m paddle.distributed.launch "
+ py_script = os.path.join(current_dir_path, "infer.py")
+
+ arguments = (f" --devices {self.cfg.device_ids} {py_script} --model_dir {self.cfg.model_dir}"
+ f" --max_batch_size {self.cfg.max_batch_size} --max_seq_len {self.cfg.max_seq_len}"
+ f" --max_dec_len {self.cfg.max_dec_len}"
+ f" --max_block_num {self.cfg.total_block_num} --block_size {self.cfg.block_size}"
+ f" --use_cache_kv_int8 {self.cfg.use_cache_kv_int8}"
+ f" --enc_dec_block_num {self.cfg.enc_dec_block_num}"
+ f" --block_ratio {self.cfg.block_ratio} --dtype {self.cfg.dtype}")
+ pd_cmd = pd_cmd + arguments + " >log/launch_infer.log 2>&1"
+ model_server_logger.info("Launch infer service command: {}".format(pd_cmd))
+ p = subprocess.Popen(
+ pd_cmd,
+ shell=True,
+ preexec_fn=os.setsid,
+ )
+ return p
+
+ def _start_infer_service(self):
+ """
+ 启动模型推理进程
+ """
+ return self._start_gpu_infer_service()
diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py
new file mode 100644
index 0000000000..0d1bfaa607
--- /dev/null
+++ b/llm/server/server/engine/infer.py
@@ -0,0 +1,607 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import json
+import os
+import sys
+import time
+import numpy as np
+from multiprocessing import shared_memory
+from concurrent.futures import ThreadPoolExecutor
+
+import paddle
+import paddle.distributed as dist
+import paddle.distributed.fleet as fleet
+from paddlenlp_ops import step_paddle
+from paddlenlp.utils.llm_utils import get_rotary_position_embedding
+
+from server.utils import get_logger
+from server.engine.config import Config
+from task_queue_manager import TaskQueueManager
+from server.data.processor import DataProcessor
+
+File_Path = os.path.realpath(sys.argv[0])
+Dir_Path = os.path.dirname(File_Path)
+logger = get_logger("infer_server", "infer.log")
+
+
+class ModelRunner:
+ def __init__(self, args):
+ self.args = args
+
+ self.MAX_INFER_SEED = 9223372036854775806 # 2**63 - 1
+
+ self.config = Config()
+ self.model_cfg = self.config.get_model_config()
+ self.format_print_configuration()
+
+ self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
+ self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
+ self.args.hidden_size = self.model_cfg["hidden_size"]
+
+ self.nranks = dist.get_world_size()
+ self.init_dist_env()
+ self.rank = fleet.worker_index()
+
+ self.load_model_init_val()
+
+ self.share_inputs = {}
+ self.cache_kvs = {}
+ self.init_inputs()
+
+ self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)
+
+ model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
+ if not os.path.exists(model_rank_path):
+ model_rank_path = self.args.model_dir
+
+ self.infer_engine = InferenceEngine(model_dir=model_rank_path,
+ share_inputs=self.share_inputs,
+ cache_kvs=self.cache_kvs,
+ config=self.config,
+ mp_degree=self.nranks
+ )
+
+ def read_model_config(self):
+ """
+ 读取通用模型配置文件
+ """
+ model_config_json = json.load(open(self.config_file, 'r', encoding='utf-8'))
+ return model_config_json
+
+ def get_value(self, cfg, names):
+ if not isinstance(names, list):
+ names = [names]
+ for name in names:
+ if name in cfg:
+ return cfg[name]
+ break
+ raise Exception(
+ "Cannot find any one of key in {} in configuration file.".format(
+ names))
+
+ def format_print_configuration(self):
+ """
+ 输出配置信息
+ """
+ logger.info("=============== Model Information ==============")
+ for k, v in self.model_cfg.items():
+ logger.info("{:<20}:{:<6}{}".format(k, "", v))
+ logger.info("=============== Service Configuration ===============")
+ for k, v in vars(self.args).items():
+ logger.info("{:<20}:{:<6}{}".format(k, "", v))
+ logger.info("=====================================================\n")
+
+ def load_model_init_val(self):
+ self.top_p = self.model_cfg.get("top_p", 0.0)
+ self.temperature = self.model_cfg.get("temperature", 1.0)
+ self.rope_theta = self.model_cfg.get('rope_theta', 10000.0)
+ self.rope_scaling = self.model_cfg.get('rope_scaling', None)
+ self.penalty_score = self.model_cfg.get('penalty_score', 1.0)
+ self.frequency_score = self.model_cfg.get('frequency_score', 0.0)
+ self.presence_score = self.model_cfg.get('presence_score', 0.0)
+ self.min_length = self.model_cfg.get('min_length', 1)
+ self.max_length = self.model_cfg.get('max_length', 1024)
+
+ data_processor = DataProcessor()
+ # 允许用户配置一个额外的 eos_token 长度
+ self.eos_tokens_lens = data_processor.get_eos_tokens_lens() + 1
+ self.pad_token_id = data_processor.get_pad_id()
+
+ def init_dist_env(self, seed=20):
+ """
+ 初始化分布式环境
+ """
+ # start to init distributed env
+ strategy = fleet.DistributedStrategy()
+
+ strategy.hybrid_configs = {
+ "dp_degree": 1,
+ "mp_degree": self.nranks,
+ "pp_degree": 1,
+ "sharding_degree": 1,
+ }
+
+ # Set control in tensor parallel
+ strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
+ fleet.init(is_collective=True, strategy=strategy)
+
+ def init_inputs(self):
+ # 初始化输入,所有输入都share进引擎
+ if "num_key_value_heads" in self.model_cfg and \
+ self.model_cfg["num_key_value_heads"] is not None and \
+ int(self.model_cfg["num_key_value_heads"]) > 0:
+ kv_num_head = int(self.model_cfg["num_key_value_heads"]) // self.nranks
+ else:
+ kv_num_head = self.args.num_attention_heads // self.nranks
+
+ for i in range(self.args.num_layers):
+ if not self.args.use_cache_kv_int8:
+ cache_type = self.args.dtype
+ else:
+ cache_type = "uint8"
+
+ self.cache_kvs["key_caches_{}".format(i)] = paddle.full(shape=[
+ self.args.max_block_num, kv_num_head,
+ self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
+ ], fill_value=0, dtype=cache_type)
+ self.cache_kvs["value_caches_{}".format(i)] = paddle.full(shape=[
+ self.args.max_block_num, kv_num_head,
+ self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
+ ], fill_value=0, dtype=cache_type)
+
+ pre_max_block_num = (self.args.max_seq_len + self.args.block_size - 1) // self.args.block_size + self.args.enc_dec_block_num
+ self.share_inputs["block_tables"] = paddle.full(
+ shape=[self.args.max_batch_size, pre_max_block_num],
+ fill_value=-1,
+ dtype="int32")
+
+ self.share_inputs['pre_ids'] = paddle.to_tensor(
+ np.full((self.args.max_batch_size, self.args.max_dec_len), -1, dtype='int64'))
+
+ tmp_position_ids = paddle.arange(self.args.max_seq_len).reshape((1, -1))
+ self.share_inputs['rope_emb'] = get_rotary_position_embedding(tmp_position_ids,
+ self.args.hidden_size // self.args.num_attention_heads, self.rope_theta, self.rope_scaling)
+ self.share_inputs['input_ids'] = paddle.full(
+ shape=[self.args.max_batch_size, self.args.max_seq_len],
+ fill_value=self.pad_token_id,
+ dtype='int64')
+ self.share_inputs['top_p'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.top_p,
+ dtype="float32")
+ self.share_inputs['temperature'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.temperature,
+ dtype="float32")
+ self.share_inputs['eos_token_id'] = paddle.to_tensor(
+ np.zeros((self.eos_tokens_lens, 1)).reshape(-1, 1).astype("int64"))
+ self.share_inputs['penalty_score'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.penalty_score,
+ dtype="float32")
+ self.share_inputs['frequency_score'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.frequency_score,
+ dtype="float32")
+ self.share_inputs['presence_score'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.presence_score,
+ dtype="float32")
+ self.share_inputs['seq_lens_this_time'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
+ self.share_inputs['seq_lens_encoder'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=0,
+ dtype="int32")
+ self.share_inputs['step_seq_lens_encoder'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
+ self.share_inputs['seq_lens_decoder'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=0,
+ dtype="int32")
+ self.share_inputs['step_idx'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=0,
+ dtype="int64")
+ self.share_inputs['min_length'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.min_length,
+ dtype="int64")
+ self.share_inputs['max_length'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=self.max_length,
+ dtype="int64")
+ self.share_inputs['not_need_stop'] = paddle.full(shape=[1],
+ fill_value=False,
+ dtype="bool")
+ self.share_inputs['stop_flags'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=True,
+ dtype="bool")
+ self.share_inputs['stop_nums'] = paddle.full(shape=[1],
+ fill_value=self.args.max_batch_size,
+ dtype="int64")
+ self.share_inputs['bad_tokens'] = paddle.full(shape=[1],
+ fill_value=-1,
+ dtype="int64")
+ self.share_inputs['next_tokens'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=-1,
+ dtype="int64")
+ self.share_inputs['is_block_step'] = paddle.full(shape=[self.args.max_batch_size],
+ fill_value=False,
+ dtype="bool")
+ self.share_inputs['encoder_block_lens'] = paddle.full(shape=[self.args.max_batch_size],
+ fill_value=0,
+ dtype="int32")
+ self.share_inputs['step_block_list'] = paddle.full(shape=[self.args.max_batch_size],
+ fill_value=-1,
+ dtype="int32")
+ self.share_inputs['step_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32")
+ self.share_inputs['recover_block_list'] = paddle.full(shape=[self.args.max_batch_size],
+ fill_value=-1,
+ dtype="int32")
+ self.share_inputs['recover_lens'] = paddle.full(shape=[1],
+ fill_value=0,
+ dtype="int32")
+ self.share_inputs['need_block_list'] = paddle.full(shape=[self.args.max_batch_size],
+ fill_value=-1,
+ dtype="int32")
+ self.share_inputs['need_block_len'] = paddle.full(shape=[1],
+ fill_value=0,
+ dtype="int32")
+ self.share_inputs['used_list_len'] = paddle.full(shape=[self.args.max_batch_size],
+ fill_value=0,
+ dtype="int32")
+ self.share_inputs['infer_seed'] = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=0,
+ dtype="int64")
+ free_list = list(range(int(self.args.max_block_num * self.args.block_ratio)))
+ self.free_list_len = len(free_list)
+ self.share_inputs['free_list'] = paddle.to_tensor(free_list, dtype="int32")
+ self.share_inputs['free_list_len'] = paddle.full(shape=[1],
+ fill_value=self.free_list_len,
+ dtype="int32")
+
+ def dy_input_preprocess(self, tasks):
+ """
+ 动态插入部分额外处理
+ """
+ for i in range(len(tasks)):
+ task = tasks[i]
+ idx = task['idx']
+ length = len(task['input_ids'])
+ self.share_inputs['input_ids'][idx:idx + 1, :length] = np.array(task['input_ids'])
+ if len(task['eos_token_ids']) < self.eos_tokens_lens:
+ task['eos_token_ids'].append(task['eos_token_ids'][0])
+ self.share_inputs['eos_token_id'][:] = np.array(task['eos_token_ids'], dtype="int64").reshape(-1, 1)
+ self.share_inputs['pre_ids'][idx:idx + 1] = -1
+ self.share_inputs['top_p'][idx:idx + 1] = task.get('topp', 0.7)
+ self.share_inputs['temperature'][idx:idx + 1] = task.get('temperature', 0.95)
+ self.share_inputs['penalty_score'][idx:idx + 1] = task.get('penalty_score', 1.0)
+ self.share_inputs['frequency_score'][idx:idx + 1] = task.get('frequency_score', 0.0)
+ self.share_inputs['presence_score'][idx:idx + 1] = task.get('presence_score', 0.0)
+ self.share_inputs['seq_lens_this_time'][idx:idx + 1] = length
+ self.share_inputs['step_seq_lens_encoder'][idx:idx + 1] = length
+ self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length
+ self.share_inputs['seq_lens_decoder'][idx:idx + 1] = 0
+ self.share_inputs['step_idx'][idx:idx + 1] = 0
+ self.share_inputs['min_length'][idx:idx + 1] = task.get('min_dec_len', 1)
+ if "max_dec_len" in task:
+ max_dec_len = task['max_dec_len']
+ elif "seq_len" in task:
+ max_dec_len = task['seq_len']
+ else:
+ max_dec_len = self.args.max_dec_len
+ self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
+ self.share_inputs['stop_flags'][idx:idx + 1] = False
+
+ if "infer_seed" in task:
+ self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
+
+ encoder_block_num = len(task['block_tables'])
+ self.share_inputs['encoder_block_lens'][idx:idx + 1] = encoder_block_num
+ self.share_inputs["block_tables"][idx:idx + 1, :] = -1
+ self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
+ task['block_tables'], dtype="int32")
+
+ def step_cuda(self, seq_lens_this_time):
+ """
+ block调度
+ """
+ step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
+ self.share_inputs['step_seq_lens_encoder'],
+ self.share_inputs['seq_lens_encoder'],
+ self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
+ self.share_inputs['encoder_block_lens'],
+ self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
+ self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
+ self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
+ self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
+ self.share_inputs['free_list'], self.share_inputs['free_list_len'],
+ self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
+ self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
+ self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
+
+ def initialize_engine_ready_check_flag(self):
+ """
+ 初始化共享内存中引擎准备就绪标志变量
+ """
+ engine_ready_check_flag = np.zeros([1], dtype=np.int32)
+ shm_engine_ready_check_flag = shared_memory.SharedMemory(
+ name=self.config.get_unique_name("engine_ready_check_flag"))
+ engine_ready_check_flag_array = np.ndarray(engine_ready_check_flag.shape,
+ dtype=engine_ready_check_flag.dtype,
+ buffer=shm_engine_ready_check_flag.buf)
+ return shm_engine_ready_check_flag, engine_ready_check_flag_array
+
+ def initialize_engine_live_flag(self):
+ """
+ 创建用来表明当前infer引擎进程存在的共享内存变量
+ """
+ infer_live_flag_shm = shared_memory.SharedMemory(create=True,
+ size=1,
+ name=self.config.get_unique_name("shm_flag_infer_{}_live".format(self.rank)))
+ return infer_live_flag_shm
+
+ def initialize_engine_healthy_recorded_time_flag(self):
+ """
+ 初始化共享内存中记录引擎健康的时间戳变量
+ """
+ engine_healthy_recorded_time = np.zeros([1], dtype=float)
+ shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
+ name=self.config.get_unique_name("engine_healthy_recorded_time"))
+ engine_healthy_recorded_time_array = np.ndarray(engine_healthy_recorded_time.shape,
+ dtype=engine_healthy_recorded_time.dtype,
+ buffer=shm_engine_healthy_recorded_time.buf)
+ return shm_engine_healthy_recorded_time, engine_healthy_recorded_time_array
+
+ def run(self):
+ # 共享内存设置 #
+ flag_array = np.zeros([1], dtype=np.int32)
+ shm_flag_broadcast = shared_memory.SharedMemory(
+ name=self.config.get_unique_name("shm_pd_infer_flag_broadcast"))
+ flag_broadcast_array = np.ndarray(flag_array.shape,
+ dtype=flag_array.dtype,
+ buffer=shm_flag_broadcast.buf)
+
+ flag_array = np.zeros([self.nranks], dtype=np.int32)
+ shm_flag_ready = shared_memory.SharedMemory(name=self.config.get_unique_name("shm_flag_infer_ready"))
+ flag_ready_array = np.ndarray(flag_array.shape,
+ dtype=flag_array.dtype,
+ buffer=shm_flag_ready.buf)
+ flag_ready_array[self.rank] = 1 # 已初始化完毕
+
+ flag_array = np.zeros([1], dtype=np.int32)
+ shm_flag_has_block_step = shared_memory.SharedMemory(name=self.config.get_unique_name("shm_flag_has_block_step"))
+ flag_has_block_step_array = np.ndarray(flag_array.shape,
+ dtype=flag_array.dtype,
+ buffer=shm_flag_has_block_step.buf)
+
+ use_custom_health_checker = self.config.use_custom_health_checker
+ if use_custom_health_checker:
+ shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag()
+ engine_ready_check_flag_array[0] = 1
+ shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag()
+ engine_healthy_recorded_time_array[0] = time.time()
+ # 创建代表infer存活的共享变量
+ infer_live_flag_shm = self.initialize_engine_live_flag()
+
+ infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=4,
+ dtype="int64")
+
+ thread_executor = ThreadPoolExecutor(max_workers=1)
+ seq_lens_this_time = None
+ real_bsz = None
+
+ while 1:
+ if use_custom_health_checker:
+ engine_healthy_recorded_time_array[0] = time.time()
+
+ if self.rank == 0:
+ # 队列不为空, 可取出数据
+ if not self.infer_queue.empty():
+ flag_broadcast_array[0] = 1
+
+ if self.nranks > 1:
+ paddle.distributed.barrier()
+
+ if flag_broadcast_array[0] == 1:
+ logger.info(f'rank: {self.rank} start to get')
+ if seq_lens_this_time is not None:
+ self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time
+
+ tasks, read_finish = self.infer_queue.get()
+ if read_finish:
+ flag_broadcast_array[0] = 0
+
+ req_dicts = []
+ for req_dict, bsz in tasks:
+ real_bsz = int(bsz)
+ req_dicts.extend(req_dict)
+ logger.info(
+ f'rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}'
+ )
+
+ self.dy_input_preprocess(req_dicts)
+ # 特殊处理seq_lens
+ seq_lens_this_time = copy.deepcopy(
+ self.share_inputs['seq_lens_this_time'][:real_bsz])
+ self.infer_engine.seq_lens_handle.share_external_data(
+ seq_lens_this_time)
+ self.share_inputs['not_need_stop'][0] = True
+
+ if not self.share_inputs['not_need_stop']:
+ if self.nranks > 1:
+ paddle.distributed.barrier()
+
+ time.sleep(0.001)
+ continue
+ self.infer_engine.predictor.run()
+
+ # 自增随机种子,让每次计算的种子不一样
+ self.share_inputs['infer_seed'].add_(infer_seed_increment)
+ self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
+
+ if self.free_list_len > 0:
+ self.step_cuda(seq_lens_this_time)
+
+
+class InferenceEngine(object):
+ """
+ Model Parallel Inference Engine
+
+ Args:
+ model_dir (string): root directory of inference model
+ mp_degree (int): model parallel size
+ """
+ def __init__(self, model_dir, share_inputs, cache_kvs, config, mp_degree=1):
+ """
+ 初始化模型目录,并设置多进程环境。
+ """
+ self.config = config
+ self.model_dir = model_dir
+ self.mp_degree = mp_degree
+
+ self.share_inputs = share_inputs
+ self.cache_kvs = cache_kvs
+
+ if mp_degree == 1:
+ self.nranks = 1
+ self.rank = 0
+ else:
+ self.nranks = fleet.worker_num()
+ self.rank = fleet.worker_index()
+
+ self._init_predictor()
+ self.share_data()
+
+ def _init_predictor(self):
+ """predictor init"""
+ device_id = self.rank % 8
+ self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
+ self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
+ config = paddle.inference.Config(self.model_file, self.param_file)
+
+ # config.enable_memory_optim()
+ config.switch_ir_optim(False)
+ config.enable_use_gpu(100, device_id)
+
+ # distributed config
+ if self.mp_degree > 1:
+ trainer_endpoints = fleet.worker_endpoints()
+ current_endpoint = trainer_endpoints[self.rank]
+ dist_config = config.dist_config()
+ dist_config.set_ranks(self.nranks, self.rank)
+ dist_config.set_endpoints(trainer_endpoints, current_endpoint)
+ dist_config.enable_dist_model(True)
+ if self.config.distributed_config_path:
+ dist_config.set_comm_init_config(self.config.distributed_config_path)
+ else:
+ raise Exception("Please set DISTRIBUTED_CONFIG env variable.")
+ logger.warning(
+ f"Use default distributed config, please set env DISTRIBUTED_CONFIG"
+ )
+ dist_config.set_comm_init_config(
+ os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))
+ # dist_config.set_comm_init_config(os.path.join(Dir_Path + "/config", "rank_mapping.csv"))
+ config.set_dist_config(dist_config)
+ self.predictor = paddle.inference.create_predictor(config)
+ self.input_names = self.predictor.get_input_names()
+ self.seq_lens_handle = self.predictor.get_input_handle('seq_lens_this_time')
+
+ def share_data(self):
+ """
+ 分享不拷贝数据
+ """
+ for name in self.input_names:
+ if "caches" in name:
+ input_tensor = self.predictor.get_input_handle(name)
+ input_tensor.share_external_data(self.cache_kvs[name])
+ continue
+ if "seq_lens_this_time" in name:
+ continue
+ input_tensor = self.predictor.get_input_handle(name)
+ input_tensor.share_external_data(self.share_inputs[name])
+
+ def predict(self, real_bsz):
+ """
+ predict
+ """
+ seq_lens_this_time = copy.deepcopy(
+ self.share_inputs['seq_lens_this_time'][:real_bsz])
+ self.seq_lens_handle.share_external_data(seq_lens_this_time)
+ self.share_inputs['not_need_stop'][0] = True
+ while self.share_inputs['not_need_stop']:
+ self.predictor.run()
+ self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time
+
+
+def parse_args():
+ """
+ 从命令行解析参数
+ """
+ parser = argparse.ArgumentParser("FastDeploy LLM Inference")
+ parser.add_argument('-m',
+ '--model_dir',
+ type=str,
+ default='./output',
+ help='model dir')
+ parser.add_argument('-mp',
+ '--mp_degree',
+ type=int,
+ default=1,
+ help='mp degree')
+ parser.add_argument('-mbs',
+ '--max_batch_size',
+ type=int,
+ default=34,
+ help='max batch size')
+ parser.add_argument('--max_block_num', type=int, default=2000)
+ parser.add_argument("--block_size", type=int, default=128)
+ parser.add_argument('--max_seq_len',
+ type=int,
+ default=3072,
+ help='max_seq_len')
+ parser.add_argument('--max_dec_len',
+ type=int,
+ default=1024,
+ help='max_dec_len')
+ parser.add_argument('--use_cache_kv_int8',
+ type=int,
+ default=0,
+ help='use cache kv int8')
+ parser.add_argument('--dtype',
+ type=str,
+ default="bfloat16",
+ help='input dtype')
+ parser.add_argument('--enc_dec_block_num',
+ type=int,
+ default=1,
+ help="encoder's decoder num")
+ parser.add_argument('--block_ratio',
+ type=float,
+ default=0.7,
+ help="block ratio")
+ parser.add_argument('--first_token_id',
+ type=int,
+ default=1,
+ help="first token id")
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ """
+ 启动推理引擎并进行预测
+ """
+ args = parse_args()
+ model_runner = ModelRunner(args)
+ model_runner.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/server/server/engine/resource_manager.py b/llm/server/server/engine/resource_manager.py
new file mode 100644
index 0000000000..d446e198e7
--- /dev/null
+++ b/llm/server/server/engine/resource_manager.py
@@ -0,0 +1,190 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import random
+import threading
+import time
+
+import numpy as np
+from server.utils import model_server_logger
+
+
+class ResourceManager(object):
+ """
+ 用于记录和分配引擎的资源
+ """
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.stop_flags = [True] * cfg.max_batch_size
+ self.free_list = list(range(cfg.max_block_num - 1, -1, -1))
+ self.tasks_list = [None] * self.cfg.max_batch_size
+ # 引擎当前的batch情况
+ self.real_bsz = 0
+ model_server_logger.info(f"{self.info()}")
+
+ def get_required_block_number(self, input_token_num):
+ """
+ 计算需要多少Block资源
+ """
+ block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
+ return block_num
+
+ def get_encoder_block_number(self, input_token_num):
+ """
+ 获取编码器所需的block数目
+ """
+ enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
+ return enc_block_num
+
+ def get_decoder_block_number(self):
+ """
+ 获取解码器所需的block数目
+ """
+ return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
+
+ def total_block_number(self):
+ """
+ 返回服务启动时预分配的block数量
+ """
+ return self.cfg.max_block_num
+
+ def _get_block_tables(self, input_token_num, required_type="all"):
+ """
+ 分配显存资源
+ """
+ if required_type == "all":
+ block_num = self.get_required_block_number(input_token_num)
+ elif required_type == "encoder":
+ block_num = self.get_encoder_block_number(input_token_num)
+ elif required_type == "decoder":
+ block_num = self.get_decoder_block_number()
+ else:
+ raise ValueError('unknown required type')
+ block_num = min(block_num, self.cfg.max_query_block_num)
+ block_list = list()
+ if block_num > len(self.free_list):
+ model_server_logger.error("block_num:{0} > free_list len:{1}".format(block_num, len(self.free_list)))
+ return block_list
+ for _ in range(block_num):
+ used_block_id = self.free_list.pop()
+ block_list.append(used_block_id)
+ model_server_logger.info(f"dispatch {len(block_list)} blocks.")
+ return block_list
+
+ def _recycle_block_tables(self, block_tables):
+ """
+ 回收显存资源blocks
+ """
+ ori_number = len(self.free_list)
+ self.free_list.extend(block_tables)
+ # self.free_list = list(set(self.free_list + block_tables))
+ cur_number = len(self.free_list)
+ model_server_logger.info(f"recycle {cur_number - ori_number} blocks.")
+
+ def available_batch(self):
+ """
+ 引擎当前可用最大Batch
+ """
+ return np.sum(self.stop_flags)
+
+ def availabel_block_num(self):
+ """
+ 引擎当前可用的block数量
+ """
+ return len(self.free_list)
+
+ def is_resource_sufficient(self, input_token_num):
+ """
+ 判断当前可用资源是否满足新的需求
+ """
+ if self.available_batch() < 1:
+ return False
+ block_num = self.get_required_block_number(input_token_num)
+ if block_num > self.availabel_block_num():
+ return False
+ return True
+
+ def allocate_resources_for_new_tasks(self, tasks):
+ """
+ 为新任务分配资源
+ """
+
+ allocated_position = 0 # 新任务插入的位置
+ processing_task_index = 0 # 当前正在处理的任务index
+ processed_tasks = list()
+ while allocated_position < self.cfg.max_batch_size:
+ if processing_task_index >= len(tasks):
+ break
+
+ if len(tasks[processing_task_index]["input_ids"]) > self.cfg.max_seq_len:
+ model_server_logger.error("req_id: {0} input_ids len:{1} > {2}".format(
+ tasks[
+ processing_task_index]["req_id"], len(tasks[
+ processing_task_index]["input_ids"]), self.cfg.max_seq_len
+ ))
+ processing_task_index += 1
+ continue
+
+ can_insert = False
+ while allocated_position + 1 <= self.cfg.max_batch_size:
+ if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
+ can_insert = True
+ break
+ allocated_position += 1
+ if can_insert:
+ if self.stop_flags[allocated_position]:
+ task = copy.deepcopy(tasks[processing_task_index])
+
+ if not isinstance(task["eos_token_ids"], list):
+ task["eos_token_ids"] = [task["eos_token_ids"]]
+
+ if "infer_seed" in task and task["infer_seed"]:
+ task["infer_seed"] = int(task["infer_seed"])
+ else:
+ task["infer_seed"] = random.randint(0, 9223372036854775807)
+ task["idx"] = allocated_position
+ task["block_tables"] = self._get_block_tables(len(task["input_ids"]))
+ if not task["block_tables"]:
+ model_server_logger.error("req_id: {0} block_tables is empty".format(task["req_id"]))
+ continue
+
+ processed_tasks.append(task)
+ self.stop_flags[allocated_position] = False
+ task["inference_start_time"] = time.time()
+ task["inference_time_cost"] = -1.0
+ task["tokens_all_num"] = int(0)
+ self.tasks_list[allocated_position] = task
+ model_server_logger.info(f"allocate req_id: {task['req_id']}, "
+ f"allocated_position:{allocated_position}, input_ids_length: {len(task['input_ids'])}")
+ allocated_position += 1
+ processing_task_index += 1
+
+ # 统计引擎正在推理时的batch size
+ for i in range(self.cfg.max_batch_size - 1, -1, -1):
+ if not self.stop_flags[i]:
+ self.real_bsz = i + 1
+ break
+
+ model_server_logger.info("in num:{0} new task num:{1} real_bsz is:{2}".format(
+ len(tasks), len(processed_tasks), self.real_bsz))
+ model_server_logger.info(f"{self.info()}")
+ return processed_tasks
+
+ def info(self):
+ info = f"ResourceManager info, " \
+ f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
+ f"availabel_block_num: {self.availabel_block_num()}, available_batch: {self.available_batch()}"
+ return info
diff --git a/llm/server/server/engine/task_queue_manager.py b/llm/server/server/engine/task_queue_manager.py
new file mode 100644
index 0000000000..678946c03c
--- /dev/null
+++ b/llm/server/server/engine/task_queue_manager.py
@@ -0,0 +1,159 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import threading
+import time
+from queue import Queue
+from multiprocessing.managers import (
+ AcquirerProxy,
+ BaseManager,
+ ListProxy,
+ Value,
+ ValueProxy,
+)
+
+from server.utils import get_logger
+
+logger = get_logger("infer_server", "task_queue_manager.log")
+
+
+class QueueManager(BaseManager):
+ """
+ 基础类
+ """
+
+ pass
+
+
+class TaskQueueManager(object):
+ """
+ 管理类
+ """
+
+ def __init__(self, rank=0, mp_num=8, port=56666):
+ """
+ 初始化函数,用于创建对象时进行初始化操作。
+ """
+ self.max_get_num = int(os.getenv("ENGINE_MAX_NEED_NUM", 0))
+ QueueManager.register('get_list')
+ QueueManager.register('get_value')
+ QueueManager.register('get_lock')
+ QueueManager.register('get_barrier1')
+ QueueManager.register('get_barrier2')
+ QueueManager.register('get_queue')
+
+ self.client_manager = QueueManager(address=('127.0.0.1', port),
+ authkey=b'infer_queue'
+ )
+ self.client_manager.connect()
+ self.list = self.client_manager.get_list()
+ self.value = self.client_manager.get_value()
+ self.lock = self.client_manager.get_lock()
+ self.barrier1 = self.client_manager.get_barrier1()
+ self.barrier2 = self.client_manager.get_barrier2()
+ self.queue = self.client_manager.get_queue()
+ self.mp_num = mp_num
+ self.rank = rank
+ self.position = 1 << rank
+ self.total_num = (1 << self.mp_num) - 1
+ logger.info(f"init task queue manager success, rank: {rank}")
+
+ def empty(self):
+ """
+ 暴露至推理端,用于判断队列是否为空
+ """
+ try:
+ return len(self.list) == 0
+ except Exception as e:
+ logger.error(f"empty function meets error: {e}")
+ raise e
+
+ def put(self, item):
+ """
+ 向队列中添加数据
+ """
+ self.lock.acquire()
+ if 0 < self.value.get() < self.total_num:
+ self.lock.release()
+ while 0 < self.value.get() < self.total_num:
+ time.sleep(0.001)
+ logger.info("put item to queue wait finish")
+ self.lock.acquire()
+ if self.max_get_num <= 0 and self.value.get() == self.total_num:
+ self.list[:] = []
+ self.value.set(0)
+ self.list.append(item)
+ self.lock.release()
+ logger.info("put item to queue success")
+
+ def get(self):
+ """
+ 从队列中获取数据
+ """
+ input_list = []
+ read_finish = False
+ self.lock.acquire()
+ if self.value.get() & self.position == 0 and len(self.list) > 0:
+ # 控制进入引擎的输入数量. 默认服务中所有输入都拷贝进引擎一起处理
+ if self.max_get_num > 0:
+ input_list.extend(self.list[: self.max_get_num])
+ else:
+ input_list.extend(self.list[:])
+ set_value = self.value.get() | self.position
+ logger.info("rank: {0} set_value: {1}".format(self.rank, set_value))
+ if set_value >= self.total_num:
+ if self.max_get_num > 0:
+ for i in range(self.max_get_num):
+ self.list.pop(0)
+ else:
+ self.list[:] = []
+ set_value = 0
+ read_finish = True
+ self.value.set(set_value)
+ self.lock.release()
+ return input_list, read_finish
+
+
+def launch_queue_service(port, num_workers):
+ """
+ 启动进程间通信队列服务
+
+ port: 监听端口号
+ num_workers: infer进程的数量
+ """
+ try:
+ logger.info(f"start launch queue service, port:{port}")
+ value = Value("i", 0)
+ QueueManager.register("get_value", callable=lambda: value, proxytype=ValueProxy)
+ List = list()
+ QueueManager.register("get_list", callable=lambda: List, proxytype=ListProxy)
+ lock = threading.Lock()
+ QueueManager.register('get_lock',
+ callable=lambda: lock,
+ proxytype=AcquirerProxy)
+ barrier1 = threading.Barrier(num_workers)
+ QueueManager.register('get_barrier1', callable=lambda: barrier1)
+ barrier2 = threading.Barrier(num_workers)
+ QueueManager.register('get_barrier2', callable=lambda: barrier2)
+ q = Queue()
+ QueueManager.register("get_queue", callable=lambda: q)
+ m = QueueManager(address=('127.0.0.1', port), authkey=b'infer_queue')
+ s = m.get_server()
+ logger.info("launch queue service success")
+ s.serve_forever()
+ logger.info("finish queue service")
+ except Exception as e:
+ logger.error(f"launch queue service failed, error_msg: {e}")
+ raise e
diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/token_processor.py
new file mode 100644
index 0000000000..d943bb3e9a
--- /dev/null
+++ b/llm/server/server/engine/token_processor.py
@@ -0,0 +1,244 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import threading
+import time
+import traceback
+import numpy as np
+
+from collections import Counter
+from datetime import datetime
+from paddlenlp_ops import get_output
+from server.utils import datetime_diff, model_server_logger, monitor_logger
+
+
+class TokenProcessor(object):
+ """
+ 持续从Paddle底层引擎队列中获取生成Token/Score,并进行处理
+ """
+ def __init__(self, cfg):
+ import paddle
+ paddle.device.set_device("cpu")
+ # 服务配置
+ self.cfg = cfg
+ # 引擎状态
+ self.resource_manager = None
+ # 记录每个请求的当前所有生成Token
+ self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]
+
+ self.tokens_counter = Counter()
+ self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
+ self.worker = None
+
+ self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600"))
+ assert self.record_time_interval < 3600, "The RECORD_TIME_INTERVAL cannot exceed 3600."
+ self.statics_start_time = time.time()
+ self.number_of_tasks = 0
+ self.number_of_input_tokens = 0
+ self.number_of_output_tokens = 0
+
+ def set_resource_manager(self, resource_manager):
+ """
+ 设置ResourceManager
+ """
+ assert self.resource_manager is None, "The resource manager is not None, cannot set again."
+ self.resource_manager = resource_manager
+
+ def run(self):
+ """
+ 启动子线程,持续处理生成Token
+ """
+ assert self.resource_manager is not None, "The resource manager is None, cannot run."
+ if self.worker is not None:
+ raise Exception("Worker is already running!")
+
+ self.worker = threading.Thread(target=self.process_sampling_results, args=())
+ self.worker.daemon = True
+ self.worker.start()
+
+ def process_sampling_results(self):
+ """
+ 循环获取输出,并处理数据
+ """
+ while True:
+ try:
+ rank_id = 0
+ is_blocking = True
+ get_output(self.output_tokens, rank_id, is_blocking)
+
+ if self.output_tokens[0, 0] == -2:
+ continue
+ self._process_batch_output()
+ except Exception as e:
+ model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
+
+ def postprocess(self, batch_result, exist_finished_task=False):
+ """
+ 生成单步结果后处理函数
+ """
+ result_dir = "./generate_token_results"
+ if not os.path.exists(result_dir):
+ os.makedirs(result_dir)
+ for result in batch_result:
+ result_file = os.path.join(result_dir, result["req_id"])
+ with open(result_file, "a") as f:
+ f.write("{}\n".format(result))
+
+ def _get_single_result(self, i, task_id, token_id, task):
+ """
+ 处理单步生成结果
+ """
+ inference_time_cost = time.time() - task["inference_start_time"]
+ task["inference_time_cost"] = inference_time_cost
+ task["tokens_all_num"] = len(self.all_tokens[i])
+ task["inference_current_step_time"] = datetime.now()
+ result = {
+ "req_id": task_id,
+ "is_end": 0,
+ "token_ids": [token_id],
+ "send_idx": self.tokens_counter[task_id],
+ "inference_time_cost": inference_time_cost,
+ "infer_seed": task["infer_seed"],
+ "return_all_tokens": task.get("return_all_tokens", False),
+ }
+
+ # 收集benchmark信息
+ if task.get("benchmark"):
+ keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time",
+ "inference_start_time", "inference_current_step_time"]
+ for key in keys:
+ if key in task:
+ result[key] = str(task[key])
+
+ # 生成结束符时,额外填充部分信息
+ if token_id in task["eos_token_ids"]:
+ result["is_end"] = 1
+ result["token_ids"] = []
+ result["tokens_all_num"] = len(self.all_tokens[i]) + 1
+ result["tokens_all_ids"] = self.all_tokens[i]
+
+ # 生成请求的完整日志,用于平台监控
+ info_dict = {}
+ info_dict["req_id"] = task["req_id"]
+ info_dict["input_token_num"] = len(task["input_ids"])
+ info_dict["output_token_num"] = len(self.all_tokens[i])
+ if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
+ info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
+ task["preprocess_end_time"])
+ if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
+ info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
+ task["schedule_start_time"])
+ info_dict["inference_time_cost"] = task["inference_time_cost"]
+ info_dict["version"] = "4.6"
+ info_dict["timestamp"] = time.time()
+ monitor_logger.info(f"{info_dict}")
+
+ return result
+
+ def _recycle_resources(self, task_id, index, task):
+ """
+ 对于已完成的任务,回收资源
+ """
+ self.resource_manager.stop_flags[index] = True
+ self.resource_manager.tasks_list[index] = None
+ self.resource_manager._recycle_block_tables(task["block_tables"])
+ if task_id in self.tokens_counter:
+ del self.tokens_counter[task_id]
+ self.all_tokens[index] = list()
+
+ def _recycle_beam_resources(self, task_id_list, index_list, block_tables):
+ assert len(task_id_list) == len(index_list), \
+ f"{len(task_id_list)} task_id don't equal to {len(index_list)} index"
+ self.resource_manager._recycle_block_tables(block_tables)
+ for i in range(len(task_id_list)):
+ task_id = task_id_list[i]
+ index = index_list[i]
+ self.resource_manager.tasks_list[index] = None
+ self.resource_manager.stop_flags[index] = True
+ if task_id in self.tokens_counter:
+ del self.tokens_counter[task_id]
+ self.all_tokens[index] = list()
+
+ def _process_batch_output(self):
+ """
+ 处理一个batch的输出结果
+ """
+ tokens = self.output_tokens.numpy()
+ batch = self.output_tokens[1, 0]
+ tokens = tokens[2:batch + 2]
+
+ batch_result = list()
+ # 用于判断当前此批结果中是否存在已完成的任务
+ exist_finished_task = False
+ for i in range(batch):
+ if self.resource_manager.stop_flags[i]:
+ continue
+
+ token_id = int(tokens[i, 0])
+ if token_id < 0:
+ continue
+
+ task = self.resource_manager.tasks_list[i]
+
+ task_id = task["req_id"]
+ result = self._get_single_result(i, task_id, token_id, task)
+
+ self.tokens_counter[task_id] += 1
+ if token_id not in task["eos_token_ids"]:
+ self.all_tokens[i].append(token_id)
+
+ self.number_of_output_tokens += 1
+ if token_id in task["eos_token_ids"]:
+ self._recycle_resources(task_id, i, task)
+ model_server_logger.info("req_id: {0} finished".format(task_id))
+ model_server_logger.info(f"{self.resource_manager.info()}")
+ exist_finished_task = True
+ batch_result.append(result)
+
+ self.postprocess(batch_result, exist_finished_task)
+
+
+class WarmUpTokenProcessor(TokenProcessor):
+ """
+ 创建warm up服务的Processor
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self._is_running = True
+ self._is_blocking = True
+
+ def postprocess(self, batch_result, exist_finished_task=False):
+ pass
+
+ def process_sampling_results(self):
+ """
+ 循环获取输出,并处理数据
+ """
+ while self._is_running:
+ try:
+ rank_id = 0
+ get_output(self.output_tokens, rank_id, self._is_blocking)
+
+ if self.output_tokens[0, 0] == -2:
+ continue
+ self._process_batch_output()
+ except Exception as e:
+ model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
+
+ def stop(self):
+ self._is_running = False
+ self.worker.join()
+ model_server_logger.info("warm up thread stop")
+ del self.worker
diff --git a/llm/server/server/http_server/__init__.py b/llm/server/server/http_server/__init__.py
new file mode 100644
index 0000000000..fd05a92081
--- /dev/null
+++ b/llm/server/server/http_server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/llm/server/server/http_server/api.py b/llm/server/server/http_server/api.py
new file mode 100644
index 0000000000..159a9144db
--- /dev/null
+++ b/llm/server/server/http_server/api.py
@@ -0,0 +1,163 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import queue
+import time
+import uuid
+from datetime import datetime
+from functools import partial
+from typing import Dict, List, Optional
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+from pydantic import BaseModel, Field
+from tritonclient import utils as triton_utils
+
+
+class Req(BaseModel):
+ """请求参数的类"""
+ # 传入模型服务的请求参数
+ req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ input_ids: Optional[List[int]] = None
+ text: Optional[str] = None
+ messages: Optional[List] = None
+ max_dec_len: Optional[int] = None
+ seq_len: Optional[int] = None # 保留seq_len为了兼容支持
+ min_dec_len: Optional[int] = None
+ temperature: Optional[float] = None
+ topp: Optional[float] = None
+ penalty_score: Optional[float] = None
+ frequency_score: Optional[float] = None
+ presence_score: Optional[float] = None
+ system: Optional[str] = None
+ return_all_tokens: Optional[bool] = None
+ eos_token_ids: Optional[List[int]] = None
+ benchmark: bool = False
+ # http服务使用的请求参数
+ stream: bool = False
+ timeout: int = 300
+
+ def to_dict_for_infer(self):
+ """将请求参数转化为字典,去掉为None的字段,避免传递给模型服务出错"""
+ req_dict = {}
+ for key, value in self.dict().items():
+ if value is not None:
+ req_dict[key] = value
+ return req_dict
+
+
+def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict:
+ """
+ 基于Triton推理服务的聊天补全结果的生成器。
+ Args:
+ infer_grpc_url (str): Triton推理服务的gRPC URL。
+ req (Request): 聊天补全请求。
+ yield_json (bool): 是否返回json格式,否则返回Resp类
+ Returns:
+ dict: 聊天补全结果的生成器。
+ 如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+ """
+ class _TritonOutputData:
+ """接收Triton服务返回的数据"""
+ def __init__(self):
+ self._completed_requests = queue.Queue()
+
+ def _triton_callback(output_data, result, error):
+ """Triton客户端的回调函数"""
+ if error:
+ output_data._completed_requests.put(error)
+ else:
+ output_data._completed_requests.put(result)
+
+ def _format_resp(resp_dict):
+ if yield_json:
+ return json.dumps(resp_dict, ensure_ascii=False) + "\n"
+ else:
+ return resp_dict
+
+ # 准备请求数据
+ timeout = req.timeout
+ req_id = req.req_id
+ req_dict = req.to_dict_for_infer()
+ http_received_time = datetime.now()
+
+ inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
+ inputs[0].set_data_from_numpy(np.array([json.dumps([req_dict])], dtype=np.object_))
+ outputs = [grpcclient.InferRequestedOutput("OUT")]
+ output_data = _TritonOutputData()
+
+ # 建立连接
+ with grpcclient.InferenceServerClient(url=infer_grpc_url, verbose=False) as triton_client:
+ triton_client.start_stream(callback=partial(_triton_callback, output_data))
+
+ # 发送请求
+ triton_client.async_stream_infer(model_name="model",
+ inputs=inputs,
+ request_id=req_dict['req_id'],
+ outputs=outputs)
+ # 处理返回结果
+ while True:
+ output_item = output_data._completed_requests.get(timeout=timeout)
+ if type(output_item) == triton_utils.InferenceServerException:
+ error_msg = f"status is {output_item.status()}, msg is {output_item.message()}"
+ yield _format_resp({"error_msg": error_msg, "error_code": 500})
+ break
+ else:
+ result = json.loads(output_item.as_numpy("OUT")[0])
+ result = result[0] if isinstance(result, list) else result
+ result["error_msg"] = result.get("error_msg", "")
+ result["error_code"] = result.get("error_code", 0)
+ if req.benchmark:
+ result["http_received_time"] = str(http_received_time)
+ yield _format_resp(result)
+ if (result.get("error_msg") or result.get("error_code")) or result.get("is_end") == 1:
+ break
+
+ # 手动关闭连接
+ triton_client.stop_stream()
+ triton_client.close()
+
+def chat_completion_result(infer_grpc_url: str, req: Req) -> Dict:
+ """
+ 获取非流式生成结果
+ Args:
+ infer_grpc_url (str): gRPC服务地址
+ req (Req): 请求参数对象
+ Returns:
+ dict: 聊天补全结果的生成器。
+ 如果正常,返回{'result': xxx, 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'result': '', 'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+ """
+ result = None
+ error_resp = None
+ for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
+ if resp.get("error_msg") or resp.get("error_code"):
+ error_resp = resp
+ error_resp["result"] = ""
+ else:
+ if resp.get('is_end') == 1:
+ result = resp
+ for key in ['token', 'is_end', 'send_idx', 'return_all_tokens', 'token']:
+ if key in result:
+ del result[key]
+ if not result:
+ error_resp = {
+ "error_msg": "HTTP parsing data error",
+ "error_code": 500,
+ "result": "",
+ "is_end": 1,
+ }
+ return error_resp if error_resp else result
diff --git a/llm/server/server/http_server/app.py b/llm/server/server/http_server/app.py
new file mode 100644
index 0000000000..d9cc879380
--- /dev/null
+++ b/llm/server/server/http_server/app.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import uvicorn
+from fastapi import FastAPI
+from fastapi.responses import StreamingResponse
+from server.http_server.api import (
+ Req,
+ chat_completion_generator,
+ chat_completion_result,
+)
+from server.utils import http_server_logger
+
+http_server_logger.info(f"create fastapi app...")
+app = FastAPI()
+
+@app.post("/v1/chat/completions")
+def create_chat_completion(req: Req):
+ """
+ 服务端路由函数
+ 返回:
+ 如果stream为True,流式返回
+ 如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+ 如果stream为False,非流式返回
+ 如果正常,返回{'result': xxx, 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'result': '', 'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+ """
+ try:
+ http_server_logger.info(f"receive request: {req.req_id}")
+ grpc_port = int(os.getenv("GRPC_PORT", 0))
+ if grpc_port == 0:
+ return {"error_msg": f"GRPC_PORT ({grpc_port}) for infer service is invalid",
+ "error_code": 400}
+ grpc_url = f"localhost:{grpc_port}"
+
+ if req.stream:
+ generator = chat_completion_generator(infer_grpc_url=grpc_url, req=req, yield_json=True)
+ resp = StreamingResponse(generator, media_type="text/event-stream")
+ else:
+ resp = chat_completion_result(infer_grpc_url=grpc_url, req=req)
+ except Exception as e:
+ resp = {'error_msg': str(e), 'error_code': 501}
+ finally:
+ http_server_logger.info(f"finish request: {req.req_id}")
+ return resp
+
+def launch_http_server(port: int, workers: int) -> None:
+ """
+ 启动http服务
+ """
+ http_server_logger.info(f"launch http server... port: {port}, workers: {workers}")
+ try:
+ uvicorn.run(app="server.http_server.app:app",
+ host='0.0.0.0',
+ port=port,
+ workers=workers,
+ log_level="error")
+ except Exception as e:
+ http_server_logger.error(f"launch http server error, {e}")
+
+def main():
+ """main函数"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--port", default=9904, type=int, help="port to the http server")
+ parser.add_argument("--workers", default=1, type=int, help="set the number of workers for the http service")
+ args = parser.parse_args()
+ launch_http_server(port=args.port, workers=args.workers)
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/server/server/triton_server.py b/llm/server/server/triton_server.py
new file mode 100644
index 0000000000..53f2deb5c7
--- /dev/null
+++ b/llm/server/server/triton_server.py
@@ -0,0 +1,473 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+import json
+import multiprocessing
+import os
+import queue
+import subprocess
+import sys
+import threading
+import time
+import traceback
+from collections import Counter, deque
+from datetime import datetime
+
+import numpy as np
+from server.checker import (
+ add_default_params,
+ check_basic_params,
+)
+from server.engine import engine
+from server.engine.config import Config
+from server.utils import error_logger, model_server_logger
+
+import server
+
+try:
+ import triton_python_backend_utils as pb_utils
+except:
+ model_server_logger.warning(
+ "TritonPythonModel is only available under triton inference server framework."
+ )
+
+if sys.stdout.encoding is None:
+ enc = os.environ["LANG"].split(".")[1]
+ sys.stdout = codecs.getwriter(enc)(sys.stdout)
+
+
+class TritonConfig(Config):
+ """
+ Triton Inference Server额外增加的配置参数
+ """
+ def __init__(self, base_config):
+ super().__init__()
+ for k, v in base_config.__dict__.items():
+ setattr(self, k, v)
+
+
+class TritonTokenProcessor(engine.TokenProcessor):
+ """
+ 创建Triton服务的Processor
+ """
+ def __init__(self, cfg, triton_server):
+ super().__init__(cfg)
+ self.triton_server = triton_server
+ # 缓存的结果
+ self.cached_generated_tokens = queue.Queue()
+ # Token缓存,针对部分特殊Token累积后再发送
+ self.token_buffer = dict()
+ # Score缓存,针对部分特殊Token累积后再发送
+ self.score_buffer = dict()
+
+ self.push_mode_sender_thread = threading.Thread(target=self._push_mode_sender_thread, args=())
+ self.push_mode_sender_thread.daemon = True
+ self.push_mode_sender_thread.start()
+
+ def _push_mode_sender_thread(self):
+ while True:
+ try:
+ batch_result = self.cached_generated_tokens.get()
+ for result in batch_result:
+ req_id = result["req_id"]
+ is_end = result.get("is_end", 0)
+ return_all_tokens = result.get("return_all_tokens", False)
+ # 非流式返回下仅返回最后一个Token结果
+ if is_end == 0 and (return_all_tokens or self.cfg.disable_streaming):
+ continue
+ if return_all_tokens and "topk_tokens" in result:
+ del result["topk_tokens"]
+ result = self.triton_server.data_processor.process_response(result)
+ model_server_logger.debug(f"Send result to client under push mode: {result}")
+ with self.triton_server.thread_lock:
+ _send_result([result], self.triton_server.response_sender[req_id], is_end)
+ if is_end == 1:
+ del self.triton_server.response_sender[req_id]
+ self.triton_server._update_metrics()
+ except Exception as e:
+ model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
+
+ def postprocess(self, batch_result, exist_finished_task=False):
+ """
+ 生成单步结果后处理函数
+ """
+ try:
+ self.cached_generated_tokens.put(batch_result)
+ except Exception as e:
+ model_server_logger.info(
+ "Unexcepted problem happend while process output token: {}, {}"
+ .format(e, str(traceback.format_exc())))
+
+
+class TritonServer(object):
+ """
+ Triton框架服务实现
+ """
+
+ def initialize(self, args):
+ """
+ Triton服务初始化
+ """
+ # 开启探活服务
+ use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
+ # 环境变量USE_CUSTOM_HEALTH_CHECKER:控制是否使用自定义的探活接口
+ # 使用自定义的探活接口时候,tritonserver自身的探活服务需要被关闭,当USE_CUSTOM_HEALTH_CHECKER为1时,需要--allow-http设置为false
+ # 当USE_CUSTOM_HEALTH_CHECKER为0时,tritonserver自身的探活服务需要打开,设置--http-port=${HTTP_PORT}
+ if use_custom_health_checker:
+ http_port = os.getenv("HTTP_PORT")
+ if http_port is None:
+ raise Exception("HTTP_PORT must be set")
+ from server.triton_server_helper import start_health_checker
+ multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start()
+ time.sleep(1) # 等待1s,保证需要的共享内存已经创建
+
+ model_config = json.loads(args["model_config"])
+ using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
+ model_config)
+ if not using_decoupled:
+ raise pb_utils.TritonModelException(
+ """the model `{}` can generate any number of responses per request,
+ enable decoupled transaction policy in model configuration to
+ serve this model""".format(args["model_name"]))
+
+ # 添加metrics指标,可以通过 METRICS_PORT 获取服务状态
+ self.metric_family = pb_utils.MetricFamily(
+ name="inference_server_metrics",
+ description="Metrics for monitoring inference server status",
+ kind=pb_utils.MetricFamily.
+ GAUGE,
+ )
+ self.metrics = {
+ "batch_size":
+ self.metric_family.Metric(labels={"batch_size": "batch_size"}),
+ "block_num":
+ self.metric_family.Metric(labels={"block_num": "block_num"}),
+ "max_batch_size":
+ self.metric_family.Metric(
+ labels={"max_batch_size": "max_batch_size"}),
+ "max_block_num":
+ self.metric_family.Metric(
+ labels={"max_block_num": "max_block_num"}),
+ "available_resource":
+ self.metric_family.Metric(
+ labels={"available_resource": "available_resource"}),
+ }
+
+ # Triton服务所需变量
+ # response_sender的线程锁,避免多线程访问或读写时的问题
+ self.thread_lock = threading.Lock()
+
+ base_config = Config()
+ self.cfg = TritonConfig(base_config)
+ self.cfg.print(file="log/fastdeploy_init.info")
+
+ # 初始化底层引擎
+ self.token_processor = TritonTokenProcessor(self.cfg, self)
+ self.engine = engine.Engine(self.cfg, self.token_processor)
+ model_server_logger.info("Creat engine...")
+ self.engine.start()
+ model_server_logger.info("Create engine success")
+
+ self._initialize_push_mode()
+ model_server_logger.info("Init triton server success")
+
+
+ def execute(self, requests):
+ """
+ Triton服务主函数,处理Triton框架接收的请求
+ """
+ if len(requests) != 1:
+ raise pb_utils.TritonModelException(
+ "Only support batch=1, but now it's {}.".format(len(requests)))
+ request = requests[0]
+ current_response_sender = request.get_response_sender()
+ request_tensor = pb_utils.get_input_tensor_by_name(request, "IN")
+ tasks = json.loads(request_tensor.as_numpy()[0])
+
+ model_server_logger.info(f"receive task: {tasks}")
+ self._process_task_push_mode(tasks, current_response_sender)
+ self._update_metrics()
+
+ def finalize(self):
+ """
+ Triton服务退出函数
+ """
+ model_server_logger.info("Triton service will be terminated...")
+ wait_time = 300
+ while not self.engine.all_tasks_finished():
+ if wait_time <= 0:
+ model_server_logger.warning(f"Ignore the unfinished tasks, force to stop.")
+ break
+ model_server_logger.info(f"There's unfinished tasks, wait {wait_time}...")
+ wait_time -= 5
+ time.sleep(5)
+ model_server_logger.info("Terminate the engine now.")
+ self.enable_insert_task_push_mode = False
+ time.sleep(1)
+ del self.engine
+ if hasattr(self, "http_process"):
+ self.http_process.kill()
+ model_server_logger.info("Triton service is terminated!")
+
+ def _initialize_push_mode(self):
+ from server.data.processor import DataProcessor
+ self.data_processor = DataProcessor()
+ model_server_logger.info("create data processor success")
+
+ # 是否开启HTTP协议支持
+ if self.cfg.push_mode_http_port < 0:
+ model_server_logger.info("HTTP server for push mode is disabled.")
+ else:
+ model_server_logger.info("launch http server...")
+
+ current_dir_path = os.path.split(os.path.abspath(__file__))[0]
+ http_py_file = "app.py"
+ http_py_path = os.path.join(current_dir_path, "http_server", http_py_file)
+ http_cmd = f"python3 {http_py_path} --port={self.cfg.push_mode_http_port} " \
+ f"--workers={self.cfg.push_mode_http_workers} >log/launch_http.log 2>&1"
+
+ model_server_logger.info(f"Launch HTTP server for push mode, command:{http_cmd}")
+ self.http_process = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid)
+ time.sleep(3)
+ exit_code = self.http_process.poll()
+ if exit_code is None:
+ http_url = f"http://127.0.0.1:{self.cfg.push_mode_http_port}/v1/chat/completions"
+ model_server_logger.info(f"Launch HTTP server for push mode success, http_url:{http_url}")
+ else:
+ error_msg = "\n Launch HTTP service for push mode failed in 3 seconds. " \
+ "Please check log/launch_http.log file \n"
+ model_server_logger.error(error_msg)
+ model_server_logger.info("init push server success")
+
+ # 需要维护每个请求的通信句柄
+ self.response_sender = dict()
+ # 请求队列,从左侧插入,从右侧取出
+ self.cached_task_deque = deque()
+ # 持续监控引擎和请求队列,当引擎有资源时,从请求队列中获取数据,插入到引擎内
+ self.enable_insert_task_push_mode = True
+ self.insert_task_to_engine_thread = threading.Thread(
+ target=self._insert_task_push_mode, args=())
+ self.insert_task_to_engine_thread.daemon = True
+ self.insert_task_to_engine_thread.start()
+
+ def _process_task_push_mode(self, tasks, current_response_sender):
+ """
+ 针对推模式,对请求进行检查,如果没问题则插入到cached_task_deque中。
+ """
+ try:
+ # 基础检查,如果检查失败,则直接返回错误信息
+ tik = time.time()
+ req_id = tasks[0]["req_id"]
+ cached_task_num = len(self.cached_task_deque)
+ if cached_task_num >= self.cfg.max_cached_task_num:
+ error_msg = f"cached task num ({cached_task_num}) exceeds " \
+ f"the limit ({self.cfg.max_cached_task_num})"
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ if not tasks or len(tasks) != 1 or not tasks[0]:
+ error_msg = f"request data should not be empty and query " \
+ f"num {len(tasks)} should be 1"
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ task = tasks[0]
+ task["preprocess_start_time"] = datetime.now()
+
+ error_msg = check_basic_params(task)
+ if error_msg != []:
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ task_id = task["req_id"]
+ with self.thread_lock:
+ if task_id in self.response_sender:
+ error_msg = f"The req_id {task_id} already exists in the current batch, " \
+ f"the current request will be ignored."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ # 添加默认参数
+ task = add_default_params(task)
+
+ # 拼接和tokenizer处理,默认支持截断
+ if int(task.get("enable_text_truncate", 1)):
+ real_seq_len = self.cfg.max_seq_len - task.get("max_dec_len", 800)
+ task = self.data_processor.process_request(task, max_seq_len=real_seq_len)
+ else:
+ task = self.data_processor.process_request(task)
+
+ # 检查输入长度
+ input_ids_len = len(task["input_ids"])
+ if "max_dec_len" not in task:
+ task["max_dec_len"] = min(self.cfg.max_seq_len - input_ids_len, self.cfg.dec_len_limit)
+ min_dec_len = task["min_dec_len"]
+ if input_ids_len + min_dec_len >= self.cfg.max_seq_len:
+ error_msg = f"Input text is too long, input_ids_len ({input_ids_len}) " \
+ f"+ min_dec_len ({min_dec_len}) >= max_seq_len "
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ if input_ids_len > self.cfg.seq_len_limit:
+ error_msg = f"Length of input token({input_ids_len}) exceeds the limit MAX_SEQ_LEN({self.cfg.seq_len_limit})."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+ if task["max_dec_len"] > self.cfg.dec_len_limit:
+ error_msg = f"The parameter max_dec_len({task['max_dec_len']}) exceeds the limit MAX_DEC_LEN({self.cfg.dec_len_limit})."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ required_block_num = self.engine.resource_manager.get_required_block_number(input_ids_len)
+ if required_block_num > self.engine.resource_manager.total_block_number():
+ error_msg = f"The input task required resources is exceed the limit, task={task}."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ with self.thread_lock:
+ # 插入缓存队列
+ self.response_sender[task_id] = current_response_sender
+
+ task["preprocess_end_time"] = datetime.now()
+ self.cached_task_deque.appendleft(task)
+ tok = time.time()
+ model_server_logger.info(f"cache task with req_id ({task_id}), "
+ f"cost time: {tok-tik}s, cached_task_num: {len(self.cached_task_deque)}.")
+ model_server_logger.debug(f"cache task: {task}")
+ except Exception as e:
+ error_msg = "Unexcepted promblem happend while insert new task to server task queue: {}, {}".format(
+ e, str(traceback.format_exc()))
+ _send_error(error_msg, current_response_sender)
+
+ def _insert_task_push_mode(self):
+ """
+ 推push模式下的持续处理缓存task的线程,一旦有资源将缓存task插入到引擎中。
+ 1. 所有接收到的请求会先插入到cached_task_deque
+ 2. _insert_task_push_mode线程持续监控引擎
+ 3. 一旦有资源可用,从cached_task_deque取出数据,提交给引擎
+ """
+ try:
+ while self.enable_insert_task_push_mode:
+ if not hasattr(self, "engine") or self.engine is None:
+ time.sleep(0.1)
+ continue
+ if self.engine.available_batch() == 0:
+ time.sleep(0.001)
+ continue
+ if len(self.cached_task_deque) == 0:
+ time.sleep(0.001)
+ continue
+ if not self.engine.is_queue_empty():
+ time.sleep(0.001)
+ continue
+
+ i_bs = 0
+ for _ in range(self.cfg.max_prefill_batch):
+ if len(self.cached_task_deque) == 0:
+ break
+ if self.engine.available_batch() == 0:
+ break
+ while i_bs < self.cfg.max_batch_size:
+ if self.engine.task_is_finished(i_bs):
+ break
+ i_bs += 1
+ if i_bs >= self.cfg.max_batch_size:
+ break
+ # 此处无需加锁,execute中插入cached_task_deque的方向与-1的方向不同
+ input_token_num = len(self.cached_task_deque[-1]["input_ids"])
+ if not self.engine.is_resource_sufficient(input_token_num):
+ break
+ task = self.cached_task_deque.pop()
+ try:
+ self.engine.insert_tasks([task])
+ except Exception as e:
+ err_msg = "Error happend while insert task to engine: {}, {}.".format(
+ e, str(traceback.format_exc()))
+ with self.thread_lock:
+ _send_result({"error_msg": err_msg},
+ self.response_sender[task["req_id"]], 1)
+ del self.response_sender[task["req_id"]]
+ model_server_logger.info("finish insert_task_push_mode thread")
+ except Exception as e:
+ model_server_logger.error("insert_task_push_mode thread exit "
+ f"unexpectedly, {e}. {str(traceback.format_exc())}")
+
+ def _update_metrics(self):
+ """
+ 更新监控指标
+ """
+ block_num = self.engine.available_block_num()
+ batch_size = self.engine.available_batch()
+ self.metrics["block_num"].set(block_num)
+ self.metrics["max_batch_size"].set(self.cfg.max_batch_size)
+ self.metrics["batch_size"].set(self.cfg.max_batch_size - batch_size)
+ self.metrics["max_block_num"].set(self.cfg.max_block_num)
+ self.metrics["available_resource"].set(block_num * 1.0 /
+ self.cfg.max_block_num)
+
+ def _get_current_server_info(self):
+ """
+ 获取服务当前资源信息
+ """
+ available_batch_size = min(self.cfg.max_prefill_batch,
+ self.engine.available_batch())
+ available_block_num = self.engine.available_block_num()
+ server_info = {
+ "block_size": int(self.cfg.block_size),
+ "block_num": int(available_block_num),
+ "dec_token_num": int(self.cfg.dec_token_num),
+ "available_resource":
+ 1.0 * available_block_num / self.cfg.max_block_num,
+ "max_batch_size": int(available_batch_size),
+ }
+ return server_info
+
+
+def _send_result(result_dict, sender, end_flag=0):
+ """
+ 向推理引擎发送推理结果。
+
+ Args:
+ result_dict (dict): 推理结果,以字典形式存储。
+ sender (grpc.aio.ServerReaderWriter): gRPC的ServerReaderWriter对象,用于发送推理结果。
+ end_flag (int, optional): 标志位,用于标识是否发送结束信号。默认为0。
+ """
+ response = None
+ if result_dict:
+ result_dict = json.dumps(result_dict)
+ end_output = pb_utils.Tensor("OUT",
+ np.array([result_dict], dtype=np.object_))
+ response = pb_utils.InferenceResponse(output_tensors=[end_output])
+ if response is None and end_flag == 0:
+ return
+ sender.send(response, flags=end_flag)
+
+def _send_error(error_msg, sender, error_code=200, req_id=None):
+ """
+ 向发送方发送错误信息
+
+ Args:
+ error_msg (str): 错误信息
+ sender (str): 发送方标识
+ error_code (int, optional): 错误码. Defaults to 200.
+ """
+ if not isinstance(error_msg, str):
+ error_msg = str(error_msg)
+ error_info = {"req_id": req_id, "error_msg": error_msg, "error_code": error_code, "version": "4.6", "timestamp": time.time()}
+ error_logger.info(f"{error_info}")
+ model_server_logger.error(error_msg)
+ _send_result(error_info, sender, 1)
+
+
+TritonPythonModel = TritonServer
diff --git a/llm/server/server/triton_server_helper.py b/llm/server/server/triton_server_helper.py
new file mode 100644
index 0000000000..12435b40f8
--- /dev/null
+++ b/llm/server/server/triton_server_helper.py
@@ -0,0 +1,146 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import queue
+import socket
+import subprocess
+import time
+from collections import defaultdict
+from multiprocessing import shared_memory
+
+import numpy as np
+import uvicorn
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import JSONResponse, Response
+from server.engine.config import Config
+from server.utils import get_logger
+
+app = FastAPI()
+
+logger = get_logger("health_checker", "health_checker.log")
+env_config = Config()
+
+@app.get("/v2/health/ready")
+def check_health():
+ """
+ 探活接口"""
+ status, error_info = check()
+ if status is True:
+ logger.info("check_health: OK")
+ return Response()
+ else:
+ logger.info("check_health: Bad")
+ return JSONResponse(
+ status_code=500,
+ content=error_info)
+
+
+@app.get("/v2/health/live")
+def check_live():
+ """
+ 探活接口"""
+ status, error_info = check()
+ if status is True:
+ logger.info("check_health: OK")
+ return Response()
+ else:
+ logger.info("check_health: Bad")
+ return JSONResponse(
+ status_code=500,
+ content=error_info)
+
+
+def check_infer_engine_process():
+ # 检查infer进程是否存在
+ mp_num = int(env_config.mp_num)
+ for i in range(mp_num):
+ try:
+ infer_live_flag_shm = shared_memory.SharedMemory(name=env_config.get_unique_name("shm_flag_infer_{}_live".format(i)))
+ except Exception as e: # infer掉了会报异常
+ return False
+ return True
+
+
+def check():
+ """
+ 推理服务的状态探活接口
+ """
+ error_info = {}
+ grpc_port = os.getenv("GRPC_PORT")
+
+ # 1. 检查server是否健康
+ if grpc_port is not None:
+ sock = socket.socket()
+ try:
+ sock.connect(('localhost', int(grpc_port)))
+ except Exception:
+ error_info["error_code"] = 1
+ error_info["error_msg"] = "server is not ready"
+ logger.info("server is not ready")
+ return False, error_info
+ finally:
+ sock.close()
+
+ # 2. 检查engine是否健康
+ is_engine_live = check_infer_engine_process()
+ if is_engine_live is False:
+ error_info["error_code"] = 2
+ error_info["error_msg"] = "infer engine is down"
+ logger.info("infer engine is down")
+ return False, error_info
+
+ # 检查是否启动
+ engine_ready_checker = np.ndarray(engine_ready_check_flag.shape, dtype=engine_ready_check_flag.dtype,
+ buffer=shm_engine_ready_check_flag.buf)
+ if engine_ready_checker[0] == 0: # 值为0代表没启动,值为1代表已启动
+ error_info["error_code"] = 2
+ error_info["error_msg"] = "infer engine is down"
+ logger.info("infer engine is down")
+ return False, error_info
+
+ # 检查是否hang住
+ engine_hang_checker = np.ndarray(engine_healthy_recorded_time.shape, dtype=engine_healthy_recorded_time.dtype,
+ buffer=shm_engine_healthy_recorded_time.buf)
+ elapsed_time = time.time() - engine_hang_checker[0]
+ logger.info("engine_checker elapsed time: {}".format(elapsed_time))
+ if (engine_hang_checker[0]) and (elapsed_time > time_interval_threashold):
+ error_info["error_code"] = 3
+ error_info["error_msg"] = "infer engine hangs"
+ logger.info("infer engine hangs")
+ return False, error_info
+
+ return True, error_info
+
+
+def start_health_checker(http_port):
+ import sys
+ sys.stdout = open("log/health_http.log", 'w')
+ sys.stderr = sys.stdout
+ uvicorn.run(app=app, host='0.0.0.0', port=http_port, workers=1, log_level="info")
+
+
+time_interval_threashold = env_config.check_health_interval # 10s infer engine没有执行过while循环,则判定hang死或挂掉等问题
+engine_healthy_recorded_time = np.zeros([1], dtype=float)
+shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
+ create=True,
+ size=engine_healthy_recorded_time.nbytes,
+ name=env_config.get_unique_name("engine_healthy_recorded_time")) # 由推理引擎进行更新,每次读token时候就刷新一次时间,正常情况下该时间戳在30s内肯定会被刷新
+
+engine_ready_check_flag = np.zeros([1], dtype=np.int32)
+shm_engine_ready_check_flag = shared_memory.SharedMemory(
+ create=True,
+ size=engine_ready_check_flag.nbytes,
+ name=env_config.get_unique_name("engine_ready_check_flag")) # 由推理引擎更新,推理引擎初始化完毕时候置为1
diff --git a/llm/server/server/utils.py b/llm/server/server/utils.py
new file mode 100644
index 0000000000..cb36f2a6ed
--- /dev/null
+++ b/llm/server/server/utils.py
@@ -0,0 +1,196 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import codecs
+import logging
+import os
+import pickle
+import re
+import time
+from datetime import datetime
+from enum import Enum
+from logging.handlers import BaseRotatingHandler
+from pathlib import Path
+import subprocess
+
+
+class DailyRotatingFileHandler(BaseRotatingHandler):
+ """
+ - 可以支持多进程
+ - 只支持自然日分割
+ - 暂不支持UTC
+ """
+
+ def __init__(
+ self,
+ filename,
+ backupCount=0,
+ encoding="utf-8",
+ delay=False,
+ utc=False,
+ **kwargs
+ ):
+ self.backup_count = backupCount
+ self.utc = utc
+ self.suffix = "%Y-%m-%d"
+ self.base_log_path = Path(filename)
+ self.base_filename = self.base_log_path.name
+ self.current_filename = self._compute_fn()
+ self.current_log_path = self.base_log_path.with_name(self.current_filename)
+ BaseRotatingHandler.__init__(self, filename, "a", encoding, delay)
+
+ def shouldRollover(self, record):
+ """
+ 判断是否该滚动日志,如果当前时间对应的日志文件名与当前打开的日志文件名不一致,则需要滚动日志
+ """
+ if self.current_filename != self._compute_fn():
+ return True
+ return False
+
+ def doRollover(self):
+ """
+ 滚动日志
+ """
+ if self.stream:
+ self.stream.close()
+ self.stream = None
+
+ self.current_filename = self._compute_fn()
+ self.current_log_path = self.base_log_path.with_name(self.current_filename)
+
+ if not self.delay:
+ self.stream = self._open()
+
+ self.delete_expired_files()
+
+ def _compute_fn(self):
+ """
+ 计算当前时间对应的日志文件名
+ """
+ return self.base_filename + "." + time.strftime(self.suffix, time.localtime())
+
+ def _open(self):
+ """
+ 打开新的日志文件,同时更新base_filename指向的软链,修改软链不会对日志记录产生任何影响
+ """
+ if self.encoding is None:
+ stream = open(str(self.current_log_path), self.mode)
+ else:
+ stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)
+
+ # 删除旧的软链
+ if self.base_log_path.exists():
+ try:
+ if (
+ not self.base_log_path.is_symlink()
+ or os.readlink(self.base_log_path) != self.current_filename
+ ):
+ os.remove(self.base_log_path)
+ except OSError:
+ pass
+
+ try:
+ os.symlink(self.current_filename, str(self.base_log_path))
+ except OSError:
+ pass
+ return stream
+
+ def delete_expired_files(self):
+ """
+ 删除过期的日志
+ """
+ if self.backup_count <= 0:
+ return
+
+ file_names = os.listdir(str(self.base_log_path.parent))
+ result = []
+ prefix = self.base_filename + "."
+ plen = len(prefix)
+ for file_name in file_names:
+ if file_name[:plen] == prefix:
+ suffix = file_name[plen:]
+ if re.match(r"^\d{4}-\d{2}-\d{2}(\.\w+)?$", suffix):
+ result.append(file_name)
+ if len(result) < self.backup_count:
+ result = []
+ else:
+ result.sort()
+ result = result[: len(result) - self.backup_count]
+
+ for file_name in result:
+ os.remove(str(self.base_log_path.with_name(file_name)))
+
+
+def get_logger(name, file_name, without_formater=False):
+ """
+ 获取logger
+ """
+ log_dir = os.getenv("FD_LOG_DIR", default="log")
+ is_debug = int(os.getenv("FD_DEBUG", default=0))
+ logger = logging.getLogger(name)
+ if is_debug:
+ logger.setLevel(level=logging.DEBUG)
+ else:
+ logger.setLevel(level=logging.INFO)
+
+ LOG_FILE = "{0}/{1}".format(log_dir, file_name)
+ backup_count = int(os.getenv("FD_LOG_BACKUP_COUNT", 7))
+ handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count)
+
+ formatter = logging.Formatter(
+ "%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s"
+ )
+ if not without_formater:
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ handler.propagate = False
+ return logger
+
+# 实例化单例logger
+model_server_logger = get_logger("model_server", "infer_server.log")
+http_server_logger = get_logger("http_server", "http_server.log")
+data_processor_logger = get_logger("data_processor", "data_processor.log")
+monitor_logger = get_logger("monitor_logger", "monitor_logger.log", True)
+error_logger = get_logger("error_logger", "error_logger.log", True)
+
+
+def str_to_datetime(date_string):
+ """datetime字符串转datetime对象"""
+ if "." in date_string:
+ return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
+ else:
+ return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
+
+
+def datetime_diff(datetime_start, datetime_end):
+ """
+ 计算两个日期时间之间的差值(以秒为单位)。
+
+ Args:
+ datetime_start (Union[str, datetime.datetime]): 开始时间,可以是字符串或datetime.datetime对象。
+ datetime_end (Union[str, datetime.datetime]): 结束时间,可以是字符串或datetime.datetime对象。
+
+ Returns:
+ float: 日期时间差值,以秒为单位。
+ """
+ if isinstance(datetime_start, str):
+ datetime_start = str_to_datetime(datetime_start)
+ if isinstance(datetime_end, str):
+ datetime_end = str_to_datetime(datetime_end)
+ if datetime_end > datetime_start:
+ cost = datetime_end - datetime_start
+ else:
+ cost = datetime_start - datetime_end
+ return cost.total_seconds()