Skip to content

Commit

Permalink
feat: add generic typing for contexts to ensure type safety and updat…
Browse files Browse the repository at this point in the history
…e python
  • Loading branch information
CNSeniorious000 committed Dec 29, 2024
1 parent 7c81b7f commit 78683ea
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
.pdm-python

__pycache__
*.pyi
src/utils/load.pyi

node_modules
dist
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ RUN bun install
COPY frontend .
RUN NODE_ENV=production bun run build

FROM python:3.12-slim AS py
FROM python:3.13-slim AS py
WORKDIR /app
COPY pyproject.toml .
RUN pip install uv --disable-pip-version-check && uv venv && uv pip install -r pyproject.toml --compile-bytecode

FROM python:3.12-slim AS base
FROM python:3.13-slim AS base
WORKDIR /app
COPY --from=js /app/dist frontend/dist
COPY --from=py /app .
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "promplate-demo"
version = "1"
authors = [{ name = "Muspi Merol", email = "me@promplate.dev" }]
requires-python = ">=3.10,<3.13"
requires-python = ">=3.10,<3.14"
readme = "README.md"
license = { text = "MIT" }
dependencies = [
Expand Down Expand Up @@ -46,4 +46,5 @@ line-length = 130

[tool.pyright]
exclude = ["**/*.pyi"]
include = ["src"]
include = ["src"]
typeCheckingMode = "standard"
51 changes: 34 additions & 17 deletions src/logic/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from asyncio import gather, get_running_loop
from json import JSONDecodeError, dumps, loads
from json import dumps
from typing import cast

from promplate import Chain, ChainContext, Jump, Message, Node
from promplate import Chain, ChainContext, Jump, Message
from promplate.prompt.utils import AutoNaming
from promplate_trace.auto import patch
from promptools.extractors import extract_json
from pydantic import TypeAdapter, ValidationError
from rich import print

from ..templates.schema.output import Output
from ..utils.functional import compose
from ..utils.load import load_template
from ..utils.node import Node
from .tools import call_tool, tools

main = patch.node(Node)(load_template("main"), {"tools": tools})

class TypedContext(ChainContext):
partial = True
parsed: Output = {}


main = Node(load_template("main"), TypedContext({"tools": tools}))


@patch.chain
Expand All @@ -23,19 +32,24 @@ class Loop(Chain, AutoNaming):
main_loop = Loop(main)


_validator = TypeAdapter(Output)
serialize = compose(_validator.dump_json, bytes.decode)
loads = _validator.validate_json


@main.end_process
async def collect_results(context: ChainContext):
parsed = cast(dict, context["parsed"] or {"content": [{"text": context.result}]})
async def collect_results(context: TypedContext):
parsed = context.parsed or {"content": [{"text": context.result}]}
actions = parsed.get("actions", [])

if not actions:
return

results = await gather(*(call_tool(i["name"], i["body"]) for i in actions))
results = await gather(*(call_tool(i["name"], i.get("body", {})) for i in actions))

messages = cast(list[Message], context["messages"])

messages.append({"role": "assistant", "content": dumps(parsed, ensure_ascii=False)})
messages.append({"role": "assistant", "content": serialize(parsed)})

messages.extend(
[
Expand All @@ -54,24 +68,27 @@ async def collect_results(context: ChainContext):


@main.mid_process
def parse_json(context: ChainContext):
def parse_json(context: TypedContext):
try:
context["parsed"] = loads(context.result)
context.parsed = loads(context.result)
context.pop("partial", None)
print("parsed json:", context["parsed"])
except JSONDecodeError:
context.partial = False
print("parsed json:", context.parsed)
raise Jump(out_of=main)
except ValidationError:
context["partial"] = True
try:
context["parsed"] = extract_json(context.result, context.get("parsed", {}), Output)
context.parsed = extract_json(context.result, context.parsed, Output)
except SyntaxError:
context["parsed"] = {"content": [{"text": context.result}]}

context["partial"] = True
finally:
context["parsed"] = context.parsed


@main.mid_process
async def run_tools(context: ChainContext):
if actions := cast(dict, context["parsed"]).get("actions", []):
async def run_tools(context: TypedContext):
if actions := cast(dict, context.parsed).get("actions", []):
loop = get_running_loop()
for action in actions[slice(None, -1 if context.get("partial") else None)]:
for action in actions[slice(None, -1 if context.partial else None)]:
loop.create_task(call_tool(action["name"], action["body"]))
print(f"start <{action['name']}> with {action['body']}")
6 changes: 3 additions & 3 deletions src/logic/translate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Literal

from partial_json_parser import ARR
from promplate import ChainContext, Node
from promplate_trace.auto import patch
from promplate import ChainContext
from promptools.extractors import extract_json

from ..utils.llm.openai import openai
from ..utils.load import load_template
from ..utils.node import Node

translate = patch.node(Node)(load_template("translate"), llm=openai)
translate = Node(load_template("translate"), llm=openai)

translate.run_config["stop"] = ["\n```"]
translate.run_config["temperature"] = 0
Expand Down
2 changes: 2 additions & 0 deletions src/utils/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def compose(f1, f2):
return lambda *args, **kwargs: f2(f1(*args, **kwargs))
3 changes: 3 additions & 0 deletions src/utils/functional.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Callable

def compose[**P1, T1, T2](f1: Callable[P1, T1], f2: Callable[[T1], T2]) -> Callable[P1, T2]: ...
15 changes: 15 additions & 0 deletions src/utils/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from promplate import Callback
from promplate import Node as N
from promplate.chain.node import ChainContext
from promplate.prompt.utils import AutoNaming
from promplate_trace.auto import patch


@patch.node
class Node(N, AutoNaming):
def __init__(self, template, partial_context=None, *args, **config):
super().__init__(template, partial_context, *args, **config)

context_type = partial_context.__class__
if issubclass(context_type, ChainContext) and context_type is not ChainContext:
self.add_callbacks(Callback(on_enter=lambda _, context, config: (context_type(context), config)))
21 changes: 21 additions & 0 deletions src/utils/node.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Awaitable, Callable, overload

from promplate.chain.node import ChainContext, Context
from promplate.chain.node import Node as N
from promplate.llm.base import LLM
from promplate.prompt.template import Template

type _MaybeContext = Context | None

type _Return = _MaybeContext | Awaitable[_Return]

type _Process[C: ChainContext, T: _Return] = Callable[[C], T]

class Node[C: ChainContext = ChainContext](N):
@overload
def __new__(cls, template: Template | str, partial_context: C, llm: LLM | None = ..., **config) -> Node[C]: ...
@overload
def __new__(cls, template: Template | str, partial_context: _MaybeContext = ..., llm: LLM | None = ..., **config) -> Node: ...
def pre_process[T: _Return](self, process: Callable[[C], T]) -> Callable[[C], T]: ... # type: ignore
def mid_process[T: _Return](self, process: Callable[[C], T]) -> Callable[[C], T]: ... # type: ignore
def end_process[T: _Return](self, process: Callable[[C], T]) -> Callable[[C], T]: ... # type: ignore

0 comments on commit 78683ea

Please sign in to comment.