diff --git a/python/promplate/chain/node.py b/python/promplate/chain/node.py index d9a2854..6ab1973 100644 --- a/python/promplate/chain/node.py +++ b/python/promplate/chain/node.py @@ -1,26 +1,38 @@ from inspect import isclass from itertools import accumulate -from typing import Callable, Mapping, MutableMapping, overload +from typing import Callable, Mapping, MutableMapping, TypeVar, overload from ..llm.base import * from ..prompt.template import Context, Loader, SafeChainMapContext, Template from .callback import BaseCallback, Callback from .utils import accumulate_any, resolve +C = TypeVar("C", bound="ChainContext") + class ChainContext(SafeChainMapContext): @overload - def __init__(self): ... + def __new__(cls): ... @overload - def __init__(self, least: MutableMapping | None = None): ... + def __new__(cls, least: C, *maps: Mapping) -> C: ... @overload - def __init__(self, least: MutableMapping | None = None, *maps: Mapping): ... + def __new__(cls, least: MutableMapping | None = None, *maps: Mapping): ... def __init__(self, least: MutableMapping | None = None, *maps: Mapping): super().__init__({} if least is None else least, *maps) # type: ignore + def __new__(cls, *args, **kwargs): # type: ignore + try: + least = args[0] + except IndexError: + least = kwargs.get("least") + if isinstance(least, cls) and least.__class__ is not cls: + return least.__class__(*args, **kwargs) + + return super().__new__(cls, *args, **kwargs) + @classmethod def ensure(cls, context): return context if isinstance(context, cls) else cls(context) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 024ed38..96ba3fa 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -170,3 +170,14 @@ def test_ensure_method(): assert new_ctx["key"] == "value" assert len(new_ctx.maps) == 1 assert len(new_ctx) == 1 + + +def test_subclass(): + class MyChainContext(ChainContext): + pass + + ctx = MyChainContext() + + assert isinstance(ChainContext(ctx), MyChainContext) + + assert isinstance(ChainContext.ensure(ctx), MyChainContext)