Skip to content

Commit

Permalink
Refactor prompting (#807)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Rudiger <prudiger@anaconda.com>
  • Loading branch information
ahuang11 and philippjfr authored Dec 5, 2024
1 parent dfb68b8 commit 72f0d63
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 73 deletions.
67 changes: 56 additions & 11 deletions lumen/ai/actor.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,76 @@
from abc import abstractmethod
from pathlib import Path
from types import FunctionType
from typing import Any

import param

from pydantic import BaseModel

from .llm import Message
from .utils import log_debug, render_template
from .utils import log_debug, render_template, warn_on_unused_variables


class Actor(param.Parameterized):

prompt_overrides = param.Dict(default={}, doc="""
Overrides the prompt's 'instructions' or 'context' jinja2 blocks.
template_overrides = param.Dict(default={}, doc="""
Overrides the template's 'instructions', 'context', or 'examples' jinja2 blocks.
Is a nested dictionary with the prompt name (e.g. main) as the key
and the block names as the inner keys with the new content as the
values.""")

prompt_templates = param.Dict(default={}, doc="""
The paths to the prompt's jinja2 templates.""")
prompts = param.Dict(default={}, doc="""
A dict of the prompt name, like 'main' as key nesting another dict
with keys like 'template', 'model', and/or 'model_factory'.""")

def _get_model(self, prompt_name: str, **context) -> type[BaseModel]:
if prompt_name in self.prompts and "model" in self.prompts[prompt_name]:
prompt_spec = self.prompts[prompt_name]
else:
prompt_spec = self.param.prompts.default[prompt_name]
if "model" not in prompt_spec:
raise KeyError(f"Prompt {prompt_name!r} does not provide a model.")
model_spec = prompt_spec["model"]
if isinstance(model_spec, FunctionType):
model = model_spec(**context)
else:
model = model_spec
return model

def _render_prompt(self, prompt_name: str, **context) -> str:
if prompt_name in self.prompts and "template" in self.prompts[prompt_name]:
prompt_spec = self.prompts[prompt_name]
else:
prompt_spec = self.param.prompts.default[prompt_name]
if "template" not in prompt_spec:
raise KeyError(f"Prompt {prompt_name!r} does not provide a prompt template.")
prompt_template = prompt_spec["template"]

overrides = self.template_overrides.get(prompt_name, {})
prompt_label = f"\033[92m{self.name[:-5]}.prompts['{prompt_name}']['template']\033[0m"
context["memory"] = self._memory
prompt = render_template(
self.prompt_templates[prompt_name],
overrides=self.prompt_overrides.get(prompt_name, {}),
**context
)
log_debug(f"\033[92mRendered prompt\033[0m '{prompt_name}':\n{prompt}")
if isinstance(prompt_template, str) and not Path(prompt_template).exists():
# check if all the format_kwargs keys are contained prompt_template
# e.g. the key, "memory", is not used in "string template".format(memory=memory)
format_kwargs = dict(**overrides, **context)
warn_on_unused_variables(prompt_template, format_kwargs, prompt_label)
try:
prompt = prompt_template.format(**format_kwargs)
except KeyError as e:
# check if all the format variables in prompt_template
# are available from format_kwargs, e.g. the key, "var",
# is not available in context "string template {var}".format(memory=memory)
raise KeyError(
f"Unexpected template variable: {e}. To resolve this, "
f"please ensure overrides contains the {e} key"
) from e
else:
prompt = render_template(
prompt_template,
overrides=overrides,
**context
)
log_debug(f"Below is the rendered prompt from {prompt_label}:\n{prompt}")
return prompt

async def _render_main_prompt(self, messages: list[Message], **context) -> str:
Expand Down
94 changes: 47 additions & 47 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .translate import param_to_pydantic
from .utils import (
clean_sql, describe_data, gather_table_sources, get_data, get_pipeline,
get_schema, report_error, retry_llm_output,
get_schema, log_debug, report_error, retry_llm_output,
)
from .views import AnalysisOutput, LumenOutput, SQLOutput

Expand Down Expand Up @@ -77,9 +77,6 @@ class Agent(Viewer, Actor):
requires = param.List(default=[], readonly=True, doc="""
List of context that this Agent requires to be in memory.""")

response_model = param.ClassSelector(class_=BaseModel, is_instance=False, doc="""
A Pydantic model determining the schema of the response.""")

user = param.String(default="Agent", doc="""
The name of the user that will be respond to the user query.""")

Expand Down Expand Up @@ -230,7 +227,7 @@ async def respond(
system_prompt = await self._render_main_prompt(messages)
message = None
async for output_chunk in self.llm.stream(
messages, system=system_prompt, response_model=self.response_model, field="output"
messages, system=system_prompt, field="output"
):
message = self.interface.stream(
output_chunk, replace=True, message=message, user=self.user, max_width=self._max_width
Expand Down Expand Up @@ -282,14 +279,12 @@ class ChatAgent(Agent):
Usually not used concurrently with SQLAgent, unlike AnalystAgent.
""")

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "ChatAgent" / "main.jinja2",
"main": {"template": PROMPTS_DIR / "ChatAgent" / "main.jinja2"},
}
)

response_model = param.ClassSelector(class_=BaseModel, is_instance=False)

requires = param.List(default=["source"], readonly=True)

async def _render_main_prompt(self, messages: list[Message], **context) -> str:
Expand Down Expand Up @@ -331,9 +326,9 @@ class AnalystAgent(ChatAgent):

requires = param.List(default=["source", "table", "pipeline", "sql"], readonly=True)

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "AnalystAgent" / "main.jinja2",
"main": {"template": PROMPTS_DIR / "AnalystAgent" / "main.jinja2"},
}
)

Expand Down Expand Up @@ -372,9 +367,9 @@ class TableListAgent(LumenBaseAgent):
Renders a list of all availables tables to the user.
Not useful for gathering information about the tables.""")

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "TableListAgent" / "main.jinja2",
"main": {"template": PROMPTS_DIR / "TableListAgent" / "main.jinja2"},
}
)

Expand Down Expand Up @@ -431,19 +426,31 @@ class SQLAgent(LumenBaseAgent):
also capable of joining it with other tables. Will generate and
execute a query in a single step.""")

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "SQLAgent" / "main.jinja2",
"select_table": PROMPTS_DIR / "SQLAgent" / "select_table.jinja2",
"require_joins": PROMPTS_DIR / "SQLAgent" / "require_joins.jinja2",
"find_joins": PROMPTS_DIR / "SQLAgent" / "find_joins.jinja2",
"main": {
"model": Sql,
"template": PROMPTS_DIR / "SQLAgent" / "main.jinja2"
},
"select_table": {
"model": make_table_model,
"template": PROMPTS_DIR / "SQLAgent" / "select_table.jinja2"
},
"require_joins": {
"model": JoinRequired,
"template": PROMPTS_DIR / "SQLAgent" / "require_joins.jinja2"
},
"find_joins": {
"model": TableJoins,
"template": PROMPTS_DIR / "SQLAgent" / "find_joins.jinja2"
},
}
)

requires = param.List(default=["source"], readonly=True)

provides = param.List(default=["table", "sql", "pipeline", "data"], readonly=True)

requires = param.List(default=["source"], readonly=True)

_extensions = ('codeeditor', 'tabulator',)

_output_type = SQLOutput
Expand All @@ -467,7 +474,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba
elif len(tables) > FUZZY_TABLE_LENGTH:
tables = await self._get_closest_tables(messages, tables)
system_prompt = self._render_prompt("select_table", tables_schema_str=tables_schema_str)
table_model = make_table_model(tables)
table_model = self._get_model("select_table", tables=tables)
result = await self.llm.invoke(
messages,
system=system_prompt,
Expand Down Expand Up @@ -512,10 +519,10 @@ async def _create_valid_sql(
)
}
]
print(errors)
log_debug(f"Below are the errors in `_create_valid_sql` retry:\n{errors}")

with self.interface.add_step(title=title or "SQL query", steps_layout=self._steps_layout) as step:
response = self.llm.stream(messages, system=system, response_model=Sql)
response = self.llm.stream(messages, system=system, response_model=self._get_model("main"))
sql_query = None
async for output in response:
step_message = output.chain_of_thought
Expand Down Expand Up @@ -602,7 +609,7 @@ async def _check_requires_joins(
response = self.llm.stream(
messages[-1:],
system=join_prompt,
response_model=JoinRequired,
response_model=self._get_model("require_joins"),
)
async for output in response:
step.stream(output.chain_of_thought, replace=True)
Expand All @@ -625,7 +632,7 @@ async def find_join_tables(self, messages: list):
output = await self.llm.invoke(
messages,
system=find_joins_prompt,
response_model=TableJoins,
response_model=self._get_model("find_joins"),
)
tables_to_join = output.tables_to_join
step.stream(
Expand Down Expand Up @@ -729,19 +736,15 @@ class BaseViewAgent(LumenBaseAgent):

provides = param.List(default=["plot"], readonly=True)

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "BaseViewAgent" / "main.jinja2",
"main": {"template": PROMPTS_DIR / "BaseViewAgent" / "main.jinja2"},
}
)

async def _extract_spec(self, model: BaseModel):
return dict(model)

@classmethod
def _get_model(cls, schema):
raise NotImplementedError()

async def respond(
self,
messages: list[Message],
Expand Down Expand Up @@ -770,7 +773,7 @@ async def respond(
output = await self.llm.invoke(
messages,
system=system_prompt,
response_model=self._get_model(schema),
response_model=self._get_model("main", schema=schema),
)
spec = await self._extract_spec(output)
chain_of_thought = spec.pop("chain_of_thought", None)
Expand All @@ -793,18 +796,17 @@ class hvPlotAgent(BaseViewAgent):
Generates a plot of the data given a user prompt.
If the user asks to plot, visualize or render the data this is your best best.""")

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "hvPlotAgent" / "main.jinja2",
"main": {"template": PROMPTS_DIR / "hvPlotAgent" / "main.jinja2"},
}
)

view_type = hvPlotUIView

@classmethod
def _get_model(cls, schema):
def _get_model(self, prompt_name: str, schema: dict[str, Any]) -> type[BaseModel]:
# Find parameters
excluded = cls.view_type._internal_params + [
excluded = self.view_type._internal_params + [
"controls",
"type",
"source",
Expand All @@ -814,10 +816,10 @@ def _get_model(cls, schema):
"field",
"selection_group",
]
model = param_to_pydantic(cls.view_type, excluded=excluded, schema=schema, extra_fields={
model = param_to_pydantic(self.view_type, excluded=excluded, schema=schema, extra_fields={
"chain_of_thought": (str, FieldInfo(description="Your thought process behind the plot.")),
})
return model[cls.view_type.__name__]
return model[self.view_type.__name__]

async def _extract_spec(self, model):
pipeline = self._memory["pipeline"]
Expand Down Expand Up @@ -847,20 +849,18 @@ class VegaLiteAgent(BaseViewAgent):
If the user asks to plot, visualize or render the data this is your best best.
""")

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "VegaLiteAgent" / "main.jinja2",
"main": {
"model": VegaLiteSpec,
"template": PROMPTS_DIR / "VegaLiteAgent" / "main.jinja2"},
}
)

view_type = VegaLiteView

_extensions = ('vega',)

@classmethod
def _get_model(cls, schema):
return VegaLiteSpec

async def _extract_spec(self, model: VegaLiteSpec):
vega_spec = json.loads(model.json_spec)
if "$schema" not in vega_spec:
Expand All @@ -877,9 +877,9 @@ class AnalysisAgent(LumenBaseAgent):
purpose = param.String(default="""
Perform custom analyses on the data.""")

prompt_templates = param.Dict(
prompts = param.Dict(
default={
"main": PROMPTS_DIR / "AnalysisAgent" / "main.jinja2",
"main": {"template": PROMPTS_DIR / "AnalysisAgent" / "main.jinja2"},
}
)

Expand Down
Loading

0 comments on commit 72f0d63

Please sign in to comment.