diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index cfc191db55..ef84f66602 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional, Union import gradio as gr +import tiktoken import xoscar as xo from aioprometheus import REGISTRY, MetricsMiddleware from aioprometheus.asgi.starlette import metrics @@ -1292,6 +1293,33 @@ async def create_embedding(self, request: Request) -> Response: payload = await request.json() body = CreateEmbeddingRequest.parse_obj(payload) model_uid = body.model + + # 检查 body.input 是否是一个二维整数列表 + if isinstance(body.input, list) and all( + isinstance(item, list) and all(isinstance(i, int) for i in item) + for item in body.input + ): + enc = tiktoken.get_encoding("cl100k_base") + lines_decoded = [] + + for line in body.input: + try: + # 将每个 token 解码为字节,然后连接成一个完整的字符串 + output = b"".join( + enc.decode_single_token_bytes(token) for token in line + ) + # 将字节序列转换为 UTF-8 编码的字符串 + decoded_line = output.decode("utf-8") + lines_decoded.append(decoded_line) + except Exception as e: + # 抛出 HTTP 异常,包含详细的错误信息 + raise HTTPException( + status_code=500, detail=f"Error decoding tokens: {str(e)}" + ) + + # 更新 body.input 为解码后的字符串列表 + body.input = lines_decoded + exclude = { "model", "input",