diff --git a/lightllm/server/api_tgi.py b/lightllm/server/api_tgi.py index 3e1b3733c..f9b070c3c 100755 --- a/lightllm/server/api_tgi.py +++ b/lightllm/server/api_tgi.py @@ -1,3 +1,4 @@ +import os import collections from typing import AsyncGenerator from fastapi import BackgroundTasks, Request @@ -8,6 +9,8 @@ from .httpserver.manager import HttpServerManager import json +RETURN_LIST = os.getenv("RETURN_LIST", "FALSE").upper() in ["ON", "TRUE", "1"] + def format_tgi_params(params, num_beam: int = 1): """ @@ -132,7 +135,10 @@ async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerMana if return_details: ret["details"]["beam_sequences"] = beam_sequences # wrap generation inside a Vec to match api-inference - json_compatible_item_data = jsonable_encoder([ret]) + if RETURN_LIST: + json_compatible_item_data = jsonable_encoder([ret]) + else: + json_compatible_item_data = jsonable_encoder(ret) return JSONResponse(content=json_compatible_item_data)