Skip to content

chore(wren-ai-service): Refactor Services to use Resource Classes #848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
442 changes: 180 additions & 262 deletions wren-ai-service/src/web/v1/services/ask.py

Large diffs are not rendered by default.

218 changes: 95 additions & 123 deletions wren-ai-service/src/web/v1/services/ask_details.py
Original file line number Diff line number Diff line change
@@ -18,164 +18,136 @@ class SQLBreakdown(BaseModel):
cte_name: str


# POST /v1/ask-details
class AskDetailsConfigurations(BaseModel):
language: str = "English"
class AskDetails:
class Input(BaseModel):
class Configurations(BaseModel):
language: str = "English"

id: str
query: str
sql: str
summary: str
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
project_id: Optional[str] = None
user_id: Optional[str] = None
configurations: Configurations = Configurations(language="English")

class Resource(BaseModel):
class Error(BaseModel):
code: Literal["NO_RELEVANT_SQL", "OTHERS", "RESOURCE_NOT_FOUND"]
message: str

class Details(BaseModel):
description: str
steps: List[SQLBreakdown]

id: str
status: Literal["understanding", "searching", "generating", "finished", "failed"]
response: Optional[Details] = None
error: Optional[Error] = None


class AskDetailsRequest(BaseModel):
_query_id: str | None = None
query: str
sql: str
summary: str
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
project_id: Optional[str] = None
user_id: Optional[str] = None
configurations: AskDetailsConfigurations = AskDetailsConfigurations(
language="English"
)

@property
def query_id(self) -> str:
return self._query_id

@query_id.setter
def query_id(self, query_id: str):
self._query_id = query_id


class AskDetailsResponse(BaseModel):
query_id: str


# GET /v1/ask-details/{query_id}/result
class AskDetailsResultRequest(BaseModel):
query_id: str


class AskDetailsResultResponse(BaseModel):
class AskDetailsResponseDetails(BaseModel):
description: str
steps: List[SQLBreakdown]

class AskDetailsError(BaseModel):
code: Literal["NO_RELEVANT_SQL", "OTHERS"]
message: str

status: Literal["understanding", "searching", "generating", "finished", "failed"]
response: Optional[AskDetailsResponseDetails] = None
error: Optional[AskDetailsError] = None


class AskDetailsService:
def __init__(
self,
pipelines: Dict[str, Pipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
self._pipelines = pipelines
self._ask_details_results: Dict[str, AskDetailsResultResponse] = TTLCache(
self._cache: Dict[str, AskDetails.Resource] = TTLCache(
maxsize=maxsize, ttl=ttl
)

@async_timer
@observe(name="Ask Details(Breakdown SQL)")
@trace_metadata
async def ask_details(
def _handle_exception(
self,
ask_details_request: AskDetailsRequest,
**kwargs,
input: Input,
error_message: str,
code: str = "OTHERS",
):
results = {
"ask_details_result": {},
"metadata": {
"error_type": "",
"error_message": "",
},
}

try:
# ask details status can be understanding, searching, generating, finished, stopped
# we will need to handle business logic for each status
query_id = ask_details_request.query_id

self._ask_details_results[query_id] = AskDetailsResultResponse(
status="understanding",
)
self._cache[input.id] = self.Resource(
id=input.id,
status="failed",
error=self.Resource.Error(code=code, message=error_message),
)
logger.error(error_message)

self._ask_details_results[query_id] = AskDetailsResultResponse(
status="searching",
)
@async_timer
@observe(name="Ask Details (Breakdown SQL)")
@trace_metadata
async def generate(self, request: Input, **kwargs) -> Dict:
logger.info("Ask Details pipeline is running...")

self._ask_details_results[query_id] = AskDetailsResultResponse(
status="generating",
)
try:
# Update status through the pipeline stages
for status in ["understanding", "searching", "generating"]:
self._cache[request.id] = self.Resource(
id=request.id,
status=status,
)

generation_result = await self._pipelines["sql_breakdown"].run(
query=ask_details_request.query,
sql=ask_details_request.sql,
project_id=ask_details_request.project_id,
language=ask_details_request.configurations.language,
query=request.query,
sql=request.sql,
project_id=request.project_id,
language=request.configurations.language,
)

ask_details_result = generation_result["post_process"]["results"]

if not ask_details_result["steps"]:
quoted_sql, no_error = add_quotes(ask_details_request.sql)
quoted_sql, no_error = add_quotes(request.sql)
ask_details_result["steps"] = [
{
"sql": quoted_sql if no_error else ask_details_request.sql,
"summary": ask_details_request.summary,
"sql": quoted_sql if no_error else request.sql,
"summary": request.summary,
"cte_name": "",
}
]
results["metadata"]["error_type"] = "SQL_BREAKDOWN_FAILED"

self._ask_details_results[query_id] = AskDetailsResultResponse(
self._cache[request.id] = self.Resource(
id=request.id,
status="finished",
response=AskDetailsResultResponse.AskDetailsResponseDetails(
**ask_details_result
),
response=self.Resource.Details(**ask_details_result),
)

results["ask_details_result"] = ask_details_result
return {
"ask_details_result": ask_details_result,
"metadata": {
"error_type": "SQL_BREAKDOWN_FAILED" if not ask_details_result["steps"] else "",
"error_message": "",
},
}

return results
except Exception as e:
logger.exception(f"ask-details pipeline - OTHERS: {e}")

self._ask_details_results[
ask_details_request.query_id
] = AskDetailsResultResponse(
status="failed",
error=AskDetailsResultResponse.AskDetailsError(
code="OTHERS",
message=str(e),
),
self._handle_exception(
request,
f"An error occurred during Ask Details generation: {str(e)}",
)

results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = str(e)
return results

def get_ask_details_result(
self,
ask_details_result_request: AskDetailsResultRequest,
) -> AskDetailsResultResponse:
if (
result := self._ask_details_results.get(ask_details_result_request.query_id)
) is None:
logger.exception(
f"ask-details pipeline - OTHERS: {ask_details_result_request.query_id} is not found"
)
return AskDetailsResultResponse(

return {
"ask_details_result": {},
"metadata": {
"error_type": "OTHERS",
"error_message": str(e),
},
}

def __getitem__(self, id: str) -> Resource:
response = self._cache.get(id)

if response is None:
message = f"Ask Details Resource with ID '{id}' not found."
logger.exception(message)
return self.Resource(
id=id,
status="failed",
error=AskDetailsResultResponse.AskDetailsError(
code="OTHERS",
message=f"{ask_details_result_request.query_id} is not found",
error=self.Resource.Error(
code="RESOURCE_NOT_FOUND",
message=message,
),
)

return result
return response

def __setitem__(self, id: str, value: Resource):
self._cache[id] = value
180 changes: 100 additions & 80 deletions wren-ai-service/src/web/v1/services/semantics_preparation.py
Original file line number Diff line number Diff line change
@@ -3,120 +3,140 @@

from cachetools import TTLCache
from langfuse.decorators import observe
from pydantic import AliasChoices, BaseModel, Field
from pydantic import BaseModel, Field

from src.core.pipeline import BasicPipeline
from src.utils import async_timer, trace_metadata

logger = logging.getLogger("wren-ai-service")


# POST /v1/semantics-preparations
class SemanticsPreparationRequest(BaseModel):
mdl: str
# don't recommend to use id as a field name, but it's used in the API spec
# so we need to support as a choice, and will remove it in the future
mdl_hash: str = Field(validation_alias=AliasChoices("mdl_hash", "id"))
project_id: Optional[str] = None
user_id: Optional[str] = None
class SemanticsPreparation:
"""Service for preparing and managing semantics operations."""

class Input(BaseModel):
"""Input model for semantics preparation requests."""
mdl: str
mdl_hash: str = Field(description="Unique identifier for the MDL")
project_id: Optional[str] = None
user_id: Optional[str] = None

class SemanticsPreparationResponse(BaseModel):
# don't recommend to use id as a field name, but it's used in the API spec
# so we need to support as a choice, and will remove it in the future
mdl_hash: str = Field(serialization_alias="id")
class Resource(BaseModel):
"""Resource model representing the state and result of semantics preparation."""
class Error(BaseModel):
"""Error information when preparation fails."""
code: Literal["OTHERS", "NOT_FOUND", "INDEXING_FAILED"]
message: str

mdl_hash: str
status: Literal["indexing", "finished", "failed"] = "indexing"
error: Optional[Error] = None
metadata: Optional[dict] = Field(default_factory=dict)

# GET /v1/semantics-preparations/{mdl_hash}/status
class SemanticsPreparationStatusRequest(BaseModel):
# don't recommend to use id as a field name, but it's used in the API spec
# so we need to support as a choice, and will remove it in the future
mdl_hash: str = Field(validation_alias=AliasChoices("mdl_hash", "id"))


class SemanticsPreparationStatusResponse(BaseModel):
class SemanticsPreparationError(BaseModel):
code: Literal["OTHERS"]
message: str

status: Literal["indexing", "finished", "failed"]
error: Optional[SemanticsPreparationError] = None


class SemanticsPreparationService:
def __init__(
self,
pipelines: Dict[str, BasicPipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
"""Initialize the SemanticsPreparation service.
Args:
pipelines: Dictionary of pipeline implementations
maxsize: Maximum size of the cache
ttl: Time-to-live for cache entries in seconds
"""
self._pipelines = pipelines
self._prepare_semantics_statuses: Dict[
str, SemanticsPreparationStatusResponse
] = TTLCache(maxsize=maxsize, ttl=ttl)
self._cache: Dict[str, SemanticsPreparation.Resource] = TTLCache(
maxsize=maxsize, ttl=ttl
)

def _handle_exception(
self,
input: Input,
error_message: str,
code: str = "OTHERS",
) -> Resource:
"""Handle exceptions by creating and caching error resources.
Args:
input: The input that caused the exception
error_message: Description of the error
code: Error code identifier
"""
resource = self.Resource(
mdl_hash=input.mdl_hash,
status="failed",
error=self.Resource.Error(code=code, message=error_message),
metadata={"error_type": code, "error_message": error_message},
)
self._cache[input.mdl_hash] = resource
logger.error(f"Semantics preparation failed: {error_message}")
return resource

@async_timer
@observe(name="Prepare Semantics")
@trace_metadata
async def prepare_semantics(
self,
prepare_semantics_request: SemanticsPreparationRequest,
**kwargs,
):
results = {
"metadata": {
"error_type": "",
"error_message": "",
},
}
async def prepare(self, input: Input, **kwargs) -> Resource:
"""Prepare semantics based on the provided input.
Args:
input: The preparation request parameters
**kwargs: Additional keyword arguments for pipeline execution
Returns:
Resource object containing the preparation status and results
"""
logger.info(f"Starting semantics preparation for MDL hash: {input.mdl_hash}")

# Initialize resource in cache
self._cache[input.mdl_hash] = self.Resource(
mdl_hash=input.mdl_hash,
status="indexing"
)

try:
logger.info(f"MDL: {prepare_semantics_request.mdl}")
logger.info(f"Processing MDL: {input.mdl}")
await self._pipelines["indexing"].run(
mdl_str=prepare_semantics_request.mdl,
id=prepare_semantics_request.project_id,
mdl_str=input.mdl,
id=input.project_id,
)

self._prepare_semantics_statuses[
prepare_semantics_request.mdl_hash
] = SemanticsPreparationStatusResponse(
# Update cache with success result
self._cache[input.mdl_hash] = self.Resource(
mdl_hash=input.mdl_hash,
status="finished",
metadata={"completion_time": kwargs.get("timestamp")}
)
except Exception as e:
logger.exception(f"Failed to prepare semantics: {e}")

self._prepare_semantics_statuses[
prepare_semantics_request.mdl_hash
] = SemanticsPreparationStatusResponse(
status="failed",
error=SemanticsPreparationStatusResponse.SemanticsPreparationError(
code="OTHERS",
message=f"Failed to prepare semantics: {e}",
),
except Exception as e:
return self._handle_exception(
input,
f"Failed to prepare semantics: {str(e)}",
code="INDEXING_FAILED"
)

results["metadata"]["error_type"] = "INDEXING_FAILED"
results["metadata"]["error_message"] = str(e)
return self._cache[input.mdl_hash]

return results
def get_status(self, mdl_hash: str) -> Resource:
"""Retrieve the current status of a semantics preparation request.
def get_prepare_semantics_status(
self, prepare_semantics_status_request: SemanticsPreparationStatusRequest
) -> SemanticsPreparationStatusResponse:
if (
result := self._prepare_semantics_statuses.get(
prepare_semantics_status_request.mdl_hash
)
) is None:
logger.exception(
f"id is not found for SemanticsPreparation: {prepare_semantics_status_request.mdl_hash}"
)
return SemanticsPreparationStatusResponse(
Args:
mdl_hash: The identifier for the preparation request
Returns:
Resource object containing the current status
"""
if (resource := self._cache.get(mdl_hash)) is None:
message = f"No preparation found for MDL hash: {mdl_hash}"
logger.error(message)
return self.Resource(
mdl_hash=mdl_hash,
status="failed",
error=SemanticsPreparationStatusResponse.SemanticsPreparationError(
code="OTHERS",
message="{prepare_semantics_status_request.id} is not found",
),
error=self.Resource.Error(
code="NOT_FOUND",
message=message
)
)

return result
return resource
157 changes: 54 additions & 103 deletions wren-ai-service/src/web/v1/services/sql_answer.py
Original file line number Diff line number Diff line change
@@ -1,156 +1,107 @@
import logging
from typing import Dict, Literal, Optional

from cachetools import TTLCache
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.utils import async_timer, trace_metadata

logger = logging.getLogger("wren-ai-service")

class SQLAnswerRequest(BaseModel):
class Input(BaseModel):
query: str
sql: str
sql_summary: str
thread_id: Optional[str] = None
user_id: Optional[str] = None

# POST /v1/sql-answers
class SqlAnswerRequest(BaseModel):
_query_id: str | None = None
query: str
sql: str
sql_summary: str
thread_id: Optional[str] = None
user_id: Optional[str] = None

@property
def query_id(self) -> str:
return self._query_id

@query_id.setter
def query_id(self, query_id: str):
self._query_id = query_id


class SqlAnswerResponse(BaseModel):
query_id: str


# GET /v1/sql-answers/{query_id}/result
class SqlAnswerResultRequest(BaseModel):
query_id: str


class SqlAnswerResultResponse(BaseModel):
class SqlAnswerError(BaseModel):
code: Literal["OTHERS"]
message: str

status: Literal["understanding", "processing", "finished", "failed"]
response: Optional[str] = None
error: Optional[SqlAnswerError] = None
class Resource(BaseModel):
class Error(BaseModel):
code: Literal["OTHERS"]
message: str

query_id: str
status: Literal["understanding", "processing", "finished", "failed"] = None
response: Optional[str] = None
error: Optional[Error] = None

class SqlAnswerService:
def __init__(
self,
pipelines: Dict[str, BasicPipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
self._pipelines = pipelines
self._sql_answer_results: Dict[str, SqlAnswerResultResponse] = TTLCache(
self._cache: Dict[str, SQLAnswerRequest.Resource] = TTLCache(
maxsize=maxsize, ttl=ttl
)

@async_timer
@observe(name="SQL Answer")
@trace_metadata
async def sql_answer(
self,
sql_answer_request: SqlAnswerRequest,
**kwargs,
):
results = {
"sql_answer_result": {},
"metadata": {
"error": {
"type": "",
"message": "",
}
},
}

async def generate(self, request: Input, **kwargs) -> Resource:
try:
query_id = sql_answer_request.query_id

self._sql_answer_results[query_id] = SqlAnswerResultResponse(
status="understanding",
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="understanding"
)

self._sql_answer_results[query_id] = SqlAnswerResultResponse(
status="processing",
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="processing"
)

data = await self._pipelines["sql_answer"].run(
query=sql_answer_request.query,
sql=sql_answer_request.sql,
sql_summary=sql_answer_request.sql_summary,
project_id=sql_answer_request.thread_id,
query=request.query,
sql=request.sql,
sql_summary=request.sql_summary,
project_id=request.thread_id,
)
api_results = data["post_process"]["results"]
if answer := api_results["answer"]:
self._sql_answer_results[query_id] = SqlAnswerResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="finished",
response=answer,
response=answer
)
else:
self._sql_answer_results[query_id] = SqlAnswerResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=SqlAnswerResultResponse.SqlAnswerError(
error=self.Resource.Error(
code="OTHERS",
message=api_results["error"],
),
message=api_results["error"]
)
)

results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = api_results["error"]

results["sql_answer_result"] = {
"answer": api_results["answer"],
"reasoning": api_results["reasoning"],
}
return results
except Exception as e:
logger.exception(f"sql answer pipeline - OTHERS: {e}")

self._sql_answer_results[
sql_answer_request.query_id
] = SqlAnswerResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=SqlAnswerResultResponse.SqlAnswerError(
error=self.Resource.Error(
code="OTHERS",
message=str(e),
),
message=str(e)
)
)

results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = str(e)
return results
return self[request.query_id]

def get_sql_answer_result(
self,
sql_answer_result_request: SqlAnswerResultRequest,
) -> SqlAnswerResultResponse:
if (
result := self._sql_answer_results.get(sql_answer_result_request.query_id)
) is None:
logger.exception(
f"sql answer pipeline - OTHERS: {sql_answer_result_request.query_id} is not found"
)
return SqlAnswerResultResponse(
def __getitem__(self, query_id: str) -> Resource:
response = self._cache.get(query_id)
if response is None:
message = f"SQL Answer Resource with ID '{query_id}' not found."
logger.exception(message)
return self.Resource(
query_id=query_id,
status="failed",
error=SqlAnswerResultResponse.SqlAnswerError(
error=self.Resource.Error(
code="OTHERS",
message=f"{sql_answer_result_request.query_id} is not found",
),
message=message
)
)
return response

return result
def __setitem__(self, query_id: str, value: Resource):
self._cache[query_id] = value
257 changes: 94 additions & 163 deletions wren-ai-service/src/web/v1/services/sql_expansion.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,59 @@
import logging
from typing import Dict, List, Literal, Optional

from cachetools import TTLCache
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.utils import async_timer, remove_sql_summary_duplicates, trace_metadata
from src.web.v1.services.ask import AskError, AskHistory
from src.web.v1.services.ask_details import SQLBreakdown

logger = logging.getLogger("wren-ai-service")

class SQLExpansionRequest(BaseModel):
class Input(BaseModel):
query: str
history: AskHistory
project_id: Optional[str] = None
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
user_id: Optional[str] = None
configurations: Dict[
"language": str
] = {"language": "English"}
Comment on lines +21 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gently ping to mention here has the same issue as the previous comment.


class Resource(BaseModel):
class Result(BaseModel):
description: str
steps: List[SQLBreakdown]

class Error(BaseModel):
code: Literal["NO_RELEVANT_DATA", "NO_RELEVANT_SQL", "OTHERS"]
message: str

query_id: str
status: Literal[
"understanding", "searching", "generating", "finished", "failed", "stopped"
] = None
response: Optional[Result] = None
error: Optional[Error] = None

# POST /v1/sql-expansions
class SqlExpansionConfigurations(BaseModel):
language: str = "English"


class SqlExpansionRequest(BaseModel):
_query_id: str | None = None
query: str
history: AskHistory
# for identifying which collection to access from vectordb
project_id: Optional[str] = None
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
user_id: Optional[str] = None
configurations: SqlExpansionConfigurations = SqlExpansionConfigurations(
language="English"
)

@property
def query_id(self) -> str:
return self._query_id

@query_id.setter
def query_id(self, query_id: str):
self._query_id = query_id


class SqlExpansionResponse(BaseModel):
query_id: str


# PATCH /v1/sql-expansions/{query_id}
class StopSqlExpansionRequest(BaseModel):
_query_id: str | None = None
status: Literal["stopped"]

@property
def query_id(self) -> str:
return self._query_id

@query_id.setter
def query_id(self, query_id: str):
self._query_id = query_id


class StopSqlExpansionResponse(BaseModel):
query_id: str


# GET /v1/sql-expansions/{query_id}/result
class SqlExpansionResultRequest(BaseModel):
query_id: str


class SqlExpansionResultResponse(BaseModel):
class SqlExpansionResult(BaseModel):
description: str
steps: List[SQLBreakdown]

status: Literal[
"understanding", "searching", "generating", "finished", "failed", "stopped"
]
response: Optional[SqlExpansionResult] = None
error: Optional[AskError] = None


class SqlExpansionService:
def __init__(
self,
pipelines: Dict[str, BasicPipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
self._pipelines = pipelines
self._sql_expansion_results: Dict[str, SqlExpansionResultResponse] = TTLCache(
self._cache: Dict[str, SQLExpansionRequest.Resource] = TTLCache(
maxsize=maxsize, ttl=ttl
)

def _is_stopped(self, query_id: str):
if (
result := self._sql_expansion_results.get(query_id)
result := self._cache.get(query_id)
) is not None and result.status == "stopped":
return True

return False

def _get_failed_dry_run_results(self, invalid_generation_results: list[dict]):
@@ -107,69 +64,58 @@ def _get_failed_dry_run_results(self, invalid_generation_results: list[dict]):
@async_timer
@observe(name="SQL Expansion")
@trace_metadata
async def sql_expansion(
self,
sql_expansion_request: SqlExpansionRequest,
**kwargs,
):
results = {
"sql_expansion_result": {},
"metadata": {
"error_type": "",
"error_message": "",
},
}

async def generate(self, request: Input, **kwargs) -> Resource:
try:
query_id = sql_expansion_request.query_id

if not self._is_stopped(query_id):
self._sql_expansion_results[query_id] = SqlExpansionResultResponse(
status="understanding",
if not self._is_stopped(request.query_id):
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="understanding"
)

if not self._is_stopped(query_id):
self._sql_expansion_results[query_id] = SqlExpansionResultResponse(
status="searching",
if not self._is_stopped(request.query_id):
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="searching"
)

query_for_retrieval = (
sql_expansion_request.history.summary
request.history.summary
+ " "
+ sql_expansion_request.query
+ request.query
)
retrieval_result = await self._pipelines["retrieval"].run(
query=query_for_retrieval,
id=sql_expansion_request.project_id,
id=request.project_id,
)
documents = retrieval_result.get("construct_retrieval_results", [])

if not documents:
logger.exception(
f"sql expansion pipeline - NO_RELEVANT_DATA: {sql_expansion_request.query}"
f"sql expansion pipeline - NO_RELEVANT_DATA: {request.query}"
)
self._sql_expansion_results[query_id] = SqlExpansionResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=AskError(
error=self.Resource.Error(
code="NO_RELEVANT_DATA",
message="No relevant data",
),
message="No relevant data"
)
)
results["metadata"]["error_type"] = "NO_RELEVANT_DATA"
return results
return self[request.query_id]

if not self._is_stopped(query_id):
self._sql_expansion_results[query_id] = SqlExpansionResultResponse(
status="generating",
if not self._is_stopped(request.query_id):
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="generating"
)

sql_expansion_generation_results = await self._pipelines[
"sql_expansion"
].run(
query=sql_expansion_request.query,
query=request.query,
contexts=documents,
history=sql_expansion_request.history,
project_id=sql_expansion_request.project_id,
history=request.history,
project_id=request.project_id,
)

valid_generation_results = []
@@ -188,7 +134,7 @@ async def sql_expansion(
].run(
contexts=documents,
invalid_generation_results=failed_dry_run_results,
project_id=sql_expansion_request.project_id,
project_id=request.project_id,
)
valid_generation_results += sql_correction_results["post_process"][
"valid_generation_results"
@@ -197,34 +143,33 @@ async def sql_expansion(
valid_sql_summary_results = []
if valid_generation_results:
sql_summary_results = await self._pipelines["sql_summary"].run(
query=sql_expansion_request.query,
query=request.query,
sqls=valid_generation_results,
language=sql_expansion_request.configurations.language,
language=request.configurations["language"],
)
valid_sql_summary_results = sql_summary_results["post_process"][
"sql_summary_results"
]
# remove duplicates of valid_sql_summary_results, which consists of a sql and a summary
valid_sql_summary_results = remove_sql_summary_duplicates(
valid_sql_summary_results
)

if not valid_sql_summary_results:
logger.exception(
f"sql expansion pipeline - NO_RELEVANT_SQL: {sql_expansion_request.query}"
f"sql expansion pipeline - NO_RELEVANT_SQL: {request.query}"
)
self._sql_expansion_results[query_id] = SqlExpansionResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=AskError(
error=self.Resource.Error(
code="NO_RELEVANT_SQL",
message="No relevant SQL",
),
message="No relevant SQL"
)
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
return results
return self[request.query_id]

api_results = SqlExpansionResultResponse.SqlExpansionResult(
description=sql_expansion_request.history.summary,
api_results = self.Resource.Result(
description=request.history.summary,
steps=[
{
"sql": valid_generation_results[0]["sql"],
@@ -234,56 +179,42 @@ async def sql_expansion(
],
)

self._sql_expansion_results[query_id] = SqlExpansionResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="finished",
response=api_results,
response=api_results
)

results["sql_expansion_result"] = api_results
return results
return self[request.query_id]
except Exception as e:
logger.exception(f"sql expansion pipeline - OTHERS: {e}")

self._sql_expansion_results[
sql_expansion_request.query_id
] = SqlExpansionResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=AskError(
error=self.Resource.Error(
code="OTHERS",
message=str(e),
),
)

results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = str(e)
return results

def stop_sql_expansion(
self,
stop_sql_expansion_request: StopSqlExpansionRequest,
):
self._sql_expansion_results[
stop_sql_expansion_request.query_id
] = SqlExpansionResultResponse(status="stopped")

def get_sql_expansion_result(
self,
sql_expansion_result_request: SqlExpansionResultRequest,
) -> SqlExpansionResultResponse:
if (
result := self._sql_expansion_results.get(
sql_expansion_result_request.query_id
)
) is None:
logger.exception(
f"sql-expansion pipeline - OTHERS: {sql_expansion_result_request.query_id} is not found"
message=str(e)
)
)
return SqlExpansionResultResponse(
return self[request.query_id]

def stop(self, request: Dict[str, str]) -> None:
self[request["query_id"]] = self.Resource(status="stopped")

def __getitem__(self, query_id: str) -> Resource:
response = self._cache.get(query_id)
if response is None:
message = f"SQL Expansion Resource with ID '{query_id}' not found."
logger.exception(message)
return self.Resource(
query_id=query_id,
status="failed",
error=AskError(
error=self.Resource.Error(
code="OTHERS",
message=f"{sql_expansion_result_request.query_id} is not found",
),
message=message
)
)
return response

return result
def __setitem__(self, query_id: str, value: Resource):
self._cache[query_id] = value
159 changes: 67 additions & 92 deletions wren-ai-service/src/web/v1/services/sql_explanation.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,71 @@
import asyncio
import logging
from typing import Dict, List, Literal, Optional

from cachetools import TTLCache
from haystack import Pipeline
from pydantic import BaseModel

from src.utils import async_timer

logger = logging.getLogger("wren-ai-service")


# POST /v1/sql-explanations
class StepWithAnalysisResult(BaseModel):
sql: str
summary: str
sql_analysis_results: List[Dict]


class SQLExplanationRequest(BaseModel):
_query_id: str | None = None
question: str
steps_with_analysis_results: List[StepWithAnalysisResult]
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
project_id: Optional[str] = None
user_id: Optional[str] = None

@property
def query_id(self) -> str:
return self._query_id

@query_id.setter
def query_id(self, query_id: str):
self._query_id = query_id


class SQLExplanationResponse(BaseModel):
query_id: str


# GET /v1/sql-explanations/{query_id}/result
class SQLExplanationResultRequest(BaseModel):
query_id: str


class SQLExplanationResultResponse(BaseModel):
class SQLExplanationResultError(BaseModel):
code: Literal["OTHERS"]
message: str

status: Literal["understanding", "generating", "finished", "failed"]
response: Optional[List[List[Dict]]] = None
error: Optional[SQLExplanationResultError] = None

class Input(BaseModel):
question: str
steps_with_analysis_results: List[
Dict[
"step_sql": str,
"step_summary": str,
"step_sql_analysis_results": List[Dict]
]
Comment on lines +15 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gently ping to mention here has the same issue as the previous comment.

]
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
project_id: Optional[str] = None
user_id: Optional[str] = None

class Resource(BaseModel):
class Error(BaseModel):
code: Literal["OTHERS"]
message: str

query_id: str
status: Literal["understanding", "generating", "finished", "failed"] = None
response: Optional[List[List[Dict]]] = None
error: Optional[Error] = None

class SQLExplanationService:
def __init__(
self,
pipelines: Dict[str, Pipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
self._pipelines = pipelines
self._sql_explanation_results: Dict[
str, SQLExplanationResultResponse
] = TTLCache(maxsize=maxsize, ttl=ttl)
self._cache: Dict[str, SQLExplanationRequest.Resource] = TTLCache(
maxsize=maxsize, ttl=ttl
)

@async_timer
async def sql_explanation(
self,
sql_explanation_request: SQLExplanationRequest,
**kwargs,
):
async def generate(self, request: Input, **kwargs) -> Resource:
try:
query_id = sql_explanation_request.query_id

self._sql_explanation_results[query_id] = SQLExplanationResultResponse(
status="understanding",
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="understanding"
)

self._sql_explanation_results[query_id] = SQLExplanationResultResponse(
status="generating",
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="generating"
)

async def _task(
question: str,
step_with_analysis_results: StepWithAnalysisResult,
):
async def _task(question: str, step_with_analysis_results: Dict):
return await self._pipelines["sql_explanation"].run(
question=question,
step_with_analysis_results=step_with_analysis_results,
step_with_analysis_results=step_with_analysis_results
)

tasks = [
_task(
sql_explanation_request.question,
step_with_analysis_results,
)
for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results
_task(request.question, step)
for step in request.steps_with_analysis_results
]
generation_results = await asyncio.gather(*tasks)

@@ -108,39 +75,47 @@ async def _task(
]

if sql_explanation_results:
self._sql_explanation_results[query_id] = SQLExplanationResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="finished",
response=sql_explanation_results,
response=sql_explanation_results
)
else:
self._sql_explanation_results[query_id] = SQLExplanationResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=SQLExplanationResultResponse.SQLExplanationResultError(
error=self.Resource.Error(
code="OTHERS",
message="No SQL explanation is found",
),
message="No SQL explanation is found"
)
)
except Exception as e:
logger.exception(
f"sql explanation pipeline - Failed to provide SQL explanation: {e}"
)
self._sql_explanation_results[
sql_explanation_request.query_id
] = SQLExplanationResultResponse(
logger.exception(f"sql explanation pipeline - Failed to provide SQL explanation: {e}")
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=SQLExplanationResultResponse.SQLExplanationResultError(
error=self.Resource.Error(
code="OTHERS",
message=str(e),
),
message=str(e)
)
)

def get_sql_explanation_result(
self, sql_explanation_result_request: SQLExplanationResultRequest
) -> SQLExplanationResultResponse:
if sql_explanation_result_request.query_id not in self._sql_explanation_results:
return SQLExplanationResultResponse(
return self[request.query_id]

def __getitem__(self, query_id: str) -> Resource:
response = self._cache.get(query_id)
if response is None:
message = f"SQL Explanation Resource with ID '{query_id}' not found."
logger.exception(message)
return self.Resource(
query_id=query_id,
status="failed",
error=f"{sql_explanation_result_request.query_id} is not found",
error=self.Resource.Error(
code="OTHERS",
message=message
)
)
return response

return self._sql_explanation_results[sql_explanation_result_request.query_id]
def __setitem__(self, query_id: str, value: Resource):
self._cache[query_id] = value
194 changes: 82 additions & 112 deletions wren-ai-service/src/web/v1/services/sql_regeneration.py
Original file line number Diff line number Diff line change
@@ -1,163 +1,133 @@
import logging
from typing import Dict, List, Literal, Optional

from cachetools import TTLCache
from haystack import Pipeline
from pydantic import BaseModel

from src.utils import async_timer
from src.web.v1.services.ask_details import SQLBreakdown

logger = logging.getLogger("wren-ai-service")


# POST /v1/sql-regenerations
class DecisionPoint(BaseModel):
type: Literal["filter", "selectItems", "relation", "groupByKeys", "sortings"]
value: str


class CorrectionPoint(BaseModel):
type: Literal[
"sql_expression", "nl_expression"
] # nl_expression is natural language expression
value: str


class UserCorrection(BaseModel):
before: DecisionPoint
after: CorrectionPoint


class SQLExplanationWithUserCorrections(BaseModel):
summary: str
sql: str
cte_name: str
corrections: List[UserCorrection]


class SQLRegenerationRequest(BaseModel):
_query_id: str | None = None
description: str
steps: List[SQLExplanationWithUserCorrections]
mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
project_id: Optional[str] = None
user_id: Optional[str] = None

@property
def query_id(self) -> str:
return self._query_id

@query_id.setter
def query_id(self, query_id: str):
self._query_id = query_id


class SQLRegenerationResponse(BaseModel):
query_id: str


# GET /v1/sql-regenerations/{query_id}/result
class SQLRegenerationResultRequest(BaseModel):
query_id: str


class SQLRegenerationResultResponse(BaseModel):
class SQLRegenerationResponseDetails(BaseModel):
class Input(BaseModel):
description: str
steps: List[SQLBreakdown]

class SQLRegenerationError(BaseModel):
code: Literal["NO_RELEVANT_SQL", "OTHERS"]
message: str

status: Literal["understanding", "generating", "finished", "failed"]
response: Optional[SQLRegenerationResponseDetails] = None
error: Optional[SQLRegenerationError] = None
steps: List[
Dict[
"summary": str,
"sql": str,
"cte_name": str,
"corrections": List[
Dict[
"before": Dict[
"type": Literal["filter", "selectItems", "relation", "groupByKeys", "sortings"],
"value": str
],
"after": Dict[
"type": Literal["sql_expression", "nl_expression"],
"value": str
]
]
]
]
]
Comment on lines +15 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this type declaration seems be wrong, could you help to fix? thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Dict[
"summary": str,
"sql": str,
"cte_name": str,
"corrections": List[
Dict[
"before": Dict[
"type": Literal["filter", "selectItems", "relation", "groupByKeys", "sortings"],
"value": str
],
"after": Dict[
"type": Literal["sql_expression", "nl_expression"],
"value": str
]
]
]
]
]
Dict[
"summary": str,
"sql": str,
"cte_name": str,
"corrections": list[dict]
]
]
]

mdl_hash: Optional[str] = None
thread_id: Optional[str] = None
project_id: Optional[str] = None
user_id: Optional[str] = None

class Resource(BaseModel):
class Error(BaseModel):
code: Literal["NO_RELEVANT_SQL", "OTHERS"]
message: str

class ResponseDetails(BaseModel):
description: str
steps: List[SQLBreakdown]

query_id: str
status: Literal["understanding", "generating", "finished", "failed"] = None
response: Optional[ResponseDetails] = None
error: Optional[Error] = None


class SQLRegenerationService:
def __init__(
self,
pipelines: Dict[str, Pipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
self._pipelines = pipelines
self._sql_regeneration_results: Dict[
str, SQLRegenerationResultResponse
] = TTLCache(maxsize=maxsize, ttl=ttl)
self._cache: Dict[str, SQLRegenerationRequest.Resource] = TTLCache(
maxsize=maxsize, ttl=ttl
)

@async_timer
async def sql_regeneration(
self,
sql_regeneration_request: SQLRegenerationRequest,
**kwargs,
):
async def generate(self, request: Input, **kwargs) -> Resource:
try:
query_id = sql_regeneration_request.query_id

self._sql_regeneration_results[query_id] = SQLRegenerationResultResponse(
status="understanding",
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="understanding"
)

self._sql_regeneration_results[query_id] = SQLRegenerationResultResponse(
status="generating",
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="generating"
)

generation_result = await self._pipelines["sql_regeneration"].run(
description=sql_regeneration_request.description,
steps=sql_regeneration_request.steps,
project_id=sql_regeneration_request.project_id,
description=request.description,
steps=request.steps,
project_id=request.project_id,
)

sql_regeneration_result = generation_result[
"sql_regeneration_post_process"
]["results"]

if not sql_regeneration_result["steps"]:
self._sql_regeneration_results[
query_id
] = SQLRegenerationResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=SQLRegenerationResultResponse.SQLRegenerationError(
error=self.Resource.Error(
code="NO_RELEVANT_SQL",
message="SQL is not executable",
),
message="SQL is not executable"
)
)
else:
self._sql_regeneration_results[
query_id
] = SQLRegenerationResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="finished",
response=sql_regeneration_result,
response=self.Resource.ResponseDetails(
description=sql_regeneration_result["description"],
steps=sql_regeneration_result["steps"]
)
)
except Exception as e:
logger.exception(f"sql regeneration pipeline - OTHERS: {e}")
self._sql_regeneration_results[
sql_regeneration_request.query_id
] = SQLRegenerationResultResponse(
self[request.query_id] = self.Resource(
query_id=request.query_id,
status="failed",
error=SQLRegenerationResultResponse.SQLRegenerationError(
error=self.Resource.Error(
code="OTHERS",
message=str(e),
),
message=str(e)
)
)

def get_sql_regeneration_result(
self, sql_regeneration_result_request: SQLRegenerationResultRequest
) -> SQLRegenerationResultResponse:
if (
sql_regeneration_result_request.query_id
not in self._sql_regeneration_results
):
return SQLRegenerationResultResponse(
return self[request.query_id]

def __getitem__(self, query_id: str) -> Resource:
response = self._cache.get(query_id)
if response is None:
message = f"SQL Regeneration Resource with ID '{query_id}' not found."
logger.exception(message)
return self.Resource(
query_id=query_id,
status="failed",
error=SQLRegenerationResultResponse.SQLRegenerationError(
error=self.Resource.Error(
code="OTHERS",
message=f"{sql_regeneration_result_request.query_id} is not found",
),
message=message
)
)
return response

return self._sql_regeneration_results[sql_regeneration_result_request.query_id]
def __setitem__(self, query_id: str, value: Resource):
self._cache[query_id] = value