Skip to content

Commit

Permalink
feat: preserve type for ChainContext constructed by a subclass inst…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
CNSeniorious000 committed May 19, 2024
1 parent bad4024 commit 4e6fe29
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
20 changes: 16 additions & 4 deletions python/promplate/chain/node.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from inspect import isclass
from itertools import accumulate
from typing import Callable, Mapping, MutableMapping, overload
from typing import Callable, Mapping, MutableMapping, TypeVar, overload

from ..llm.base import *
from ..prompt.template import Context, Loader, SafeChainMapContext, Template
from .callback import BaseCallback, Callback
from .utils import accumulate_any, resolve

C = TypeVar("C", bound="ChainContext")


class ChainContext(SafeChainMapContext):
@overload
def __init__(self): ...
def __new__(cls): ...

@overload
def __init__(self, least: MutableMapping | None = None): ...
def __new__(cls, least: C, *maps: Mapping) -> C: ...

@overload
def __init__(self, least: MutableMapping | None = None, *maps: Mapping): ...
def __new__(cls, least: MutableMapping | None = None, *maps: Mapping): ...

def __init__(self, least: MutableMapping | None = None, *maps: Mapping):
super().__init__({} if least is None else least, *maps) # type: ignore

def __new__(cls, *args, **kwargs): # type: ignore
try:
least = args[0]
except IndexError:
least = kwargs.get("least")
if isinstance(least, cls) and least.__class__ is not cls:
return least.__class__(*args, **kwargs)

return super().__new__(cls, *args, **kwargs)

@classmethod
def ensure(cls, context):
return context if isinstance(context, cls) else cls(context)
Expand Down
11 changes: 11 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,14 @@ def test_ensure_method():
assert new_ctx["key"] == "value"
assert len(new_ctx.maps) == 1
assert len(new_ctx) == 1


def test_subclass():
class MyChainContext(ChainContext):
pass

ctx = MyChainContext()

assert isinstance(ChainContext(ctx), MyChainContext)

assert isinstance(ChainContext.ensure(ctx), MyChainContext)

0 comments on commit 4e6fe29

Please sign in to comment.