diff --git a/wren-ai-service/src/web/v1/routers/ask.py b/wren-ai-service/src/web/v1/routers/ask.py index 46a7bd54f..6d360baa3 100644 --- a/wren-ai-service/src/web/v1/routers/ask.py +++ b/wren-ai-service/src/web/v1/routers/ask.py @@ -53,7 +53,7 @@ - **Path Parameter**: - `query_id`: The unique identifier of the query. - **Response**: - - `status`: The current status of the query (`"understanding"`, `"searching"`, `"generating"`, `"finished"`, `"failed"`, or `"stopped"`). + - `status`: The current status of the query (`"understanding"`, `"searching"`, `"generating"`, `"correcting"`, `"finished"`, `"failed"`, or `"stopped"`). - `response`: (Optional) A list of SQL results, each containing: - `sql`: The generated SQL statement. - `summary`: A summary of the SQL statement. diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 4f03a759b..8d4acf908 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -98,7 +98,13 @@ class AskResultRequest(BaseModel): class AskResultResponse(BaseModel): status: Literal[ - "understanding", "searching", "generating", "finished", "failed", "stopped" + "understanding", + "searching", + "generating", + "correcting", + "finished", + "failed", + "stopped", ] response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -267,6 +273,10 @@ async def ask( "invalid_generation_results" ] ): + self._ask_results[query_id] = AskResultResponse( + status="correcting", + ) + sql_correction_results = await self._pipelines[ "sql_correction" ].run(