From e575ba29f56283dc774fc91e22e5f8284d04a4fd Mon Sep 17 00:00:00 2001 From: thucpn Date: Mon, 25 Mar 2024 16:03:07 +0700 Subject: [PATCH] feat(python): add chat request route --- .../streaming/fastapi/app/api/routers/chat.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 278a9a753..b1e1484bd 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -18,20 +18,18 @@ class _ChatData(BaseModel): messages: List[_Message] -@r.post("") -async def chat( - request: Request, - data: _ChatData, - chat_engine: BaseChatEngine = Depends(get_chat_engine), -): - # check preconditions and get last message +class _Result(BaseModel): + result: _Message + + +async def preprocess_request(data: _ChatData) -> tuple: if len(data.messages) == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="No messages provided", ) - lastMessage = data.messages.pop() - if lastMessage.role != MessageRole.USER: + last_message = data.messages.pop() + if last_message.role != MessageRole.USER: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Last message must be from user", @@ -44,16 +42,36 @@ async def chat( ) for m in data.messages ] + return last_message, messages - # query chat engine - response = await chat_engine.astream_chat(lastMessage.content, messages) - # stream response +@r.post("") +async def chat( + request: Request, + data: _ChatData, + chat_engine: BaseChatEngine = Depends(get_chat_engine), +): + last_message, messages = await preprocess_request(data) + + response = await chat_engine.astream_chat(last_message.content, messages) + async def event_generator(): async for token in response.async_response_gen(): - # If client closes connection, stop sending events if await request.is_disconnected(): break yield token return StreamingResponse(event_generator(), media_type="text/plain") + + +@r.post("/request") +async def chat_request( + data: _ChatData, + chat_engine: BaseChatEngine = Depends(get_chat_engine), +) -> _Result: + last_message, messages = await preprocess_request(data) + + response = await chat_engine.achat(last_message.content, messages) + return _Result( + result=_Message(role=MessageRole.ASSISTANT, content=response.response) + )