Skip to content

Commit

Permalink
feat: impl the part in service to run pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
paopa committed Oct 3, 2024
1 parent d523faf commit 0e5a835
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 7 deletions.
2 changes: 2 additions & 0 deletions wren-ai-service/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,5 @@ pipes:
- name: sql_regeneration
llm: openai_llm.gpt-4o-mini
engine: wren_ui
- name: semantics_description
llm: openai_llm.gpt-4o-mini
10 changes: 9 additions & 1 deletion wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines.generation import (
followup_sql_generation,
semantics_description,
sql_answer,
sql_breakdown,
sql_correction,
Expand Down Expand Up @@ -57,7 +58,14 @@ def create_service_container(
query_cache: Optional[dict] = {},
) -> ServiceContainer:
return ServiceContainer(
semantics_description=SemanticsDescription(pipelines={}, **query_cache),
semantics_description=SemanticsDescription(
pipelines={
"semantics_description": semantics_description.SemanticsDescription(
**pipe_components["semantics_description"],
)
},
**query_cache,
),
semantics_preparation_service=SemanticsPreparationService(
pipelines={
"indexing": indexing.Indexing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ def wrapper(text: str) -> str:


class SemanticsDescription(BasicPipeline):
def __init__(
self,
llm_provider: LLMProvider,
):
def __init__(self, llm_provider: LLMProvider, **_):
self._components = {
"prompt_builder": PromptBuilder(template=user_prompt_template),
"generator": llm_provider.get_generator(system_prompt=system_prompt),
Expand Down
38 changes: 36 additions & 2 deletions wren-ai-service/src/web/v1/services/semantics_description.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Dict, Literal, Optional

import orjson
from cachetools import TTLCache
from langfuse.decorators import observe
from pydantic import BaseModel
Expand Down Expand Up @@ -98,12 +99,45 @@ def __init__(
maxsize=maxsize, ttl=ttl
)

def _handle_exception(self, request: Request, error_message: str):
self._cache[request.id] = self.Response(
id=request.id,
status="failed",
error=self.Response.Error(code="OTHERS", message=error_message),
)
logger.error(error_message)

@observe(name="Generate Semantics Description")
@trace_metadata
async def generate(self, request: Request, **kwargs) -> Response:
logger.info("Generate Semantics Description pipeline is running...")
# todo: implement the service flow
pass

try:
if request.mdl is None:
raise ValueError("MDL must be provided")

mdl_dict = orjson.loads(request.mdl)

input = {
"user_prompt": request.user_prompt,
"selected_models": request.selected_models,
"mdl": mdl_dict,
}

resp = await self._pipelines["semantics_description"].run(**input)

self._cache[request.id] = self.Response(
id=request.id, status="finished", response=resp.get("normalize")
)
except orjson.JSONDecodeError as e:
self._handle_exception(request, f"Failed to parse MDL: {str(e)}")
except Exception as e:
self._handle_exception(
request,
f"An error occurred during semantics description generation: {str(e)}",
)

return self._cache[request.id]

def __getitem__(self, request: Request) -> Response:
response = self._cache.get(request.id)
Expand Down

0 comments on commit 0e5a835

Please sign in to comment.