Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt template override in BaseProvider #309

Merged
merged 13 commits into from
Aug 3, 2023
Merged
14 changes: 14 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,20 @@ A function that computes the lowest common multiples of two integers, and
a function that runs 5 test cases of the lowest common multiple function
```

### Prompt templates

Each provider can define **prompt templates** for each supported format. A prompt
template guides the language model to produce output in a particular
format. The default prompt templates are a
[Python dictionary mapping formats to templates](https://github.com/jupyterlab/jupyter-ai/blob/57a758fa5cdd5a87da5519987895aa688b3766a8/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py#L138-L166).
Developers who write subclasses of `BaseProvider` can override templates per
output format, per model, and based on the prompt being submitted, by
implementing their own
[`get_prompt_template` function](https://github.com/jupyterlab/jupyter-ai/blob/57a758fa5cdd5a87da5519987895aa688b3766a8/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py#L186-L195).
Each prompt template includes the string `{prompt}`, which is replaced with
the user-provided prompt when the user runs a magic command.


### Clearing the OpenAI chat history

With the `openai-chat` provider *only*, you can run a cell magic command using the `-r` or
Expand Down
49 changes: 18 additions & 31 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def _repr_mimebundle_(self, include=None, exclude=None):

NA_MESSAGE = '<abbr title="Not applicable">N/A</abbr>'

MARKDOWN_PROMPT_TEMPLATE = "{prompt}\n\nProduce output in markdown format only."

PROVIDER_NO_MODELS = "This provider does not define a list of models."

CANNOT_DETERMINE_MODEL_TEXT = """Cannot determine model provider from model ID '{0}'.
Expand All @@ -93,17 +91,6 @@ def _repr_mimebundle_(self, include=None, exclude=None):
To see a list of models you can use, run `%ai list`"""


PROMPT_TEMPLATES_BY_FORMAT = {
"code": "{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.",
"html": "{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.",
"image": "{prompt}\n\nProduce output as an image only, with no text before or after it.",
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
"math": "{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.",
"json": "{prompt}\n\nProduce output in JSON format only, with nothing before or after it.",
"text": "{prompt}", # No customization
}

AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}


Expand Down Expand Up @@ -465,24 +452,6 @@ def handle_list(self, args: ListArgs):
)

def run_ai_cell(self, args: CellArgs, prompt: str):
# Apply a prompt template.
prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt=prompt)

# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
if Provider is None:
Expand All @@ -500,6 +469,17 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
self.transcript_openai = []
return

# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
Expand Down Expand Up @@ -541,6 +521,13 @@ def run_ai_cell(self, args: CellArgs, prompt: str):

provider = Provider(**provider_params)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)

# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

# generate output from model via provider
result = provider.generate([prompt])
output = result.generations[0][0].text
Expand Down
52 changes: 52 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union

from jsonpath_ng import parse
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import (
AI21,
Expand Down Expand Up @@ -117,6 +118,10 @@ class Config:
# instance attrs
#
model_id: str
prompt_templates: Dict[str, PromptTemplate]
"""Prompt templates for each output type. Can be overridden with
`update_prompt_template`. The function `prompt_template`, in the base class,
refers to this."""

def __init__(self, *args, **kwargs):
try:
Expand All @@ -130,6 +135,36 @@ def __init__(self, *args, **kwargs):
if self.__class__.model_id_key != "model_id":
model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]

model_kwargs["prompt_templates"] = {
"code": PromptTemplate.from_template(
"{prompt}\n\nProduce output as source code only, "
"with no text or explanation before or after it."
),
"html": PromptTemplate.from_template(
"{prompt}\n\nProduce output in HTML format only, "
"with no markup before or afterward."
),
"image": PromptTemplate.from_template(
"{prompt}\n\nProduce output as an image only, "
"with no text before or after it."
),
"markdown": PromptTemplate.from_template(
"{prompt}\n\nProduce output in markdown format only."
),
"md": PromptTemplate.from_template(
"{prompt}\n\nProduce output in markdown format only."
),
"math": PromptTemplate.from_template(
"{prompt}\n\nProduce output in LaTeX format only, "
"with $$ at the beginning and end."
),
"json": PromptTemplate.from_template(
"{prompt}\n\nProduce output in JSON format only, "
"with nothing before or after it."
),
"text": PromptTemplate.from_template("{prompt}"), # No customization
}

super().__init__(*args, **kwargs, **model_kwargs)

async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand All @@ -142,6 +177,23 @@ async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
_call_with_args = functools.partial(self._call, *args, **kwargs)
return await loop.run_in_executor(executor, _call_with_args)

def update_prompt_template(self, format: str, template: str):
"""
Changes the class-level prompt template for a given format.
"""
self.prompt_templates[format] = PromptTemplate.from_template(template)

def get_prompt_template(self, format) -> PromptTemplate:
"""
Produce a prompt template suitable for use with a particular model, to
produce output in a desired format.
"""

if format in self.prompt_templates:
return self.prompt_templates[format]
else:
return self.prompt_templates["text"] # Default to plain format


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down
Loading