diff --git a/python/promplate/chain/node.py b/python/promplate/chain/node.py index 62691bf..42d1124 100644 --- a/python/promplate/chain/node.py +++ b/python/promplate/chain/node.py @@ -13,16 +13,13 @@ class ChainContext(SafeChainMapContext): @overload - def __init__(self): - ... + def __init__(self): ... @overload - def __init__(self, least: MutableMapping | None = None): - ... + def __init__(self, least: MutableMapping | None = None): ... @overload - def __init__(self, least: MutableMapping | None = None, *maps: Mapping): - ... + def __init__(self, least: MutableMapping | None = None, *maps: Mapping): ... def __init__(self, least: MutableMapping | None = None, *maps: Mapping): super().__init__({} if least is None else least, *maps) # type: ignore @@ -56,8 +53,7 @@ def invoke( /, complete: Complete | None = None, **config, - ) -> ChainContext: - ... + ) -> ChainContext: ... async def ainvoke( self, @@ -65,8 +61,7 @@ async def ainvoke( /, complete: Complete | AsyncComplete | None = None, **config, - ) -> ChainContext: - ... + ) -> ChainContext: ... def stream( self, @@ -74,8 +69,7 @@ def stream( /, generate: Generate | None = None, **config, - ) -> Iterable[ChainContext]: - ... + ) -> Iterable[ChainContext]: ... def astream( self, @@ -83,8 +77,7 @@ def astream( /, generate: Generate | AsyncGenerate | None = None, **config, - ) -> AsyncIterable[ChainContext]: - ... + ) -> AsyncIterable[ChainContext]: ... @classmethod def _get_chain_type(cls): @@ -108,8 +101,7 @@ def _invoke( complete: Complete | None, callbacks: list[BaseCallback], **config, - ): - ... + ): ... async def _ainvoke( self, @@ -118,8 +110,7 @@ async def _ainvoke( complete: Complete | AsyncComplete | None, callbacks: list[BaseCallback], **config, - ): - ... + ): ... def _stream( self, @@ -128,8 +119,7 @@ def _stream( generate: Generate | None, callbacks: list[BaseCallback], **config, - ) -> Iterable: - ... + ) -> Iterable: ... def _astream( self, @@ -138,8 +128,7 @@ def _astream( generate: Generate | AsyncGenerate | None, callbacks: list[BaseCallback], **config, - ) -> AsyncIterable: - ... + ) -> AsyncIterable: ... callbacks: list[BaseCallback | type[BaseCallback]] diff --git a/python/promplate/llm/base.py b/python/promplate/llm/base.py index 13e5a98..b7b183b 100644 --- a/python/promplate/llm/base.py +++ b/python/promplate/llm/base.py @@ -13,30 +13,24 @@ def _config(self): class Complete(Protocol): - def __call__(self, prompt, /, **config) -> str: - ... + def __call__(self, prompt, /, **config) -> str: ... class Generate(Protocol): - def __call__(self, prompt, /, **config) -> Iterable[str]: - ... + def __call__(self, prompt, /, **config) -> Iterable[str]: ... class AsyncComplete(Protocol): - def __call__(self, prompt, /, **config) -> Awaitable[str]: - ... + def __call__(self, prompt, /, **config) -> Awaitable[str]: ... class AsyncGenerate(Protocol): - def __call__(self, prompt, /, **config) -> AsyncIterable[str]: - ... + def __call__(self, prompt, /, **config) -> AsyncIterable[str]: ... class LLM(Protocol): @partial(cast, Complete | AsyncComplete) - def complete(self, prompt, /, **config) -> str | Awaitable[str]: - ... + def complete(self, prompt, /, **config) -> str | Awaitable[str]: ... @partial(cast, Generate | AsyncGenerate) - def generate(self, prompt, /, **config) -> Iterable[str] | AsyncIterable[str]: - ... + def generate(self, prompt, /, **config) -> Iterable[str] | AsyncIterable[str]: ... diff --git a/python/promplate/llm/openai/v0.py b/python/promplate/llm/openai/v0.py index 1f9cca0..1bd922f 100644 --- a/python/promplate/llm/openai/v0.py +++ b/python/promplate/llm/openai/v0.py @@ -44,11 +44,9 @@ def __init__( for key, val in other_config.items(): setattr(self, key, val) - def __setattr__(self, *_): - ... + def __setattr__(self, *_): ... - def __getattr__(self, _): - ... + def __getattr__(self, _): ... else: Config = Configurable diff --git a/python/promplate/llm/openai/v1.py b/python/promplate/llm/openai/v1.py index f5f8f16..7c4d6a0 100644 --- a/python/promplate/llm/openai/v1.py +++ b/python/promplate/llm/openai/v1.py @@ -62,13 +62,11 @@ def _aclient(self): class ClientConfig(Config): @same_params_as(Client) - def __init__(self, **config): - ... + def __init__(self, **config): ... class AsyncClientConfig(Config): @same_params_as(AsyncClient) - def __init__(self, **config): - ... + def __init__(self, **config): ... else: ClientConfig = AsyncClientConfig = Config diff --git a/python/promplate/prompt/template.py b/python/promplate/prompt/template.py index 1a1cb3c..28b5772 100644 --- a/python/promplate/prompt/template.py +++ b/python/promplate/prompt/template.py @@ -10,11 +10,9 @@ class Component(Protocol): - def render(self, context: Context) -> str: - ... + def render(self, context: Context) -> str: ... - async def arender(self, context: Context) -> str: - ... + async def arender(self, context: Context) -> str: ... class TemplateCore(AutoNaming): diff --git a/python/pyproject.toml b/python/pyproject.toml index 2b15ff2..3fdceaa 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -30,7 +30,7 @@ all = ["aiofiles", "httpx", "openai"] [tool.poetry.group.dev.dependencies] isort = "^5" -black = "^23" +black = "^24" pytest = "^7" coverage = "^7" pytest-asyncio = "^0.23"