Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug for llm serving #2326

Merged
merged 20 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions llm/fastdeploy_llm/serving/serving_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,26 @@ def add_request(self, task):
logger.debug("Receive task = {}".format(task))
assert not self.stop_queue, "The serving model is stopped, cannot accept new requests now."
assert task.text.strip() != "", "The request's text cannot be empty."
try:
if not hasattr(task, "token_ids"):
task.token_ids, task.position_ids = self.model.data_processor.encode(
task.text, padding=True)
else:
task.position_ids = None
if self.config.is_ptuning:
assert len(
task.token_ids
) + self.config.max_prefix_len <= self.config.max_seq_len, "The request's token number({}) + max_prefix_len({}) = {} is exceeding the setting max_seq_len({}).".format(
len(task.token_ids), self.config.max_prefix_len,
len(task.token_ids) + self.config.max_prefix_len,
self.config.max_seq_len)
else:
assert len(
task.token_ids
) <= self.config.max_seq_len, "The request's token number({}) is exceed the setting max_seq_len({}).".format(
len(task.token_ids), self.config.max_seq_len)
self.requests_queue.put(task, timeout=0.5)
logger.debug("Task with task_id={} is put into requests queue.".
format(task.task_id))
except Exception as e:
raise Exception(
"There's error while inserting request, error={}.".format(e))
if not hasattr(task, "token_ids"):
task.token_ids, task.position_ids = self.model.data_processor.encode(
task.text, padding=True)
else:
task.position_ids = None
if self.config.is_ptuning:
assert len(
task.token_ids
) + self.config.max_prefix_len <= self.config.max_seq_len, "The request's token number({}) + max_prefix_len({}) = {} is exceeding the setting max_seq_len({}).".format(
len(task.token_ids), self.config.max_prefix_len,
len(task.token_ids) + self.config.max_prefix_len,
self.config.max_seq_len)
else:
assert len(
task.token_ids
) <= self.config.max_seq_len, "The request's token number({}) is exceed the setting max_seq_len({}).".format(
len(task.token_ids), self.config.max_seq_len)
self.requests_queue.put(task, timeout=0.5)
logger.debug("Task with task_id={} is put into requests queue.".
format(task.task_id))

def runner(self):
batch_tasks = BatchTask(self.config.max_batch_size)
Expand Down
25 changes: 23 additions & 2 deletions llm/fastdeploy_llm/serving/triton_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import queue
import os
import uuid
import threading
Expand Down Expand Up @@ -123,7 +124,7 @@ def execute(self, requests):
except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0000
error_info = "Cannot load json data from request, received data = {} error={}.".format(request_tensor, e)
error_info = "Cannot load json data from request, received data = {} error={}.".format(request_tensor.as_numpy(), e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
error_res = pb_utils.InferenceResponse(
Expand Down Expand Up @@ -190,7 +191,7 @@ def execute(self, requests):
if self.model.requests_queue.qsize() > self.config.max_queue_num:
error_type = ErrorType.Server
error_code = ErrorCode.S0000
error_info = "The queue is full now(size={}), please wait for a while.".format(self.model.max_queue_num)
error_info = "The queue is full now(size={}), please wait for a while.".format(self.config.max_queue_num)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_msg))
Expand Down Expand Up @@ -220,6 +221,26 @@ def execute(self, requests):
task.call_back_func = stream_call_back
try:
self.model.add_request(task)
except queue.Full as e:
# Log error for Server
error_type = ErrorType.Server
error_code = ErrorCode.S0000
error_info = "The queue is full now(size={}), please scale service.".format(self.config.max_queue_num)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
# Log error for query
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while inserting new request, task={} error={}".format(task, "service too busy")
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
continue

except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
Expand Down