-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add generic typing for contexts to ensure type safety and updat…
…e python
- Loading branch information
1 parent
7c81b7f
commit 78683ea
Showing
9 changed files
with
84 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
.pdm-python | ||
|
||
__pycache__ | ||
*.pyi | ||
src/utils/load.pyi | ||
|
||
node_modules | ||
dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def compose(f1, f2): | ||
return lambda *args, **kwargs: f2(f1(*args, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from typing import Callable | ||
|
||
def compose[**P1, T1, T2](f1: Callable[P1, T1], f2: Callable[[T1], T2]) -> Callable[P1, T2]: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from promplate import Callback | ||
from promplate import Node as N | ||
from promplate.chain.node import ChainContext | ||
from promplate.prompt.utils import AutoNaming | ||
from promplate_trace.auto import patch | ||
|
||
|
||
@patch.node | ||
class Node(N, AutoNaming): | ||
def __init__(self, template, partial_context=None, *args, **config): | ||
super().__init__(template, partial_context, *args, **config) | ||
|
||
context_type = partial_context.__class__ | ||
if issubclass(context_type, ChainContext) and context_type is not ChainContext: | ||
self.add_callbacks(Callback(on_enter=lambda _, context, config: (context_type(context), config))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Awaitable, Callable, overload | ||
|
||
from promplate.chain.node import ChainContext, Context | ||
from promplate.chain.node import Node as N | ||
from promplate.llm.base import LLM | ||
from promplate.prompt.template import Template | ||
|
||
type _MaybeContext = Context | None | ||
|
||
type _Return = _MaybeContext | Awaitable[_Return] | ||
|
||
type _Process[C: ChainContext, T: _Return] = Callable[[C], T] | ||
|
||
class Node[C: ChainContext = ChainContext](N): | ||
@overload | ||
def __new__(cls, template: Template | str, partial_context: C, llm: LLM | None = ..., **config) -> Node[C]: ... | ||
@overload | ||
def __new__(cls, template: Template | str, partial_context: _MaybeContext = ..., llm: LLM | None = ..., **config) -> Node: ... | ||
def pre_process[T: _Return](self, process: Callable[[C], T]) -> Callable[[C], T]: ... # type: ignore | ||
def mid_process[T: _Return](self, process: Callable[[C], T]) -> Callable[[C], T]: ... # type: ignore | ||
def end_process[T: _Return](self, process: Callable[[C], T]) -> Callable[[C], T]: ... # type: ignore |