Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): semantics description pipeline to generate or optimize the user description #716

Merged
merged 13 commits into from
Oct 7, 2024
Merged
2 changes: 2 additions & 0 deletions wren-ai-service/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,5 @@ pipes:
- name: sql_regeneration
llm: openai_llm.gpt-4o-mini
engine: wren_ui
- name: semantics_description
llm: openai_llm.gpt-4o-mini
11 changes: 11 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines.generation import (
followup_sql_generation,
semantics_description,
sql_answer,
sql_breakdown,
sql_correction,
Expand All @@ -21,6 +22,7 @@
from src.pipelines.retrieval import historical_question, retrieval
from src.web.v1.services.ask import AskService
from src.web.v1.services.ask_details import AskDetailsService
from src.web.v1.services.semantics_description import SemanticsDescription
from src.web.v1.services.semantics_preparation import SemanticsPreparationService
from src.web.v1.services.sql_answer import SqlAnswerService
from src.web.v1.services.sql_expansion import SqlExpansionService
Expand All @@ -32,6 +34,7 @@

@dataclass
class ServiceContainer:
semantics_description: SemanticsDescription
semantics_preparation_service: SemanticsPreparationService
ask_service: AskService
sql_answer_service: SqlAnswerService
Expand All @@ -55,6 +58,14 @@ def create_service_container(
query_cache: Optional[dict] = {},
) -> ServiceContainer:
return ServiceContainer(
semantics_description=SemanticsDescription(
pipelines={
"semantics_description": semantics_description.SemanticsDescription(
**pipe_components["semantics_description"],
)
},
**query_cache,
),
semantics_preparation_service=SemanticsPreparationService(
pipelines={
"indexing": indexing.Indexing(
Expand Down
236 changes: 236 additions & 0 deletions wren-ai-service/src/pipelines/generation/semantics_description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import json
import logging
import sys
from pathlib import Path
from typing import Any

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe

from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import LLMProvider

logger = logging.getLogger("wren-ai-service")


## Start of Pipeline
@observe(capture_input=False)
def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]:
def extract(model: dict) -> dict:
return {
"name": model["name"],
"columns": model["columns"],
"properties": model["properties"],
}

return [
extract(model) for model in mdl["models"] if model["name"] in selected_models
]


@observe(capture_input=False)
def prompt(
picked_models: list[dict],
user_prompt: str,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"User prompt: {user_prompt}")
logger.debug(f"Picked models: {picked_models}")
return prompt_builder.run(picked_models=picked_models, user_prompt=user_prompt)


@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


@observe(capture_input=False)
def normalize(generate: dict) -> dict:
def wrapper(text: str) -> str:
text = text.replace("\n", " ")
text = " ".join(text.split())
# Convert the normalized text to a dictionary
try:
text_dict = orjson.loads(text.strip())
return text_dict
except orjson.JSONDecodeError as e:
logger.error(f"Error decoding JSON: {e}")
return {} # Return an empty dictionary if JSON decoding fails

logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
)

reply = generate.get("replies")[0] # Expecting only one reply
normalized = wrapper(reply)

return {model["name"]: model for model in normalized["models"]}


## End of Pipeline

system_prompt = """
I have a data model represented in JSON format, with the following structure:

```
[
{'name': 'model', 'columns': [
{'name': 'column_1', 'type': 'type', 'notNull': True, 'properties': {}
},
{'name': 'column_2', 'type': 'type', 'notNull': True, 'properties': {}
},
{'name': 'column_3', 'type': 'type', 'notNull': False, 'properties': {}
}
], 'properties': {}
}
]
```

Your task is to update this JSON structure by adding a `description` field inside both the `properties` attribute of each `column` and the `model` itself.
Each `description` should be derived from a user-provided input that explains the purpose or context of the `model` and its respective columns.
Follow these steps:
1. **For the `model`**: Prompt the user to provide a brief description of the model's overall purpose or its context. Insert this description in the `properties` field of the `model`.
2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `'description': 'user-provided text'`.
3. Ensure that the output is a well-formatted JSON structure, preserving the input's original format and adding the appropriate `description` fields.

### Output Format:

```
{
"models": [
{
"name": "model",
"columns": [
{
"name": "column_1",
"properties": {
"description": "<description for column_1>"
}
},
{
"name": "column_2",
"properties": {
"description": "<description for column_1>"
}
},
{
"name": "column_3",
"properties": {
"description": "<description for column_1>"
}
}
],
"properties": {
"description": "<description for model>"
}
}
]
}
```

Make sure that the descriptions are concise, informative, and contextually appropriate based on the input provided by the user.
"""

user_prompt_template = """
### Input:
User's prompt: {{ user_prompt }}
Picked models: {{ picked_models }}

Please provide a brief description for the model and each column based on the user's prompt.
"""


class SemanticsDescription(BasicPipeline):
def __init__(self, llm_provider: LLMProvider, **_):
self._components = {
"prompt_builder": PromptBuilder(template=user_prompt_template),
"generator": llm_provider.get_generator(system_prompt=system_prompt),
}
self._final = "normalize"

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
user_prompt: str,
selected_models: list[str],
mdl: dict,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
[self._final],
output_file_path=f"{destination}/semantics_description.dot",
inputs={
"user_prompt": user_prompt,
"selected_models": selected_models,
"mdl": mdl,
**self._components,
},
show_legend=True,
orient="LR",
)

@observe(name="Semantics Description Generation")
async def run(
self,
user_prompt: str,
selected_models: list[str],
mdl: dict,
) -> dict:
logger.info("Semantics Description Generation pipeline is running...")
return await self._pipe.execute(
[self._final],
inputs={
"user_prompt": user_prompt,
"selected_models": selected_models,
"mdl": mdl,
**self._components,
},
)


if __name__ == "__main__":
from langfuse.decorators import langfuse_context

from src.core.engine import EngineConfig
from src.core.pipeline import async_validate
from src.providers import init_providers
from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, _ = init_providers(EngineConfig())
pipeline = SemanticsDescription(llm_provider=llm_provider)

with open("sample/college_3_bigquery_mdl.json", "r") as file:
mdl = json.load(file)

input = {
"user_prompt": "Track student enrollments, grades, and GPA calculations to monitor academic performance and identify areas for student support",
"selected_models": [
"Student",
"Minor_in",
"Member_of",
"Gradeconversion",
"Faculty",
"Enrolled_in",
"Department",
"Course",
],
"mdl": mdl,
}

pipeline.visualize(**input)
async_validate(lambda: pipeline.run(**input))

langfuse_context.flush()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_service_container,
get_service_metadata,
)
from src.web.v1.routers import semantics_description
from src.web.v1.services.ask import (
AskRequest,
AskResponse,
Expand Down Expand Up @@ -57,6 +58,7 @@
)

router = APIRouter()
router.include_router(semantics_description.router)


@router.post("/semantics-preparations")
Expand Down
Loading
Loading