-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Hudson Cooper
committed
Feb 23, 2024
1 parent
c7e643c
commit cae7532
Showing
5 changed files
with
160 additions
and
137 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from abc import ABC | ||
from contextlib import nullcontext | ||
from typing import Generic, TypeVar | ||
from textwrap import dedent | ||
|
||
from pydantic import TypeAdapter, BaseModel, ValidationError | ||
|
||
import guidance | ||
from guidance import role, block, gen | ||
from guidance.models import Model, Chat | ||
from guidance._grammar import RawFunction | ||
|
||
from ..types import Type, gen_type | ||
from ..util import resolve_refs | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
|
||
class Response(BaseModel, Generic[T]): | ||
prompt_with_completion: str | ||
completion: str | ||
object: T | None | ||
error: str | None | ||
|
||
|
||
class PromptHelper(ABC): | ||
SYSTEM_MESSAGE: str | None = None | ||
USER_MESSAGE: str | ||
|
||
def __init__(self, lm: Model): | ||
self.lm = lm | ||
|
||
def __call__(self, object_schema: Type | None = None, **kwds) -> Response: | ||
lm = self.lm + self.function( | ||
**kwds, response_type=object_schema, capture_name="result" | ||
) | ||
result = lm["result"] | ||
if object_schema is not None: | ||
try: | ||
object = TypeAdapter(object_schema).validate_json(result) | ||
except ValidationError as e: | ||
error = str(e) | ||
object = None | ||
else: | ||
error = None | ||
|
||
return Response[object_schema]( | ||
prompt_with_completion=str(lm), | ||
completion=lm["result"], | ||
object=object, | ||
error=error, | ||
) | ||
|
||
@classmethod | ||
def function( | ||
cls, response_type: Type | None = None, capture_name: str = "response", **kwds | ||
) -> RawFunction: | ||
return _prompt_with_system_and_user_message( | ||
system_message=cls.SYSTEM_MESSAGE, | ||
user_message=cls.USER_MESSAGE, | ||
response_type=response_type, | ||
capture_name=capture_name, | ||
**kwds, | ||
) | ||
|
||
|
||
@guidance(stateless=False) | ||
def _prompt_with_system_and_user_message( | ||
lm: Model, | ||
system_message: str, | ||
user_message: str, | ||
capture_name: str | None, | ||
response_type: Type | None, | ||
**kwds, | ||
): | ||
system = role("system") if isinstance(lm, Chat) else nullcontext | ||
user = role("user") if isinstance(lm, Chat) else nullcontext | ||
assistant = role("assistant") if isinstance(lm, Chat) else nullcontext | ||
if system_message is not None: | ||
with system: | ||
lm += system_message.format(**kwds) | ||
with user: | ||
lm += user_message.format(**kwds) | ||
if response_type is not None: | ||
lm += "\n" + _formatting_instructions(response_type) | ||
with assistant, block(capture_name): | ||
if response_type is None: | ||
grammar = gen() | ||
else: | ||
# TODO: add type-hints to the user/system bit? | ||
grammar = gen_type(response_type) | ||
lm += grammar | ||
return lm | ||
|
||
|
||
def _formatting_instructions(type): | ||
# Shamelessly stolen from langchain | ||
return dedent( | ||
""" | ||
The output should be formatted as a JSON instance that conforms to the JSON schema below. | ||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}} | ||
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted. | ||
Here is the output schema: | ||
``` | ||
{schema} | ||
``` | ||
""" | ||
).format(schema=TypeAdapter(type).json_schema()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from textwrap import dedent | ||
from pydantic import BaseModel, TypeAdapter | ||
from minml.types import Type | ||
from .base import PromptHelper, Response | ||
|
||
|
||
class GrepPrompt(PromptHelper): | ||
SYSTEM_MESSAGE = "Pay close attention to the schema and type of object the user is looking for. If there are no objects of the type the user is looking for, return an empty list." | ||
|
||
# Shamelessly stolen from llamaindex | ||
USER_MESSAGE = dedent( | ||
""" | ||
Please find each `{type}` in the text below. | ||
--------------------- | ||
{text} | ||
--------------------- | ||
""" | ||
).strip("\n") | ||
|
||
def __call__(self, name: str, object_schema: Type | type, text: str) -> Response: | ||
return super().__call__(object_schema=list[object_schema], type=name, text=text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from textwrap import dedent | ||
from .base import PromptHelper | ||
|
||
|
||
class DBQPrompt(PromptHelper): | ||
# Shamelessly stolen from llamaindex | ||
SYSTEM_MESSAGE = dedent( | ||
""" | ||
You are an expert Q&A system that is trusted around the world. | ||
Always answer the query using the provided context information, and not prior knowledge. | ||
Some rules to follow: | ||
1. Never directly reference the given context in your answer. | ||
2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines. | ||
""" | ||
).strip("\n") | ||
|
||
# Shamelessly stolen from llamaindex | ||
USER_MESSAGE = dedent( | ||
""" | ||
Context information from multiple sources is below. | ||
--------------------- | ||
{context_str} | ||
--------------------- | ||
Given the information from multiple sources and not prior knowledge, answer the query. | ||
Query: {query_str} | ||
""" | ||
).strip("\n") |
This file was deleted.
Oops, something went wrong.