From 183325d71900acf73965a7c4de3759efa1ecaa09 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 2 Oct 2024 16:15:13 +0800 Subject: [PATCH 01/13] feat: imple semantics description generation pipe --- .../generation/semantics_description.py | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 wren-ai-service/src/pipelines/generation/semantics_description.py diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py new file mode 100644 index 000000000..91fd6bccc --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -0,0 +1,219 @@ +import json +import logging +import sys +from pathlib import Path +from typing import Any + +import orjson +from hamilton import base +from hamilton.experimental.h_async import AsyncDriver +from haystack.components.builders.prompt_builder import PromptBuilder +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline, async_validate +from src.core.provider import LLMProvider + +logger = logging.getLogger("wren-ai-service") + + +## Start of Pipeline +def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]: + def extract(model: dict) -> dict: + return { + "name": model["name"], + "columns": model["columns"], + "properties": model["properties"], + } + + return [ + extract(model) for model in mdl["models"] if model["name"] in selected_models + ] + + +def prompt( + picked_models: list[dict], + user_prompt: str, + prompt_builder: PromptBuilder, +) -> dict: + return prompt_builder.run(picked_models=picked_models, user_prompt=user_prompt) + + +async def generate(prompt: dict, generator: Any) -> dict: + return await generator.run(prompt=prompt.get("prompt")) + + +def post_process(generate: dict) -> dict: + def normalize(text: str) -> str: + text = text.replace("\n", " ") + text = " ".join(text.split()) + # Convert the normalized text to a dictionary + try: + text_dict = orjson.loads(text.strip()) + return text_dict + except orjson.JSONDecodeError as e: + logger.error(f"Error decoding JSON: {e}") + return {} # Return an empty dictionary if JSON decoding fails + + reply = generate.get("replies")[0] # Expecting only one reply + normalized = normalize(reply) + + return {model["name"]: model for model in normalized["models"]} + + +## End of Pipeline + +system_prompt = """ +I have a data model represented in JSON format, with the following structure: + +``` +[ + {'name': 'model', 'columns': [ + {'name': 'column_1', 'type': 'type', 'notNull': True, 'properties': {} + }, + {'name': 'column_2', 'type': 'type', 'notNull': True, 'properties': {} + }, + {'name': 'column_3', 'type': 'type', 'notNull': False, 'properties': {} + } + ], 'properties': {} + } +] +``` + +Your task is to update this JSON structure by adding a `description` field inside both the `properties` attribute of each `column` and the `model` itself. +Each `description` should be derived from a user-provided input that explains the purpose or context of the `model` and its respective columns. +Follow these steps: +1. **For the `model`**: Prompt the user to provide a brief description of the model's overall purpose or its context. Insert this description in the `properties` field of the `model`. +2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `'description': 'user-provided text'`. +3. Ensure that the output is a well-formatted JSON structure, preserving the input's original format and adding the appropriate `description` fields. + +### Output Format: + +``` +[ + { + "name": "model", + "columns": [ + { + "name": "column_1", + "properties": { + "description": "" + } + }, + { + "name": "column_2", + "properties": { + "description": "" + } + }, + { + "name": "column_3", + "properties": { + "description": "" + } + } + ], + "properties": { + "description": "" + } + } +] +``` + +Make sure that the descriptions are concise, informative, and contextually appropriate based on the input provided by the user. +""" + +user_prompt_template = """ + +### Input +User's prompt: {{ user_prompt }} +Picked models: {{ picked_models }} + +Please provide a brief description for the model and each column based on the user's prompt. +""" + + +class SemanticsDescription(BasicPipeline): + def __init__( + self, + llm_provider: LLMProvider, + ): + self._components = { + "prompt_builder": PromptBuilder(template=user_prompt_template), + "generator": llm_provider.get_generator(system_prompt=system_prompt), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def visualize( + self, + user_prompt: str, + selected_models: list[str], + mdl: dict, + ) -> None: + destination = "outputs/pipelines/generation" + if not Path(destination).exists(): + Path(destination).mkdir(parents=True, exist_ok=True) + + self._pipe.visualize_execution( + [""], + output_file_path=f"{destination}/semantics_description.dot", + inputs={ + "user_prompt": user_prompt, + "selected_models": selected_models, + "mdl": mdl, + **self._components, + }, + show_legend=True, + orient="LR", + ) + + @observe(name="Semantics Description Generation") + async def run( + self, + user_prompt: str, + selected_models: list[str], + mdl: dict, + ) -> dict: + logger.info("Semantics Description Generation pipeline is running...") + return await self._pipe.execute( + ["post_process"], + inputs={ + "user_prompt": user_prompt, + "selected_models": selected_models, + "mdl": mdl, + **self._components, + }, + ) + + +if __name__ == "__main__": + from src.core.engine import EngineConfig + from src.core.pipeline import async_validate + from src.providers import init_providers + from src.utils import init_langfuse, load_env_vars + + load_env_vars() + init_langfuse() + + llm_provider, _, _, _ = init_providers(EngineConfig()) + pipeline = SemanticsDescription(llm_provider=llm_provider) + + with open("src/pipelines/prototype/example.json", "r") as file: + mdl = json.load(file) + + input = { + "user_prompt": "The Orders and Customers dataset is utilized to analyze customer behavior and preferences over time, enabling the improvement of marketing strategies. By examining purchasing patterns and trends, businesses can tailor their marketing efforts to better meet customer needs and enhance engagement.", + "selected_models": ["orders", "customers"], + "mdl": mdl, + } + + # pipeline.visualize(**input) + async_validate(lambda: pipeline.run(**input)) + + # expected = { + # "model_name": ["column1", "column2"], + # } + + # langfuse_context.flush() From 5ac4b396524f438ad0ceec634e345064ecea23e8 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 2 Oct 2024 16:33:54 +0800 Subject: [PATCH 02/13] chore: add the annotation for tracing --- .../generation/semantics_description.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index 91fd6bccc..60c67dba4 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -17,6 +17,7 @@ ## Start of Pipeline +@observe(capture_input=False) def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]: def extract(model: dict) -> dict: return { @@ -30,20 +31,26 @@ def extract(model: dict) -> dict: ] +@observe(capture_input=False) def prompt( picked_models: list[dict], user_prompt: str, prompt_builder: PromptBuilder, ) -> dict: + logger.debug(f"User prompt: {user_prompt}") + logger.debug(f"Picked models: {picked_models}") return prompt_builder.run(picked_models=picked_models, user_prompt=user_prompt) +@observe(as_type="generation", capture_input=False) async def generate(prompt: dict, generator: Any) -> dict: + logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}") return await generator.run(prompt=prompt.get("prompt")) -def post_process(generate: dict) -> dict: - def normalize(text: str) -> str: +@observe(capture_input=False) +def normalize(generate: dict) -> dict: + def wrapper(text: str) -> str: text = text.replace("\n", " ") text = " ".join(text.split()) # Convert the normalized text to a dictionary @@ -54,8 +61,12 @@ def normalize(text: str) -> str: logger.error(f"Error decoding JSON: {e}") return {} # Return an empty dictionary if JSON decoding fails + logger.debug( + f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}" + ) + reply = generate.get("replies")[0] # Expecting only one reply - normalized = normalize(reply) + normalized = wrapper(reply) return {model["name"]: model for model in normalized["models"]} @@ -123,8 +134,7 @@ def normalize(text: str) -> str: """ user_prompt_template = """ - -### Input +### Input: User's prompt: {{ user_prompt }} Picked models: {{ picked_models }} @@ -141,6 +151,7 @@ def __init__( "prompt_builder": PromptBuilder(template=user_prompt_template), "generator": llm_provider.get_generator(system_prompt=system_prompt), } + self._final = "normalize" super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) @@ -157,7 +168,7 @@ def visualize( Path(destination).mkdir(parents=True, exist_ok=True) self._pipe.visualize_execution( - [""], + [self._final], output_file_path=f"{destination}/semantics_description.dot", inputs={ "user_prompt": user_prompt, @@ -178,7 +189,7 @@ async def run( ) -> dict: logger.info("Semantics Description Generation pipeline is running...") return await self._pipe.execute( - ["post_process"], + [self._final], inputs={ "user_prompt": user_prompt, "selected_models": selected_models, @@ -189,6 +200,8 @@ async def run( if __name__ == "__main__": + from langfuse.decorators import langfuse_context + from src.core.engine import EngineConfig from src.core.pipeline import async_validate from src.providers import init_providers @@ -209,11 +222,7 @@ async def run( "mdl": mdl, } - # pipeline.visualize(**input) + pipeline.visualize(**input) async_validate(lambda: pipeline.run(**input)) - # expected = { - # "model_name": ["column1", "column2"], - # } - - # langfuse_context.flush() + langfuse_context.flush() From 2d2dcb29f2780ac4242795bcb3c662b83559baa7 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 2 Oct 2024 16:51:00 +0800 Subject: [PATCH 03/13] chore: modify example inner semantics description pipe --- .../pipelines/generation/semantics_description.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index 60c67dba4..bddcfcf36 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -213,12 +213,21 @@ async def run( llm_provider, _, _, _ = init_providers(EngineConfig()) pipeline = SemanticsDescription(llm_provider=llm_provider) - with open("src/pipelines/prototype/example.json", "r") as file: + with open("sample/college_3_bigquery_mdl.json", "r") as file: mdl = json.load(file) input = { - "user_prompt": "The Orders and Customers dataset is utilized to analyze customer behavior and preferences over time, enabling the improvement of marketing strategies. By examining purchasing patterns and trends, businesses can tailor their marketing efforts to better meet customer needs and enhance engagement.", - "selected_models": ["orders", "customers"], + "user_prompt": "Track student enrollments, grades, and GPA calculations to monitor academic performance and identify areas for student support", + "selected_models": [ + "Student", + "Minor_in", + "Member_of", + "Gradeconversion", + "Faculty", + "Enrolled_in", + "Department", + "Course", + ], "mdl": mdl, } From 3f987d3966dbb328b32b3a0238ab674b5482110e Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 2 Oct 2024 18:42:38 +0800 Subject: [PATCH 04/13] feat: implement the service --- wren-ai-service/src/globals.py | 2 + .../web/v1/services/semantics_description.py | 163 ++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 wren-ai-service/src/web/v1/services/semantics_description.py diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index ea337e746..c3ae63678 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -21,6 +21,7 @@ from src.pipelines.retrieval import historical_question, retrieval from src.web.v1.services.ask import AskService from src.web.v1.services.ask_details import AskDetailsService +from src.web.v1.services.semantics_description import SemanticsDescription from src.web.v1.services.semantics_preparation import SemanticsPreparationService from src.web.v1.services.sql_answer import SqlAnswerService from src.web.v1.services.sql_expansion import SqlExpansionService @@ -32,6 +33,7 @@ @dataclass class ServiceContainer: + semantics_description: SemanticsDescription semantics_preparation_service: SemanticsPreparationService ask_service: AskService sql_answer_service: SqlAnswerService diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py new file mode 100644 index 000000000..13e60c1bf --- /dev/null +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -0,0 +1,163 @@ +import logging +import uuid +from dataclasses import asdict +from typing import Dict, Literal, Optional + +from cachetools import TTLCache +from fastapi import APIRouter, BackgroundTasks, Depends +from langfuse.decorators import observe +from pydantic import BaseModel + +from src.core.pipeline import BasicPipeline +from src.globals import ( + ServiceContainer, + ServiceMetadata, + get_service_container, + get_service_metadata, +) +from src.utils import trace_metadata + +logger = logging.getLogger("wren-ai-service") + + +class SemanticsDescription: + """ + SemanticsDescription Service + + This service provides endpoints for generating and optimizing semantic descriptions + based on user prompts and selected models. + + Endpoints: + 1. POST /v1/semantics-descriptions + Generate a new semantic description. + + Request body: + { + "selected_models": ["model1", "model2"], # List of selected model names + "user_prompt": "Describe the data model", # User's prompt for description + "mdl": "..." # MDL (Model Definition Language) string + } + + Response: + { + "id": "unique_id", # Unique identifier for the generated description + "status": "generating" # Initial status + } + + 2. GET /v1/semantics-descriptions/{id} + Retrieve the status and result of a semantic description generation. + + Path parameter: + - id: Unique identifier of the semantic description resource. + + Response: + { + "id": "unique_id", + "status": "finished", # Can be "generating", "finished", or "failed" + "response": { # Present only if status is "finished" + // Generated semantic description + }, + "error": { # Present only if status is "failed" + "code": "OTHERS", + "message": "Error description" + } + } + + Usage: + 1. Call the POST endpoint to initiate a semantic description generation. + 2. Use the returned ID to poll the GET endpoint until the status is "finished" or "failed". + 3. Once finished, retrieve the generated description from the "response" field. + + Note: The generation process may take some time, so implement appropriate polling + intervals when checking the status. + """ + + class Request(BaseModel): + _id: str | None = None + selected_models: list[str] = [] + user_prompt: str = "" + mdl: str + + @property + def id(self) -> str: + return self._id + + @id.setter + def id(self, id: str): + self._id = id + + class Response(BaseModel): + class Error(BaseModel): + code: Literal["OTHERS"] + message: str + + id: str + status: Literal["generating", "finished", "failed"] = "generating" + response: Optional[dict] = None + error: Optional[Error] = None + + def __init__( + self, + pipelines: Dict[str, BasicPipeline], + maxsize: int = 1_000_000, + ttl: int = 120, + ): + self._pipelines = pipelines + self._cache: Dict[str, SemanticsDescription.Response] = TTLCache( + maxsize=maxsize, ttl=ttl + ) + + @observe(name="Generate Semantics Description") + @trace_metadata + async def generate(self, request: Request, **kwargs) -> Response: + logger.info("Generate Semantics Description pipeline is running...") + # todo: implement the service flow + pass + + def get(self, request: Request) -> Response: + response = self._cache.get(request.id) + + if response is None: + # todo: error handling + logger.error( + f"Semantics Description Resource with ID '{request.id}' not found." + ) + return self.Response() + + return response + + +router = APIRouter() + + +@router.post("/v1/semantics-descriptions", response_model=SemanticsDescription.Response) +async def generate( + request: SemanticsDescription.Request, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> SemanticsDescription.Response: + id = str(uuid.uuid4()) + request.id = id + service = service_container.semantics_description + + # todo: consider to simplify the code by using the service_container + service._cache[request.id] = SemanticsDescription.Response(id=id) + + background_tasks.add_task( + service.generate, request, service_metadata=asdict(service_metadata) + ) + return service._cache[request.id] + + +@router.get( + "/v1/semantics-descriptions/{id}", + response_model=SemanticsDescription.Response, +) +async def get( + id: str, + service_container: ServiceContainer = Depends(get_service_container), +) -> SemanticsDescription.Response: + return service_container.semantics_description.get( + SemanticsDescription.Request(id=id) + ) From eec3ea530f3f6d2e04ed1f9174697b0350c8049f Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Thu, 3 Oct 2024 11:30:23 +0800 Subject: [PATCH 05/13] feat: implement get and set item method to retrieve the cached response --- .../src/web/v1/services/semantics_description.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index 13e60c1bf..c8d81e75b 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -114,7 +114,7 @@ async def generate(self, request: Request, **kwargs) -> Response: # todo: implement the service flow pass - def get(self, request: Request) -> Response: + def __getitem__(self, request: Request) -> Response: response = self._cache.get(request.id) if response is None: @@ -126,6 +126,9 @@ def get(self, request: Request) -> Response: return response + def __setitem__(self, request: Request, value: Response): + self._cache[request.id] = value + router = APIRouter() @@ -142,12 +145,12 @@ async def generate( service = service_container.semantics_description # todo: consider to simplify the code by using the service_container - service._cache[request.id] = SemanticsDescription.Response(id=id) + service[request] = SemanticsDescription.Response(id=id) background_tasks.add_task( service.generate, request, service_metadata=asdict(service_metadata) ) - return service._cache[request.id] + return service[request] @router.get( @@ -158,6 +161,4 @@ async def get( id: str, service_container: ServiceContainer = Depends(get_service_container), ) -> SemanticsDescription.Response: - return service_container.semantics_description.get( - SemanticsDescription.Request(id=id) - ) + return service_container.semantics_description[SemanticsDescription.Request(id=id)] From 403f6c37f81500e823fd897ee5fba3b94f5e6b7c Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Thu, 3 Oct 2024 11:44:56 +0800 Subject: [PATCH 06/13] feat: web service for semantics description --- wren-ai-service/src/globals.py | 1 + .../v1/{routers.py => routers/__init__.py} | 2 + .../web/v1/routers/semantics_description.py | 47 ++++++++++++++++ .../web/v1/services/semantics_description.py | 55 +++---------------- 4 files changed, 58 insertions(+), 47 deletions(-) rename wren-ai-service/src/web/v1/{routers.py => routers/__init__.py} (98%) create mode 100644 wren-ai-service/src/web/v1/routers/semantics_description.py diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index c3ae63678..eeb764cc9 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -57,6 +57,7 @@ def create_service_container( query_cache: Optional[dict] = {}, ) -> ServiceContainer: return ServiceContainer( + semantics_description=SemanticsDescription(pipelines={}, **query_cache), semantics_preparation_service=SemanticsPreparationService( pipelines={ "indexing": indexing.Indexing( diff --git a/wren-ai-service/src/web/v1/routers.py b/wren-ai-service/src/web/v1/routers/__init__.py similarity index 98% rename from wren-ai-service/src/web/v1/routers.py rename to wren-ai-service/src/web/v1/routers/__init__.py index 92a32846e..93f0d4094 100644 --- a/wren-ai-service/src/web/v1/routers.py +++ b/wren-ai-service/src/web/v1/routers/__init__.py @@ -9,6 +9,7 @@ get_service_container, get_service_metadata, ) +from src.web.v1.routers import semantics_description from src.web.v1.services.ask import ( AskRequest, AskResponse, @@ -57,6 +58,7 @@ ) router = APIRouter() +router.include_router(semantics_description.router) @router.post("/semantics-preparations") diff --git a/wren-ai-service/src/web/v1/routers/semantics_description.py b/wren-ai-service/src/web/v1/routers/semantics_description.py new file mode 100644 index 000000000..579d6f632 --- /dev/null +++ b/wren-ai-service/src/web/v1/routers/semantics_description.py @@ -0,0 +1,47 @@ +import uuid +from dataclasses import asdict + +from fastapi import APIRouter, BackgroundTasks, Depends + +from src.globals import ( + ServiceContainer, + ServiceMetadata, + get_service_container, + get_service_metadata, +) +from src.web.v1.services.semantics_description import SemanticsDescription + +router = APIRouter() + + +@router.post("/semantics-descriptions", response_model=SemanticsDescription.Response) +async def generate( + request: SemanticsDescription.Request, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> SemanticsDescription.Response: + id = str(uuid.uuid4()) + request.id = id + service = service_container.semantics_description + + service[request] = SemanticsDescription.Response(id=id) + + background_tasks.add_task( + service.generate, request, service_metadata=asdict(service_metadata) + ) + return service[request] + + +@router.get( + "/semantics-descriptions/{id}", + response_model=SemanticsDescription.Response, +) +async def get( + id: str, + service_container: ServiceContainer = Depends(get_service_container), +) -> SemanticsDescription.Response: + request = SemanticsDescription.Request() + request.id = id + + return service_container.semantics_description[request] diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index c8d81e75b..70b1af00f 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -1,20 +1,11 @@ import logging -import uuid -from dataclasses import asdict from typing import Dict, Literal, Optional from cachetools import TTLCache -from fastapi import APIRouter, BackgroundTasks, Depends from langfuse.decorators import observe from pydantic import BaseModel from src.core.pipeline import BasicPipeline -from src.globals import ( - ServiceContainer, - ServiceMetadata, - get_service_container, - get_service_metadata, -) from src.utils import trace_metadata logger = logging.getLogger("wren-ai-service") @@ -76,7 +67,7 @@ class Request(BaseModel): _id: str | None = None selected_models: list[str] = [] user_prompt: str = "" - mdl: str + mdl: str | None = None @property def id(self) -> str: @@ -118,47 +109,17 @@ def __getitem__(self, request: Request) -> Response: response = self._cache.get(request.id) if response is None: - # todo: error handling - logger.error( + message = ( f"Semantics Description Resource with ID '{request.id}' not found." ) - return self.Response() + logger.exception(message) + return self.Response( + id=request.id, + status="failed", + error=self.Response.Error(code="OTHERS", message=message), + ) return response def __setitem__(self, request: Request, value: Response): self._cache[request.id] = value - - -router = APIRouter() - - -@router.post("/v1/semantics-descriptions", response_model=SemanticsDescription.Response) -async def generate( - request: SemanticsDescription.Request, - background_tasks: BackgroundTasks, - service_container: ServiceContainer = Depends(get_service_container), - service_metadata: ServiceMetadata = Depends(get_service_metadata), -) -> SemanticsDescription.Response: - id = str(uuid.uuid4()) - request.id = id - service = service_container.semantics_description - - # todo: consider to simplify the code by using the service_container - service[request] = SemanticsDescription.Response(id=id) - - background_tasks.add_task( - service.generate, request, service_metadata=asdict(service_metadata) - ) - return service[request] - - -@router.get( - "/v1/semantics-descriptions/{id}", - response_model=SemanticsDescription.Response, -) -async def get( - id: str, - service_container: ServiceContainer = Depends(get_service_container), -) -> SemanticsDescription.Response: - return service_container.semantics_description[SemanticsDescription.Request(id=id)] From 10f448f71c9ec358efe4e3ccb65b19f1ed7fa25f Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Thu, 3 Oct 2024 16:22:44 +0800 Subject: [PATCH 07/13] feat: impl the part in service to run pipeline --- wren-ai-service/config.example.yaml | 2 + wren-ai-service/src/globals.py | 10 ++++- .../generation/semantics_description.py | 5 +-- .../web/v1/services/semantics_description.py | 38 ++++++++++++++++++- 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/wren-ai-service/config.example.yaml b/wren-ai-service/config.example.yaml index 64c41add3..86c44ad93 100644 --- a/wren-ai-service/config.example.yaml +++ b/wren-ai-service/config.example.yaml @@ -103,3 +103,5 @@ pipes: - name: sql_regeneration llm: openai_llm.gpt-4o-mini engine: wren_ui + - name: semantics_description + llm: openai_llm.gpt-4o-mini diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index eeb764cc9..d03b5e2ef 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -8,6 +8,7 @@ from src.core.provider import EmbedderProvider, LLMProvider from src.pipelines.generation import ( followup_sql_generation, + semantics_description, sql_answer, sql_breakdown, sql_correction, @@ -57,7 +58,14 @@ def create_service_container( query_cache: Optional[dict] = {}, ) -> ServiceContainer: return ServiceContainer( - semantics_description=SemanticsDescription(pipelines={}, **query_cache), + semantics_description=SemanticsDescription( + pipelines={ + "semantics_description": semantics_description.SemanticsDescription( + **pipe_components["semantics_description"], + ) + }, + **query_cache, + ), semantics_preparation_service=SemanticsPreparationService( pipelines={ "indexing": indexing.Indexing( diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index bddcfcf36..fe7f67fdc 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -143,10 +143,7 @@ def wrapper(text: str) -> str: class SemanticsDescription(BasicPipeline): - def __init__( - self, - llm_provider: LLMProvider, - ): + def __init__(self, llm_provider: LLMProvider, **_): self._components = { "prompt_builder": PromptBuilder(template=user_prompt_template), "generator": llm_provider.get_generator(system_prompt=system_prompt), diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index 70b1af00f..67ab7571a 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -1,6 +1,7 @@ import logging from typing import Dict, Literal, Optional +import orjson from cachetools import TTLCache from langfuse.decorators import observe from pydantic import BaseModel @@ -98,12 +99,45 @@ def __init__( maxsize=maxsize, ttl=ttl ) + def _handle_exception(self, request: Request, error_message: str): + self._cache[request.id] = self.Response( + id=request.id, + status="failed", + error=self.Response.Error(code="OTHERS", message=error_message), + ) + logger.error(error_message) + @observe(name="Generate Semantics Description") @trace_metadata async def generate(self, request: Request, **kwargs) -> Response: logger.info("Generate Semantics Description pipeline is running...") - # todo: implement the service flow - pass + + try: + if request.mdl is None: + raise ValueError("MDL must be provided") + + mdl_dict = orjson.loads(request.mdl) + + input = { + "user_prompt": request.user_prompt, + "selected_models": request.selected_models, + "mdl": mdl_dict, + } + + resp = await self._pipelines["semantics_description"].run(**input) + + self._cache[request.id] = self.Response( + id=request.id, status="finished", response=resp.get("normalize") + ) + except orjson.JSONDecodeError as e: + self._handle_exception(request, f"Failed to parse MDL: {str(e)}") + except Exception as e: + self._handle_exception( + request, + f"An error occurred during semantics description generation: {str(e)}", + ) + + return self._cache[request.id] def __getitem__(self, request: Request) -> Response: response = self._cache.get(request.id) From e901919bad63e6682972dc579d36158c20cdc5f4 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Thu, 3 Oct 2024 16:32:46 +0800 Subject: [PATCH 08/13] feat: optimize the prompt for stabiliby --- .../pipelines/generation/semantics_description.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index fe7f67fdc..29c410fd8 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -100,8 +100,9 @@ def wrapper(text: str) -> str: ### Output Format: ``` -[ - { +{ + "models": [ + { "name": "model", "columns": [ { @@ -124,10 +125,11 @@ def wrapper(text: str) -> str: } ], "properties": { - "description": "" + "description": "" + } } - } -] + ] +} ``` Make sure that the descriptions are concise, informative, and contextually appropriate based on the input provided by the user. From d5cb50d34134a33b4b816b0403411fb7f3bc78da Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Thu, 3 Oct 2024 16:50:19 +0800 Subject: [PATCH 09/13] feat: test case to validate service --- .../services/test_semantics_description.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 wren-ai-service/tests/pytest/services/test_semantics_description.py diff --git a/wren-ai-service/tests/pytest/services/test_semantics_description.py b/wren-ai-service/tests/pytest/services/test_semantics_description.py new file mode 100644 index 000000000..4b2678701 --- /dev/null +++ b/wren-ai-service/tests/pytest/services/test_semantics_description.py @@ -0,0 +1,125 @@ +from unittest.mock import AsyncMock + +import pytest + +from src.web.v1.services.semantics_description import SemanticsDescription + + +@pytest.fixture +def semantics_description_service(): + mock_pipeline = AsyncMock() + mock_pipeline.run.return_value = { + "normalize": { + "model1": { + "columns": [], + "properties": {"description": "Test description"}, + } + } + } + + pipelines = {"semantics_description": mock_pipeline} + return SemanticsDescription(pipelines=pipelines) + + +@pytest.mark.asyncio +async def test_generate_semantics_description( + semantics_description_service: SemanticsDescription, +): + request = SemanticsDescription.Request( + user_prompt="Describe the model", + selected_models=["model1"], + mdl='{"models": [{"name": "model1", "columns": []}]}', + ) + request.id = "test_id" + + response = await semantics_description_service.generate(request) + + assert response.id == "test_id" + assert response.status == "finished" + assert response.response == { + "model1": { + "columns": [], + "properties": {"description": "Test description"}, + } + } + assert response.error is None + + +@pytest.mark.asyncio +async def test_generate_semantics_description_with_invalid_mdl( + semantics_description_service: SemanticsDescription, +): + request = SemanticsDescription.Request( + user_prompt="Describe the model", + selected_models=["model1"], + mdl="invalid_json", + ) + request.id = "test_id" + + response = await semantics_description_service.generate(request) + + assert response.id == "test_id" + assert response.status == "failed" + assert response.response is None + assert response.error.code == "OTHERS" + assert "Failed to parse MDL" in response.error.message + + +@pytest.mark.asyncio +async def test_generate_semantics_description_with_exception( + semantics_description_service: SemanticsDescription, +): + request = SemanticsDescription.Request( + user_prompt="Describe the model", + selected_models=["model1"], + mdl='{"models": [{"name": "model1", "columns": []}]}', + ) + request.id = "test_id" + + semantics_description_service._pipelines[ + "semantics_description" + ].run.side_effect = Exception("Test exception") + + response = await semantics_description_service.generate(request) + + assert response.id == "test_id" + assert response.status == "failed" + assert response.response is None + assert response.error.code == "OTHERS" + assert ( + "An error occurred during semantics description generation" + in response.error.message + ) + + +def test_get_semantics_description_result( + semantics_description_service: SemanticsDescription, +): + request = SemanticsDescription.Request() + request.id = "test_id" + + expected_response = SemanticsDescription.Response( + id="test_id", + status="finished", + response={"model1": {"description": "Test description"}}, + ) + semantics_description_service._cache["test_id"] = expected_response + + result = semantics_description_service[request] + + assert result == expected_response + + +def test_get_non_existent_semantics_description_result( + semantics_description_service: SemanticsDescription, +): + request = SemanticsDescription.Request() + request.id = "non_existent_id" + + result = semantics_description_service[request] + + assert result.id == "non_existent_id" + assert result.status == "failed" + assert result.response is None + assert result.error.code == "OTHERS" + assert "not found" in result.error.message From f060d15205b3ba83d601d96e11d9aaffcbea680e Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Fri, 4 Oct 2024 14:33:36 +0800 Subject: [PATCH 10/13] chore: remove status attr in response for generating semantics description task --- wren-ai-service/src/web/v1/services/semantics_description.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index 67ab7571a..4ba4dea3e 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -33,7 +33,6 @@ class SemanticsDescription: Response: { "id": "unique_id", # Unique identifier for the generated description - "status": "generating" # Initial status } 2. GET /v1/semantics-descriptions/{id} @@ -84,7 +83,7 @@ class Error(BaseModel): message: str id: str - status: Literal["generating", "finished", "failed"] = "generating" + status: Literal["generating", "finished", "failed"] = None response: Optional[dict] = None error: Optional[Error] = None From b17d8eebd899e1d8e83c7e3b5a39b4b21ec0a005 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Fri, 4 Oct 2024 16:52:25 +0800 Subject: [PATCH 11/13] feat: refactor the route and service classes design --- .../web/v1/routers/semantics_description.py | 88 ++++++++++++-- .../web/v1/services/semantics_description.py | 107 ++++-------------- .../services/test_semantics_description.py | 28 +++-- 3 files changed, 113 insertions(+), 110 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/semantics_description.py b/wren-ai-service/src/web/v1/routers/semantics_description.py index 579d6f632..28044a25a 100644 --- a/wren-ai-service/src/web/v1/routers/semantics_description.py +++ b/wren-ai-service/src/web/v1/routers/semantics_description.py @@ -1,7 +1,9 @@ import uuid from dataclasses import asdict +from typing import Literal, Optional from fastapi import APIRouter, BackgroundTasks, Depends +from pydantic import BaseModel from src.globals import ( ServiceContainer, @@ -13,35 +15,101 @@ router = APIRouter() +""" +Semantics Description Router -@router.post("/semantics-descriptions", response_model=SemanticsDescription.Response) +This router handles endpoints related to generating and retrieving semantic descriptions. + +Endpoints: +1. POST /semantics-descriptions + - Generates a new semantic description + - Request body: PostRequest + - Response: PostResponse with a unique ID + +2. GET /semantics-descriptions/{id} + - Retrieves the status and result of a semantic description generation + - Path parameter: id (str) + - Response: GetResponse with status, response, and error information + +The semantic description generation is an asynchronous process. The POST endpoint +initiates the generation and returns immediately with an ID. The GET endpoint can +then be used to check the status and retrieve the result when it's ready. + +Usage: +1. Send a POST request to start the generation process. +2. Use the returned ID to poll the GET endpoint until the status is "finished" or "failed". + +Note: The actual generation is performed in the background using FastAPI's BackgroundTasks. +""" + + +class PostRequest(BaseModel): + _id: str | None = None + selected_models: list[str] + user_prompt: str + mdl: str + + @property + def id(self) -> str: + return self._id + + @id.setter + def id(self, id: str): + self._id = id + + +class PostResponse(BaseModel): + id: str + + +@router.post( + "/semantics-descriptions", + response_model=PostResponse, +) async def generate( - request: SemanticsDescription.Request, + request: PostRequest, background_tasks: BackgroundTasks, service_container: ServiceContainer = Depends(get_service_container), service_metadata: ServiceMetadata = Depends(get_service_metadata), -) -> SemanticsDescription.Response: +) -> PostResponse: id = str(uuid.uuid4()) request.id = id service = service_container.semantics_description - service[request] = SemanticsDescription.Response(id=id) + service[id] = SemanticsDescription.Resource(id=id) + SemanticsDescription.Input( + id=id, + selected_models=request.selected_models, + user_prompt=request.user_prompt, + mdl=request.mdl, + ) background_tasks.add_task( service.generate, request, service_metadata=asdict(service_metadata) ) - return service[request] + return PostResponse(id=id) + + +class GetResponse(BaseModel): + id: str + status: Literal["generating", "finished", "failed"] + response: Optional[dict] + error: Optional[dict] @router.get( "/semantics-descriptions/{id}", - response_model=SemanticsDescription.Response, + response_model=GetResponse, ) async def get( id: str, service_container: ServiceContainer = Depends(get_service_container), -) -> SemanticsDescription.Response: - request = SemanticsDescription.Request() - request.id = id +) -> GetResponse: + resource = service_container.semantics_description[id] - return service_container.semantics_description[request] + return GetResponse( + id=resource.id, + status=resource.status, + response=resource.response, + error=resource.error and resource.error.model_dump(), + ) diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index 4ba4dea3e..172c81228 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -13,71 +13,13 @@ class SemanticsDescription: - """ - SemanticsDescription Service - - This service provides endpoints for generating and optimizing semantic descriptions - based on user prompts and selected models. - - Endpoints: - 1. POST /v1/semantics-descriptions - Generate a new semantic description. - - Request body: - { - "selected_models": ["model1", "model2"], # List of selected model names - "user_prompt": "Describe the data model", # User's prompt for description - "mdl": "..." # MDL (Model Definition Language) string - } - - Response: - { - "id": "unique_id", # Unique identifier for the generated description - } - - 2. GET /v1/semantics-descriptions/{id} - Retrieve the status and result of a semantic description generation. - - Path parameter: - - id: Unique identifier of the semantic description resource. - - Response: - { - "id": "unique_id", - "status": "finished", # Can be "generating", "finished", or "failed" - "response": { # Present only if status is "finished" - // Generated semantic description - }, - "error": { # Present only if status is "failed" - "code": "OTHERS", - "message": "Error description" - } - } - - Usage: - 1. Call the POST endpoint to initiate a semantic description generation. - 2. Use the returned ID to poll the GET endpoint until the status is "finished" or "failed". - 3. Once finished, retrieve the generated description from the "response" field. - - Note: The generation process may take some time, so implement appropriate polling - intervals when checking the status. - """ - - class Request(BaseModel): - _id: str | None = None - selected_models: list[str] = [] - user_prompt: str = "" - mdl: str | None = None - - @property - def id(self) -> str: - return self._id - - @id.setter - def id(self, id: str): - self._id = id - - class Response(BaseModel): + class Input(BaseModel): + id: str + selected_models: list[str] + user_prompt: str + mdl: str + + class Resource(BaseModel): class Error(BaseModel): code: Literal["OTHERS"] message: str @@ -94,27 +36,24 @@ def __init__( ttl: int = 120, ): self._pipelines = pipelines - self._cache: Dict[str, SemanticsDescription.Response] = TTLCache( + self._cache: Dict[str, SemanticsDescription.Resource] = TTLCache( maxsize=maxsize, ttl=ttl ) - def _handle_exception(self, request: Request, error_message: str): - self._cache[request.id] = self.Response( + def _handle_exception(self, request: Input, error_message: str): + self[request.id] = self.Resource( id=request.id, status="failed", - error=self.Response.Error(code="OTHERS", message=error_message), + error=self.Resource.Error(code="OTHERS", message=error_message), ) logger.error(error_message) @observe(name="Generate Semantics Description") @trace_metadata - async def generate(self, request: Request, **kwargs) -> Response: + async def generate(self, request: Input, **kwargs) -> Resource: logger.info("Generate Semantics Description pipeline is running...") try: - if request.mdl is None: - raise ValueError("MDL must be provided") - mdl_dict = orjson.loads(request.mdl) input = { @@ -125,7 +64,7 @@ async def generate(self, request: Request, **kwargs) -> Response: resp = await self._pipelines["semantics_description"].run(**input) - self._cache[request.id] = self.Response( + self[request.id] = self.Resource( id=request.id, status="finished", response=resp.get("normalize") ) except orjson.JSONDecodeError as e: @@ -136,23 +75,21 @@ async def generate(self, request: Request, **kwargs) -> Response: f"An error occurred during semantics description generation: {str(e)}", ) - return self._cache[request.id] + return self[request.id] - def __getitem__(self, request: Request) -> Response: - response = self._cache.get(request.id) + def __getitem__(self, id: int) -> Resource: + response = self._cache.get(id) if response is None: - message = ( - f"Semantics Description Resource with ID '{request.id}' not found." - ) + message = f"Semantics Description Resource with ID '{id}' not found." logger.exception(message) - return self.Response( - id=request.id, + return self.Resource( + id=id, status="failed", - error=self.Response.Error(code="OTHERS", message=message), + error=self.Resource.Error(code="OTHERS", message=message), ) return response - def __setitem__(self, request: Request, value: Response): - self._cache[request.id] = value + def __setitem__(self, id: int, value: Resource): + self._cache[id] = value diff --git a/wren-ai-service/tests/pytest/services/test_semantics_description.py b/wren-ai-service/tests/pytest/services/test_semantics_description.py index 4b2678701..f35f382d7 100644 --- a/wren-ai-service/tests/pytest/services/test_semantics_description.py +++ b/wren-ai-service/tests/pytest/services/test_semantics_description.py @@ -25,12 +25,12 @@ def semantics_description_service(): async def test_generate_semantics_description( semantics_description_service: SemanticsDescription, ): - request = SemanticsDescription.Request( + request = SemanticsDescription.Input( + id="test_id", user_prompt="Describe the model", selected_models=["model1"], mdl='{"models": [{"name": "model1", "columns": []}]}', ) - request.id = "test_id" response = await semantics_description_service.generate(request) @@ -49,12 +49,12 @@ async def test_generate_semantics_description( async def test_generate_semantics_description_with_invalid_mdl( semantics_description_service: SemanticsDescription, ): - request = SemanticsDescription.Request( + request = SemanticsDescription.Input( + id="test_id", user_prompt="Describe the model", selected_models=["model1"], mdl="invalid_json", ) - request.id = "test_id" response = await semantics_description_service.generate(request) @@ -69,12 +69,12 @@ async def test_generate_semantics_description_with_invalid_mdl( async def test_generate_semantics_description_with_exception( semantics_description_service: SemanticsDescription, ): - request = SemanticsDescription.Request( + request = SemanticsDescription.Input( + id="test_id", user_prompt="Describe the model", selected_models=["model1"], mdl='{"models": [{"name": "model1", "columns": []}]}', ) - request.id = "test_id" semantics_description_service._pipelines[ "semantics_description" @@ -95,17 +95,16 @@ async def test_generate_semantics_description_with_exception( def test_get_semantics_description_result( semantics_description_service: SemanticsDescription, ): - request = SemanticsDescription.Request() - request.id = "test_id" + id = "test_id" - expected_response = SemanticsDescription.Response( - id="test_id", + expected_response = SemanticsDescription.Resource( + id=id, status="finished", response={"model1": {"description": "Test description"}}, ) - semantics_description_service._cache["test_id"] = expected_response + semantics_description_service._cache[id] = expected_response - result = semantics_description_service[request] + result = semantics_description_service[id] assert result == expected_response @@ -113,10 +112,9 @@ def test_get_semantics_description_result( def test_get_non_existent_semantics_description_result( semantics_description_service: SemanticsDescription, ): - request = SemanticsDescription.Request() - request.id = "non_existent_id" + id = "non_existent_id" - result = semantics_description_service[request] + result = semantics_description_service[id] assert result.id == "non_existent_id" assert result.status == "failed" From 1faa1865322fcf73aefc89ab74cb4a20da285659 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Fri, 4 Oct 2024 16:58:10 +0800 Subject: [PATCH 12/13] chore: optimize the route doc --- .../web/v1/routers/semantics_description.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/semantics_description.py b/wren-ai-service/src/web/v1/routers/semantics_description.py index 28044a25a..1b0165e4c 100644 --- a/wren-ai-service/src/web/v1/routers/semantics_description.py +++ b/wren-ai-service/src/web/v1/routers/semantics_description.py @@ -24,12 +24,38 @@ 1. POST /semantics-descriptions - Generates a new semantic description - Request body: PostRequest - - Response: PostResponse with a unique ID + { + "selected_models": ["model1", "model2"], # List of model names to describe + "user_prompt": "Describe these models", # User's instruction for description + "mdl": "{ ... }" # JSON string of the MDL (Model Definition Language) + } + - Response: PostResponse + { + "id": "unique-uuid" # Unique identifier for the generated description + } 2. GET /semantics-descriptions/{id} - Retrieves the status and result of a semantic description generation - Path parameter: id (str) - - Response: GetResponse with status, response, and error information + - Response: GetResponse + { + "id": "unique-uuid", # Unique identifier of the description + "status": "generating" | "finished" | "failed", + "response": { # Present only if status is "finished" + "model1": { + "columns": [...], + "properties": {...} + }, + "model2": { + "columns": [...], + "properties": {...} + } + }, + "error": { # Present only if status is "failed" + "code": "OTHERS", + "message": "Error description" + } + } The semantic description generation is an asynchronous process. The POST endpoint initiates the generation and returns immediately with an ID. The GET endpoint can From a8e80ed2c601ea8e440a0340cc949f84ddb653ea Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Fri, 4 Oct 2024 17:23:06 +0800 Subject: [PATCH 13/13] fix: using input instead of endpoint request --- .../src/web/v1/routers/semantics_description.py | 14 ++------------ .../src/web/v1/services/semantics_description.py | 4 ++-- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/semantics_description.py b/wren-ai-service/src/web/v1/routers/semantics_description.py index 1b0165e4c..53adec9dd 100644 --- a/wren-ai-service/src/web/v1/routers/semantics_description.py +++ b/wren-ai-service/src/web/v1/routers/semantics_description.py @@ -70,19 +70,10 @@ class PostRequest(BaseModel): - _id: str | None = None selected_models: list[str] user_prompt: str mdl: str - @property - def id(self) -> str: - return self._id - - @id.setter - def id(self, id: str): - self._id = id - class PostResponse(BaseModel): id: str @@ -99,11 +90,10 @@ async def generate( service_metadata: ServiceMetadata = Depends(get_service_metadata), ) -> PostResponse: id = str(uuid.uuid4()) - request.id = id service = service_container.semantics_description service[id] = SemanticsDescription.Resource(id=id) - SemanticsDescription.Input( + input = SemanticsDescription.Input( id=id, selected_models=request.selected_models, user_prompt=request.user_prompt, @@ -111,7 +101,7 @@ async def generate( ) background_tasks.add_task( - service.generate, request, service_metadata=asdict(service_metadata) + service.generate, input, service_metadata=asdict(service_metadata) ) return PostResponse(id=id) diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index 172c81228..abc8311b6 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -77,7 +77,7 @@ async def generate(self, request: Input, **kwargs) -> Resource: return self[request.id] - def __getitem__(self, id: int) -> Resource: + def __getitem__(self, id: str) -> Resource: response = self._cache.get(id) if response is None: @@ -91,5 +91,5 @@ def __getitem__(self, id: int) -> Resource: return response - def __setitem__(self, id: int, value: Resource): + def __setitem__(self, id: str, value: Resource): self._cache[id] = value