-
Notifications
You must be signed in to change notification settings - Fork 859
chore(wren-ai-service): Refactor Services to use Resource Classes #848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,156 +1,107 @@ | ||
import logging | ||
from typing import Dict, Literal, Optional | ||
|
||
from cachetools import TTLCache | ||
from langfuse.decorators import observe | ||
from pydantic import BaseModel | ||
|
||
from src.core.pipeline import BasicPipeline | ||
from src.utils import async_timer, trace_metadata | ||
|
||
logger = logging.getLogger("wren-ai-service") | ||
|
||
class SQLAnswerRequest(BaseModel): | ||
class Input(BaseModel): | ||
query: str | ||
sql: str | ||
sql_summary: str | ||
thread_id: Optional[str] = None | ||
user_id: Optional[str] = None | ||
|
||
# POST /v1/sql-answers | ||
class SqlAnswerRequest(BaseModel): | ||
_query_id: str | None = None | ||
query: str | ||
sql: str | ||
sql_summary: str | ||
thread_id: Optional[str] = None | ||
user_id: Optional[str] = None | ||
|
||
@property | ||
def query_id(self) -> str: | ||
return self._query_id | ||
|
||
@query_id.setter | ||
def query_id(self, query_id: str): | ||
self._query_id = query_id | ||
|
||
|
||
class SqlAnswerResponse(BaseModel): | ||
query_id: str | ||
|
||
|
||
# GET /v1/sql-answers/{query_id}/result | ||
class SqlAnswerResultRequest(BaseModel): | ||
query_id: str | ||
|
||
|
||
class SqlAnswerResultResponse(BaseModel): | ||
class SqlAnswerError(BaseModel): | ||
code: Literal["OTHERS"] | ||
message: str | ||
|
||
status: Literal["understanding", "processing", "finished", "failed"] | ||
response: Optional[str] = None | ||
error: Optional[SqlAnswerError] = None | ||
class Resource(BaseModel): | ||
class Error(BaseModel): | ||
code: Literal["OTHERS"] | ||
message: str | ||
|
||
query_id: str | ||
status: Literal["understanding", "processing", "finished", "failed"] = None | ||
response: Optional[str] = None | ||
error: Optional[Error] = None | ||
|
||
class SqlAnswerService: | ||
def __init__( | ||
self, | ||
pipelines: Dict[str, BasicPipeline], | ||
maxsize: int = 1_000_000, | ||
ttl: int = 120, | ||
): | ||
self._pipelines = pipelines | ||
self._sql_answer_results: Dict[str, SqlAnswerResultResponse] = TTLCache( | ||
self._cache: Dict[str, SQLAnswerRequest.Resource] = TTLCache( | ||
maxsize=maxsize, ttl=ttl | ||
) | ||
|
||
@async_timer | ||
@observe(name="SQL Answer") | ||
@trace_metadata | ||
async def sql_answer( | ||
self, | ||
sql_answer_request: SqlAnswerRequest, | ||
**kwargs, | ||
): | ||
results = { | ||
"sql_answer_result": {}, | ||
"metadata": { | ||
"error": { | ||
"type": "", | ||
"message": "", | ||
} | ||
}, | ||
} | ||
|
||
async def generate(self, request: Input, **kwargs) -> Resource: | ||
try: | ||
query_id = sql_answer_request.query_id | ||
|
||
self._sql_answer_results[query_id] = SqlAnswerResultResponse( | ||
status="understanding", | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="understanding" | ||
) | ||
|
||
self._sql_answer_results[query_id] = SqlAnswerResultResponse( | ||
status="processing", | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="processing" | ||
) | ||
|
||
data = await self._pipelines["sql_answer"].run( | ||
query=sql_answer_request.query, | ||
sql=sql_answer_request.sql, | ||
sql_summary=sql_answer_request.sql_summary, | ||
project_id=sql_answer_request.thread_id, | ||
query=request.query, | ||
sql=request.sql, | ||
sql_summary=request.sql_summary, | ||
project_id=request.thread_id, | ||
) | ||
api_results = data["post_process"]["results"] | ||
if answer := api_results["answer"]: | ||
self._sql_answer_results[query_id] = SqlAnswerResultResponse( | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="finished", | ||
response=answer, | ||
response=answer | ||
) | ||
else: | ||
self._sql_answer_results[query_id] = SqlAnswerResultResponse( | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="failed", | ||
error=SqlAnswerResultResponse.SqlAnswerError( | ||
error=self.Resource.Error( | ||
code="OTHERS", | ||
message=api_results["error"], | ||
), | ||
message=api_results["error"] | ||
) | ||
) | ||
|
||
results["metadata"]["error_type"] = "OTHERS" | ||
results["metadata"]["error_message"] = api_results["error"] | ||
|
||
results["sql_answer_result"] = { | ||
"answer": api_results["answer"], | ||
"reasoning": api_results["reasoning"], | ||
} | ||
return results | ||
except Exception as e: | ||
logger.exception(f"sql answer pipeline - OTHERS: {e}") | ||
|
||
self._sql_answer_results[ | ||
sql_answer_request.query_id | ||
] = SqlAnswerResultResponse( | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="failed", | ||
error=SqlAnswerResultResponse.SqlAnswerError( | ||
error=self.Resource.Error( | ||
code="OTHERS", | ||
message=str(e), | ||
), | ||
message=str(e) | ||
) | ||
) | ||
|
||
results["metadata"]["error_type"] = "OTHERS" | ||
results["metadata"]["error_message"] = str(e) | ||
return results | ||
return self[request.query_id] | ||
|
||
def get_sql_answer_result( | ||
self, | ||
sql_answer_result_request: SqlAnswerResultRequest, | ||
) -> SqlAnswerResultResponse: | ||
if ( | ||
result := self._sql_answer_results.get(sql_answer_result_request.query_id) | ||
) is None: | ||
logger.exception( | ||
f"sql answer pipeline - OTHERS: {sql_answer_result_request.query_id} is not found" | ||
) | ||
return SqlAnswerResultResponse( | ||
def __getitem__(self, query_id: str) -> Resource: | ||
response = self._cache.get(query_id) | ||
if response is None: | ||
message = f"SQL Answer Resource with ID '{query_id}' not found." | ||
logger.exception(message) | ||
return self.Resource( | ||
query_id=query_id, | ||
status="failed", | ||
error=SqlAnswerResultResponse.SqlAnswerError( | ||
error=self.Resource.Error( | ||
code="OTHERS", | ||
message=f"{sql_answer_result_request.query_id} is not found", | ||
), | ||
message=message | ||
) | ||
) | ||
return response | ||
|
||
return result | ||
def __setitem__(self, query_id: str, value: Resource): | ||
self._cache[query_id] = value |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,71 @@ | ||
import asyncio | ||
import logging | ||
from typing import Dict, List, Literal, Optional | ||
|
||
from cachetools import TTLCache | ||
from haystack import Pipeline | ||
from pydantic import BaseModel | ||
|
||
from src.utils import async_timer | ||
|
||
logger = logging.getLogger("wren-ai-service") | ||
|
||
|
||
# POST /v1/sql-explanations | ||
class StepWithAnalysisResult(BaseModel): | ||
sql: str | ||
summary: str | ||
sql_analysis_results: List[Dict] | ||
|
||
|
||
class SQLExplanationRequest(BaseModel): | ||
_query_id: str | None = None | ||
question: str | ||
steps_with_analysis_results: List[StepWithAnalysisResult] | ||
mdl_hash: Optional[str] = None | ||
thread_id: Optional[str] = None | ||
project_id: Optional[str] = None | ||
user_id: Optional[str] = None | ||
|
||
@property | ||
def query_id(self) -> str: | ||
return self._query_id | ||
|
||
@query_id.setter | ||
def query_id(self, query_id: str): | ||
self._query_id = query_id | ||
|
||
|
||
class SQLExplanationResponse(BaseModel): | ||
query_id: str | ||
|
||
|
||
# GET /v1/sql-explanations/{query_id}/result | ||
class SQLExplanationResultRequest(BaseModel): | ||
query_id: str | ||
|
||
|
||
class SQLExplanationResultResponse(BaseModel): | ||
class SQLExplanationResultError(BaseModel): | ||
code: Literal["OTHERS"] | ||
message: str | ||
|
||
status: Literal["understanding", "generating", "finished", "failed"] | ||
response: Optional[List[List[Dict]]] = None | ||
error: Optional[SQLExplanationResultError] = None | ||
|
||
class Input(BaseModel): | ||
question: str | ||
steps_with_analysis_results: List[ | ||
Dict[ | ||
"step_sql": str, | ||
"step_summary": str, | ||
"step_sql_analysis_results": List[Dict] | ||
] | ||
Comment on lines
+15
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gently ping to mention here has the same issue as the previous comment. |
||
] | ||
mdl_hash: Optional[str] = None | ||
thread_id: Optional[str] = None | ||
project_id: Optional[str] = None | ||
user_id: Optional[str] = None | ||
|
||
class Resource(BaseModel): | ||
class Error(BaseModel): | ||
code: Literal["OTHERS"] | ||
message: str | ||
|
||
query_id: str | ||
status: Literal["understanding", "generating", "finished", "failed"] = None | ||
response: Optional[List[List[Dict]]] = None | ||
error: Optional[Error] = None | ||
|
||
class SQLExplanationService: | ||
def __init__( | ||
self, | ||
pipelines: Dict[str, Pipeline], | ||
maxsize: int = 1_000_000, | ||
ttl: int = 120, | ||
): | ||
self._pipelines = pipelines | ||
self._sql_explanation_results: Dict[ | ||
str, SQLExplanationResultResponse | ||
] = TTLCache(maxsize=maxsize, ttl=ttl) | ||
self._cache: Dict[str, SQLExplanationRequest.Resource] = TTLCache( | ||
maxsize=maxsize, ttl=ttl | ||
) | ||
|
||
@async_timer | ||
async def sql_explanation( | ||
self, | ||
sql_explanation_request: SQLExplanationRequest, | ||
**kwargs, | ||
): | ||
async def generate(self, request: Input, **kwargs) -> Resource: | ||
try: | ||
query_id = sql_explanation_request.query_id | ||
|
||
self._sql_explanation_results[query_id] = SQLExplanationResultResponse( | ||
status="understanding", | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="understanding" | ||
) | ||
|
||
self._sql_explanation_results[query_id] = SQLExplanationResultResponse( | ||
status="generating", | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="generating" | ||
) | ||
|
||
async def _task( | ||
question: str, | ||
step_with_analysis_results: StepWithAnalysisResult, | ||
): | ||
async def _task(question: str, step_with_analysis_results: Dict): | ||
return await self._pipelines["sql_explanation"].run( | ||
question=question, | ||
step_with_analysis_results=step_with_analysis_results, | ||
step_with_analysis_results=step_with_analysis_results | ||
) | ||
|
||
tasks = [ | ||
_task( | ||
sql_explanation_request.question, | ||
step_with_analysis_results, | ||
) | ||
for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results | ||
_task(request.question, step) | ||
for step in request.steps_with_analysis_results | ||
] | ||
generation_results = await asyncio.gather(*tasks) | ||
|
||
|
@@ -108,39 +75,47 @@ async def _task( | |
] | ||
|
||
if sql_explanation_results: | ||
self._sql_explanation_results[query_id] = SQLExplanationResultResponse( | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="finished", | ||
response=sql_explanation_results, | ||
response=sql_explanation_results | ||
) | ||
else: | ||
self._sql_explanation_results[query_id] = SQLExplanationResultResponse( | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="failed", | ||
error=SQLExplanationResultResponse.SQLExplanationResultError( | ||
error=self.Resource.Error( | ||
code="OTHERS", | ||
message="No SQL explanation is found", | ||
), | ||
message="No SQL explanation is found" | ||
) | ||
) | ||
except Exception as e: | ||
logger.exception( | ||
f"sql explanation pipeline - Failed to provide SQL explanation: {e}" | ||
) | ||
self._sql_explanation_results[ | ||
sql_explanation_request.query_id | ||
] = SQLExplanationResultResponse( | ||
logger.exception(f"sql explanation pipeline - Failed to provide SQL explanation: {e}") | ||
self[request.query_id] = self.Resource( | ||
query_id=request.query_id, | ||
status="failed", | ||
error=SQLExplanationResultResponse.SQLExplanationResultError( | ||
error=self.Resource.Error( | ||
code="OTHERS", | ||
message=str(e), | ||
), | ||
message=str(e) | ||
) | ||
) | ||
|
||
def get_sql_explanation_result( | ||
self, sql_explanation_result_request: SQLExplanationResultRequest | ||
) -> SQLExplanationResultResponse: | ||
if sql_explanation_result_request.query_id not in self._sql_explanation_results: | ||
return SQLExplanationResultResponse( | ||
return self[request.query_id] | ||
|
||
def __getitem__(self, query_id: str) -> Resource: | ||
response = self._cache.get(query_id) | ||
if response is None: | ||
message = f"SQL Explanation Resource with ID '{query_id}' not found." | ||
logger.exception(message) | ||
return self.Resource( | ||
query_id=query_id, | ||
status="failed", | ||
error=f"{sql_explanation_result_request.query_id} is not found", | ||
error=self.Resource.Error( | ||
code="OTHERS", | ||
message=message | ||
) | ||
) | ||
return response | ||
|
||
return self._sql_explanation_results[sql_explanation_result_request.query_id] | ||
def __setitem__(self, query_id: str, value: Resource): | ||
self._cache[query_id] = value |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,163 +1,133 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Dict, List, Literal, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
from cachetools import TTLCache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from haystack import Pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from pydantic import BaseModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
from src.utils import async_timer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from src.web.v1.services.ask_details import SQLBreakdown | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger = logging.getLogger("wren-ai-service") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# POST /v1/sql-regenerations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
class DecisionPoint(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
type: Literal["filter", "selectItems", "relation", "groupByKeys", "sortings"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
value: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class CorrectionPoint(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
type: Literal[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"sql_expression", "nl_expression" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] # nl_expression is natural language expression | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
value: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class UserCorrection(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
before: DecisionPoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
after: CorrectionPoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLExplanationWithUserCorrections(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
summary: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
sql: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cte_name: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
corrections: List[UserCorrection] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationRequest(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_query_id: str | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
description: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps: List[SQLExplanationWithUserCorrections] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
mdl_hash: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
thread_id: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
project_id: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
user_id: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def query_id(self) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@query_id.setter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def query_id(self, query_id: str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._query_id = query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationResponse(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# GET /v1/sql-regenerations/{query_id}/result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationResultRequest(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationResultResponse(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationResponseDetails(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
class Input(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
description: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps: List[SQLBreakdown] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationError(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
code: Literal["NO_RELEVANT_SQL", "OTHERS"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
status: Literal["understanding", "generating", "finished", "failed"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
response: Optional[SQLRegenerationResponseDetails] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error: Optional[SQLRegenerationError] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps: List[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Dict[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"summary": str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"sql": str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"cte_name": str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"corrections": List[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Dict[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"before": Dict[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"type": Literal["filter", "selectItems", "relation", "groupByKeys", "sortings"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"value": str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"after": Dict[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"type": Literal["sql_expression", "nl_expression"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"value": str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+15
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this type declaration seems be wrong, could you help to fix? thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
mdl_hash: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
thread_id: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
project_id: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
user_id: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class Resource(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
class Error(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
code: Literal["NO_RELEVANT_SQL", "OTHERS"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class ResponseDetails(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
description: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps: List[SQLBreakdown] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status: Literal["understanding", "generating", "finished", "failed"] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
response: Optional[ResponseDetails] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error: Optional[Error] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class SQLRegenerationService: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
pipelines: Dict[str, Pipeline], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
maxsize: int = 1_000_000, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ttl: int = 120, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._pipelines = pipelines | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._sql_regeneration_results: Dict[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
str, SQLRegenerationResultResponse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] = TTLCache(maxsize=maxsize, ttl=ttl) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._cache: Dict[str, SQLRegenerationRequest.Resource] = TTLCache( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
maxsize=maxsize, ttl=ttl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@async_timer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def sql_regeneration( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
sql_regeneration_request: SQLRegenerationRequest, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def generate(self, request: Input, **kwargs) -> Resource: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id = sql_regeneration_request.query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._sql_regeneration_results[query_id] = SQLRegenerationResultResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="understanding", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self[request.query_id] = self.Resource( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id=request.query_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="understanding" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._sql_regeneration_results[query_id] = SQLRegenerationResultResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="generating", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self[request.query_id] = self.Resource( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id=request.query_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="generating" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
generation_result = await self._pipelines["sql_regeneration"].run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
description=sql_regeneration_request.description, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps=sql_regeneration_request.steps, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
project_id=sql_regeneration_request.project_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
description=request.description, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps=request.steps, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
project_id=request.project_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
sql_regeneration_result = generation_result[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"sql_regeneration_post_process" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
]["results"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not sql_regeneration_result["steps"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._sql_regeneration_results[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] = SQLRegenerationResultResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self[request.query_id] = self.Resource( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id=request.query_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="failed", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error=SQLRegenerationResultResponse.SQLRegenerationError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error=self.Resource.Error( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
code="NO_RELEVANT_SQL", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message="SQL is not executable", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message="SQL is not executable" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._sql_regeneration_results[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] = SQLRegenerationResultResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self[request.query_id] = self.Resource( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id=request.query_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="finished", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
response=sql_regeneration_result, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
response=self.Resource.ResponseDetails( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
description=sql_regeneration_result["description"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
steps=sql_regeneration_result["steps"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.exception(f"sql regeneration pipeline - OTHERS: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._sql_regeneration_results[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
sql_regeneration_request.query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
] = SQLRegenerationResultResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self[request.query_id] = self.Resource( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id=request.query_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="failed", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error=SQLRegenerationResultResponse.SQLRegenerationError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error=self.Resource.Error( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
code="OTHERS", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message=str(e), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message=str(e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_sql_regeneration_result( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, sql_regeneration_result_request: SQLRegenerationResultRequest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> SQLRegenerationResultResponse: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
sql_regeneration_result_request.query_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
not in self._sql_regeneration_results | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return SQLRegenerationResultResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self[request.query_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __getitem__(self, query_id: str) -> Resource: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
response = self._cache.get(query_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if response is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message = f"SQL Regeneration Resource with ID '{query_id}' not found." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.exception(message) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.Resource( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
query_id=query_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
status="failed", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error=SQLRegenerationResultResponse.SQLRegenerationError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
error=self.Resource.Error( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
code="OTHERS", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message=f"{sql_regeneration_result_request.query_id} is not found", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
message=message | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return response | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._sql_regeneration_results[sql_regeneration_result_request.query_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __setitem__(self, query_id: str, value: Resource): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._cache[query_id] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gently ping to mention here has the same issue as the previous comment.