From e59c96ebe3908bbafa060168afccc7dc465cc084 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Oct 2022 20:19:27 -0700 Subject: [PATCH 1/2] add initial prompt stuff --- Makefile | 14 ++++++++ format.sh | 4 --- langchain/formatting.py | 21 ++++++++++++ langchain/prompt.py | 27 ++++++++++++++++ .../data/prompts/prompt_extra_args.json | 5 +++ .../data/prompts/prompt_missing_args.json | 3 ++ .../data/prompts/simple_prompt.json | 4 +++ tests/unit_tests/test_formatting.py | 22 +++++++++++++ tests/unit_tests/test_schema.py | 32 +++++++++++++++++++ 9 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 Makefile delete mode 100644 format.sh create mode 100644 langchain/formatting.py create mode 100644 langchain/prompt.py create mode 100644 tests/unit_tests/data/prompts/prompt_extra_args.json create mode 100644 tests/unit_tests/data/prompts/prompt_missing_args.json create mode 100644 tests/unit_tests/data/prompts/simple_prompt.json create mode 100644 tests/unit_tests/test_formatting.py create mode 100644 tests/unit_tests/test_schema.py diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000..0a0f6e39bfb8c --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +.PHONY: format lint tests + +format: + black . + isort . + +lint: + black . --check + isort . --check + flake8 . + mypy . + +tests: + pytest tests \ No newline at end of file diff --git a/format.sh b/format.sh deleted file mode 100644 index 40ffe5bf706bd..0000000000000 --- a/format.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -set -eu -black . -isort . diff --git a/langchain/formatting.py b/langchain/formatting.py new file mode 100644 index 0000000000000..96b0bffa7c855 --- /dev/null +++ b/langchain/formatting.py @@ -0,0 +1,21 @@ +from string import Formatter + + +class StrictFormatter(Formatter): + """A subclass of formatter that checks for extra keys.""" + + def check_unused_args(self, used_args, args, kwargs): + extra = set(kwargs).difference(used_args) + if extra: + raise KeyError(extra) + + def vformat(self, format_string, args, kwargs): + if len(args) > 0: + raise ValueError( + "No arguments should be provided, " + "everything should be passed as keyword arguments." + ) + return super().vformat(format_string, args, kwargs) + + +formatter = StrictFormatter() diff --git a/langchain/prompt.py b/langchain/prompt.py new file mode 100644 index 0000000000000..f4065bb1bcfee --- /dev/null +++ b/langchain/prompt.py @@ -0,0 +1,27 @@ +"""Base schema types.""" +from typing import Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.formatting import formatter + + +class Prompt(BaseModel): + """Schema to represent a prompt for an LLM.""" + + input_variables: List[str] + template: str + + class Config: + extra = Extra.forbid + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + input_variables = values["input_variables"] + template = values["template"] + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + try: + formatter.format(template, **dummy_inputs) + except KeyError: + raise ValueError("Invalid prompt schema.") + return values diff --git a/tests/unit_tests/data/prompts/prompt_extra_args.json b/tests/unit_tests/data/prompts/prompt_extra_args.json new file mode 100644 index 0000000000000..4bfc4fdcc4be6 --- /dev/null +++ b/tests/unit_tests/data/prompts/prompt_extra_args.json @@ -0,0 +1,5 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test.", + "bad_var": 1 +} \ No newline at end of file diff --git a/tests/unit_tests/data/prompts/prompt_missing_args.json b/tests/unit_tests/data/prompts/prompt_missing_args.json new file mode 100644 index 0000000000000..cb69d843e7ac5 --- /dev/null +++ b/tests/unit_tests/data/prompts/prompt_missing_args.json @@ -0,0 +1,3 @@ +{ + "input_variables": ["foo"] +} \ No newline at end of file diff --git a/tests/unit_tests/data/prompts/simple_prompt.json b/tests/unit_tests/data/prompts/simple_prompt.json new file mode 100644 index 0000000000000..d0f72b1c14f60 --- /dev/null +++ b/tests/unit_tests/data/prompts/simple_prompt.json @@ -0,0 +1,4 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test." +} \ No newline at end of file diff --git a/tests/unit_tests/test_formatting.py b/tests/unit_tests/test_formatting.py new file mode 100644 index 0000000000000..627a2ef43860e --- /dev/null +++ b/tests/unit_tests/test_formatting.py @@ -0,0 +1,22 @@ +import pytest + +from langchain.formatting import formatter + + +def test_valid_formatting(): + template = "This is a {foo} test." + output = formatter.format(template, foo="good") + expected_output = "This is a good test." + assert output == expected_output + + +def test_does_not_allow_args(): + template = "This is a {} test." + with pytest.raises(ValueError): + formatter.format(template, "good") + + +def test_does_not_allow_extra_kwargs(): + template = "This is a {foo} test." + with pytest.raises(KeyError): + formatter.format(template, foo="good", bar="oops") diff --git a/tests/unit_tests/test_schema.py b/tests/unit_tests/test_schema.py new file mode 100644 index 0000000000000..4706475e5acd6 --- /dev/null +++ b/tests/unit_tests/test_schema.py @@ -0,0 +1,32 @@ +import pytest + +from langchain.prompt import Prompt + + +def test_prompt_valid(): + template = "This is a {foo} test." + input_variables = ["foo"] + prompt = Prompt(input_variables=input_variables, template=template) + assert prompt.template == template + assert prompt.input_variables == input_variables + + +def test_prompt_missing_input_variables(): + template = "This is a {foo} test." + input_variables = [] + with pytest.raises(ValueError): + Prompt(input_variables=input_variables, template=template) + + +def test_prompt_extra_input_variables(): + template = "This is a {foo} test." + input_variables = ["foo", "bar"] + with pytest.raises(ValueError): + Prompt(input_variables=input_variables, template=template) + + +def test_prompt_wrong_input_variables(): + template = "This is a {foo} test." + input_variables = ["bar"] + with pytest.raises(ValueError): + Prompt(input_variables=input_variables, template=template) From 270fc8d0f9a0c1330a14ab6eef236f250ab44557 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Oct 2022 20:21:02 -0700 Subject: [PATCH 2/2] add newline --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 0a0f6e39bfb8c..895c8c8be1f5b 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ lint: mypy . tests: - pytest tests \ No newline at end of file + pytest tests