From 72f0d638a205245e3fa8059beb213858790fb4c3 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Thu, 5 Dec 2024 05:26:04 -0800 Subject: [PATCH] Refactor prompting (#807) Co-authored-by: Philipp Rudiger --- lumen/ai/actor.py | 67 ++++++++++++++++++++++++----- lumen/ai/agents.py | 94 ++++++++++++++++++++--------------------- lumen/ai/coordinator.py | 50 +++++++++++++++------- lumen/ai/utils.py | 21 +++++++++ 4 files changed, 159 insertions(+), 73 deletions(-) diff --git a/lumen/ai/actor.py b/lumen/ai/actor.py index e74306504..560453c3b 100644 --- a/lumen/ai/actor.py +++ b/lumen/ai/actor.py @@ -1,31 +1,76 @@ from abc import abstractmethod +from pathlib import Path +from types import FunctionType from typing import Any import param +from pydantic import BaseModel + from .llm import Message -from .utils import log_debug, render_template +from .utils import log_debug, render_template, warn_on_unused_variables class Actor(param.Parameterized): - prompt_overrides = param.Dict(default={}, doc=""" - Overrides the prompt's 'instructions' or 'context' jinja2 blocks. + template_overrides = param.Dict(default={}, doc=""" + Overrides the template's 'instructions', 'context', or 'examples' jinja2 blocks. Is a nested dictionary with the prompt name (e.g. main) as the key and the block names as the inner keys with the new content as the values.""") - prompt_templates = param.Dict(default={}, doc=""" - The paths to the prompt's jinja2 templates.""") + prompts = param.Dict(default={}, doc=""" + A dict of the prompt name, like 'main' as key nesting another dict + with keys like 'template', 'model', and/or 'model_factory'.""") + + def _get_model(self, prompt_name: str, **context) -> type[BaseModel]: + if prompt_name in self.prompts and "model" in self.prompts[prompt_name]: + prompt_spec = self.prompts[prompt_name] + else: + prompt_spec = self.param.prompts.default[prompt_name] + if "model" not in prompt_spec: + raise KeyError(f"Prompt {prompt_name!r} does not provide a model.") + model_spec = prompt_spec["model"] + if isinstance(model_spec, FunctionType): + model = model_spec(**context) + else: + model = model_spec + return model def _render_prompt(self, prompt_name: str, **context) -> str: + if prompt_name in self.prompts and "template" in self.prompts[prompt_name]: + prompt_spec = self.prompts[prompt_name] + else: + prompt_spec = self.param.prompts.default[prompt_name] + if "template" not in prompt_spec: + raise KeyError(f"Prompt {prompt_name!r} does not provide a prompt template.") + prompt_template = prompt_spec["template"] + + overrides = self.template_overrides.get(prompt_name, {}) + prompt_label = f"\033[92m{self.name[:-5]}.prompts['{prompt_name}']['template']\033[0m" context["memory"] = self._memory - prompt = render_template( - self.prompt_templates[prompt_name], - overrides=self.prompt_overrides.get(prompt_name, {}), - **context - ) - log_debug(f"\033[92mRendered prompt\033[0m '{prompt_name}':\n{prompt}") + if isinstance(prompt_template, str) and not Path(prompt_template).exists(): + # check if all the format_kwargs keys are contained prompt_template + # e.g. the key, "memory", is not used in "string template".format(memory=memory) + format_kwargs = dict(**overrides, **context) + warn_on_unused_variables(prompt_template, format_kwargs, prompt_label) + try: + prompt = prompt_template.format(**format_kwargs) + except KeyError as e: + # check if all the format variables in prompt_template + # are available from format_kwargs, e.g. the key, "var", + # is not available in context "string template {var}".format(memory=memory) + raise KeyError( + f"Unexpected template variable: {e}. To resolve this, " + f"please ensure overrides contains the {e} key" + ) from e + else: + prompt = render_template( + prompt_template, + overrides=overrides, + **context + ) + log_debug(f"Below is the rendered prompt from {prompt_label}:\n{prompt}") return prompt async def _render_main_prompt(self, messages: list[Message], **context) -> str: diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 2ece2eddb..df40ba596 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -39,7 +39,7 @@ from .translate import param_to_pydantic from .utils import ( clean_sql, describe_data, gather_table_sources, get_data, get_pipeline, - get_schema, report_error, retry_llm_output, + get_schema, log_debug, report_error, retry_llm_output, ) from .views import AnalysisOutput, LumenOutput, SQLOutput @@ -77,9 +77,6 @@ class Agent(Viewer, Actor): requires = param.List(default=[], readonly=True, doc=""" List of context that this Agent requires to be in memory.""") - response_model = param.ClassSelector(class_=BaseModel, is_instance=False, doc=""" - A Pydantic model determining the schema of the response.""") - user = param.String(default="Agent", doc=""" The name of the user that will be respond to the user query.""") @@ -230,7 +227,7 @@ async def respond( system_prompt = await self._render_main_prompt(messages) message = None async for output_chunk in self.llm.stream( - messages, system=system_prompt, response_model=self.response_model, field="output" + messages, system=system_prompt, field="output" ): message = self.interface.stream( output_chunk, replace=True, message=message, user=self.user, max_width=self._max_width @@ -282,14 +279,12 @@ class ChatAgent(Agent): Usually not used concurrently with SQLAgent, unlike AnalystAgent. """) - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "ChatAgent" / "main.jinja2", + "main": {"template": PROMPTS_DIR / "ChatAgent" / "main.jinja2"}, } ) - response_model = param.ClassSelector(class_=BaseModel, is_instance=False) - requires = param.List(default=["source"], readonly=True) async def _render_main_prompt(self, messages: list[Message], **context) -> str: @@ -331,9 +326,9 @@ class AnalystAgent(ChatAgent): requires = param.List(default=["source", "table", "pipeline", "sql"], readonly=True) - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "AnalystAgent" / "main.jinja2", + "main": {"template": PROMPTS_DIR / "AnalystAgent" / "main.jinja2"}, } ) @@ -372,9 +367,9 @@ class TableListAgent(LumenBaseAgent): Renders a list of all availables tables to the user. Not useful for gathering information about the tables.""") - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "TableListAgent" / "main.jinja2", + "main": {"template": PROMPTS_DIR / "TableListAgent" / "main.jinja2"}, } ) @@ -431,19 +426,31 @@ class SQLAgent(LumenBaseAgent): also capable of joining it with other tables. Will generate and execute a query in a single step.""") - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "SQLAgent" / "main.jinja2", - "select_table": PROMPTS_DIR / "SQLAgent" / "select_table.jinja2", - "require_joins": PROMPTS_DIR / "SQLAgent" / "require_joins.jinja2", - "find_joins": PROMPTS_DIR / "SQLAgent" / "find_joins.jinja2", + "main": { + "model": Sql, + "template": PROMPTS_DIR / "SQLAgent" / "main.jinja2" + }, + "select_table": { + "model": make_table_model, + "template": PROMPTS_DIR / "SQLAgent" / "select_table.jinja2" + }, + "require_joins": { + "model": JoinRequired, + "template": PROMPTS_DIR / "SQLAgent" / "require_joins.jinja2" + }, + "find_joins": { + "model": TableJoins, + "template": PROMPTS_DIR / "SQLAgent" / "find_joins.jinja2" + }, } ) - requires = param.List(default=["source"], readonly=True) - provides = param.List(default=["table", "sql", "pipeline", "data"], readonly=True) + requires = param.List(default=["source"], readonly=True) + _extensions = ('codeeditor', 'tabulator',) _output_type = SQLOutput @@ -467,7 +474,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba elif len(tables) > FUZZY_TABLE_LENGTH: tables = await self._get_closest_tables(messages, tables) system_prompt = self._render_prompt("select_table", tables_schema_str=tables_schema_str) - table_model = make_table_model(tables) + table_model = self._get_model("select_table", tables=tables) result = await self.llm.invoke( messages, system=system_prompt, @@ -512,10 +519,10 @@ async def _create_valid_sql( ) } ] - print(errors) + log_debug(f"Below are the errors in `_create_valid_sql` retry:\n{errors}") with self.interface.add_step(title=title or "SQL query", steps_layout=self._steps_layout) as step: - response = self.llm.stream(messages, system=system, response_model=Sql) + response = self.llm.stream(messages, system=system, response_model=self._get_model("main")) sql_query = None async for output in response: step_message = output.chain_of_thought @@ -602,7 +609,7 @@ async def _check_requires_joins( response = self.llm.stream( messages[-1:], system=join_prompt, - response_model=JoinRequired, + response_model=self._get_model("require_joins"), ) async for output in response: step.stream(output.chain_of_thought, replace=True) @@ -625,7 +632,7 @@ async def find_join_tables(self, messages: list): output = await self.llm.invoke( messages, system=find_joins_prompt, - response_model=TableJoins, + response_model=self._get_model("find_joins"), ) tables_to_join = output.tables_to_join step.stream( @@ -729,19 +736,15 @@ class BaseViewAgent(LumenBaseAgent): provides = param.List(default=["plot"], readonly=True) - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "BaseViewAgent" / "main.jinja2", + "main": {"template": PROMPTS_DIR / "BaseViewAgent" / "main.jinja2"}, } ) async def _extract_spec(self, model: BaseModel): return dict(model) - @classmethod - def _get_model(cls, schema): - raise NotImplementedError() - async def respond( self, messages: list[Message], @@ -770,7 +773,7 @@ async def respond( output = await self.llm.invoke( messages, system=system_prompt, - response_model=self._get_model(schema), + response_model=self._get_model("main", schema=schema), ) spec = await self._extract_spec(output) chain_of_thought = spec.pop("chain_of_thought", None) @@ -793,18 +796,17 @@ class hvPlotAgent(BaseViewAgent): Generates a plot of the data given a user prompt. If the user asks to plot, visualize or render the data this is your best best.""") - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "hvPlotAgent" / "main.jinja2", + "main": {"template": PROMPTS_DIR / "hvPlotAgent" / "main.jinja2"}, } ) view_type = hvPlotUIView - @classmethod - def _get_model(cls, schema): + def _get_model(self, prompt_name: str, schema: dict[str, Any]) -> type[BaseModel]: # Find parameters - excluded = cls.view_type._internal_params + [ + excluded = self.view_type._internal_params + [ "controls", "type", "source", @@ -814,10 +816,10 @@ def _get_model(cls, schema): "field", "selection_group", ] - model = param_to_pydantic(cls.view_type, excluded=excluded, schema=schema, extra_fields={ + model = param_to_pydantic(self.view_type, excluded=excluded, schema=schema, extra_fields={ "chain_of_thought": (str, FieldInfo(description="Your thought process behind the plot.")), }) - return model[cls.view_type.__name__] + return model[self.view_type.__name__] async def _extract_spec(self, model): pipeline = self._memory["pipeline"] @@ -847,9 +849,11 @@ class VegaLiteAgent(BaseViewAgent): If the user asks to plot, visualize or render the data this is your best best. """) - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "VegaLiteAgent" / "main.jinja2", + "main": { + "model": VegaLiteSpec, + "template": PROMPTS_DIR / "VegaLiteAgent" / "main.jinja2"}, } ) @@ -857,10 +861,6 @@ class VegaLiteAgent(BaseViewAgent): _extensions = ('vega',) - @classmethod - def _get_model(cls, schema): - return VegaLiteSpec - async def _extract_spec(self, model: VegaLiteSpec): vega_spec = json.loads(model.json_spec) if "$schema" not in vega_spec: @@ -877,9 +877,9 @@ class AnalysisAgent(LumenBaseAgent): purpose = param.String(default=""" Perform custom analyses on the data.""") - prompt_templates = param.Dict( + prompts = param.Dict( default={ - "main": PROMPTS_DIR / "AnalysisAgent" / "main.jinja2", + "main": {"template": PROMPTS_DIR / "AnalysisAgent" / "main.jinja2"}, } ) diff --git a/lumen/ai/coordinator.py b/lumen/ai/coordinator.py index a743e8921..6b84bcfd5 100644 --- a/lumen/ai/coordinator.py +++ b/lumen/ai/coordinator.py @@ -82,10 +82,14 @@ class Coordinator(Viewer, Actor): suggestions = param.List(default=GETTING_STARTED_SUGGESTIONS, doc=""" Initial list of suggestions of actions the user can take.""") - prompt_templates = param.Dict(default={ - "main": PROMPTS_DIR / "Coordinator" / "main.jinja2", - "check_validity": PROMPTS_DIR / "Coordinator" / "check_validity.jinja2", - }, doc="""The paths to the prompt's jinja2 templates.""") + prompts = param.Dict( + default={ + "check_validity": { + "template": PROMPTS_DIR / "Coordinator" / "check_validity.jinja2", + "model": Validity, + }, + } + ) __abstract = True @@ -317,7 +321,7 @@ async def _invalidate_memory(self, messages): output = await self.llm.invoke( messages=messages, system=system, - response_model=Validity, + response_model=self._get_model("check_validity"), ) step.stream(output.correct_assessment, replace=True) step.success_title = f"{output.is_invalid.title()} needs refresh" if output.is_invalid else "Memory still valid" @@ -463,10 +467,18 @@ class DependencyResolver(Coordinator): information required for that agent until the answer is available. """ - prompt_templates = param.Dict(default={ - "main": PROMPTS_DIR / "DependencyResolver" / "main.jinja2", - "check_validity": PROMPTS_DIR / "Coordinator" / "check_validity.jinja2", - }, doc="""The paths to the prompt's jinja2 templates.""") + prompts = param.Dict( + default={ + "main": { + "template": PROMPTS_DIR / "DependencyResolver" / "main.jinja2", + "model": make_agent_model, + }, + "check_validity": { + "template": PROMPTS_DIR / "Coordinator" / "check_validity.jinja2", + "model": Validity, + }, + }, + ) async def _choose_agent( self, @@ -479,7 +491,7 @@ async def _choose_agent( agents = self.agents agents = [agent for agent in agents if await agent.applies(self._memory)] agent_names = tuple(sagent.name[:-5] for sagent in agents) - agent_model = make_agent_model(agent_names, primary=primary) + agent_model = self._get_model("main", agent_names=agent_names, primary=primary) if len(agent_names) == 0: raise ValueError("No agents available to choose from.") if len(agent_names) == 1: @@ -544,10 +556,18 @@ class Planner(Coordinator): and then executes it. """ - prompt_templates = param.Dict(default={ - "main": PROMPTS_DIR / "Planner" / "main.jinja2", - "check_validity": PROMPTS_DIR / "Coordinator" / "check_validity.jinja2", - }, doc="""The paths to the prompt's jinja2 templates.""") + prompts = param.Dict( + default={ + "main": { + "template": PROMPTS_DIR / "Planner" / "main.jinja2", + "model": make_plan_models, + }, + "check_validity": { + "template": PROMPTS_DIR / "Coordinator" / "check_validity.jinja2", + "model": Validity, + }, + } + ) @classmethod async def _lookup_schemas( @@ -670,7 +690,7 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s for table in src.get_tables(): tables[table] = src - reason_model, plan_model = make_plan_models(agent_names, list(tables)) + reason_model, plan_model = self._get_model("main", agent_names=agent_names, tables=list(tables)) planned = False unmet_dependencies = set() schemas = {} diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index d725d2582..2f387e03f 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -3,6 +3,7 @@ import asyncio import inspect import math +import re import time from functools import wraps @@ -61,6 +62,26 @@ def render_template(template_path: Path, overrides: dict | None = None, relative template = env.get_template(template_name) return template.render(**context) + +def warn_on_unused_variables(string, kwargs, prompt_label): + used_keys = set() + + for key in kwargs: + pattern = r'\b' + re.escape(key) + r'\b' + if re.search(pattern, string): + used_keys.add(key) + + unused_keys = set(kwargs.keys()) - used_keys + if unused_keys: + # TODO: reword this concisely... what do you call those variables for formatting? + log.warning( + f"The prompt template, {prompt_label}, is missing keys, " + f"which could mean the LLM is lacking the context provided " + f"from these variables: {unused_keys}. If this is unintended, " + f"please create a template that contains those keys." + ) + + def retry_llm_output(retries=3, sleep=1): """ Retry a function that returns a response from the LLM API.