Skip to content

Commit

Permalink
replace forwardref with type_checking
Browse files Browse the repository at this point in the history
  • Loading branch information
RLKRo committed Dec 19, 2023
1 parent 677ee7a commit a5b3c22
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 62 deletions.
5 changes: 5 additions & 0 deletions dff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@
import nest_asyncio

nest_asyncio.apply()

from dff.pipeline import Pipeline
from dff.script import Context, Script

Script.model_rebuild()
2 changes: 2 additions & 0 deletions dff/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@
from .service.extra import BeforeHandler, AfterHandler
from .service.group import ServiceGroup
from .service.service import Service, to_service

ExtraHandlerRuntimeInfo.model_rebuild()
6 changes: 4 additions & 2 deletions dff/pipeline/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
are attached should be executed or not.
The standard set of them allows user to setup dependencies between pipeline components.
"""
from typing import Optional, ForwardRef
from __future__ import annotations
from typing import Optional, TYPE_CHECKING

from dff.script import Context

Expand All @@ -16,7 +17,8 @@
StartConditionCheckerAggregationFunction,
)

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


def always_start_condition(_: Context, __: Pipeline) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
.. figure:: /_static/drawio/dfe/user_actor.png
"""
from __future__ import annotations
import logging
import asyncio
from typing import Union, Callable, Optional, Dict, List, ForwardRef
from typing import Union, Callable, Optional, Dict, List, TYPE_CHECKING
import copy

from dff.utils.turn_caching import cache_clear
Expand All @@ -39,7 +40,8 @@

logger = logging.getLogger(__name__)

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


def error_handler(error_msgs: list, msg: str, exception: Optional[Exception] = None, logging_flag: bool = True):
Expand Down
6 changes: 4 additions & 2 deletions dff/pipeline/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
The PipelineComponent class can be a group or a service. It is designed to be reusable and composable,
allowing developers to create complex processing pipelines by combining multiple components.
"""
from __future__ import annotations
import logging
import abc
import asyncio
import copy
from typing import Optional, Awaitable, ForwardRef
from typing import Optional, Awaitable, TYPE_CHECKING

from dff.script import Context

Expand All @@ -31,7 +32,8 @@

logger = logging.getLogger(__name__)

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


class PipelineComponent(abc.ABC):
Expand Down
6 changes: 4 additions & 2 deletions dff/pipeline/service/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
beyond the core functionality. Extra handlers is an input converting addition to :py:class:`.PipelineComponent`.
For example, it is used to grep statistics from components, timing, logging, etc.
"""
from __future__ import annotations
import asyncio
import logging
import inspect
from typing import Optional, List, ForwardRef
from typing import Optional, List, TYPE_CHECKING

from dff.script import Context

Expand All @@ -23,7 +24,8 @@

logger = logging.getLogger(__name__)

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


class _ComponentExtraHandler:
Expand Down
6 changes: 4 additions & 2 deletions dff/pipeline/service/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
allowing for easier management and organization of the services within the pipeline.
The :py:class:`~.ServiceGroup` serves the important function of grouping services to work together in parallel.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional, List, Union, Awaitable, ForwardRef
from typing import Optional, List, Union, Awaitable, TYPE_CHECKING

from dff.script import Context

Expand All @@ -29,7 +30,8 @@

logger = logging.getLogger(__name__)

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


class ServiceGroup(PipelineComponent):
Expand Down
6 changes: 4 additions & 2 deletions dff/pipeline/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
Service can be asynchronous only if its handler is a coroutine.
Actor wrapping service is asynchronous.
"""
from __future__ import annotations
import logging
import inspect
from typing import Optional, ForwardRef
from typing import Optional, TYPE_CHECKING

from dff.script import Context

Expand All @@ -27,7 +28,8 @@

logger = logging.getLogger(__name__)

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


class Service(PipelineComponent):
Expand Down
46 changes: 22 additions & 24 deletions dff/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,22 @@
The classes and special types in this module can include data models,
data structures, and other types that are defined for type hinting.
"""
from __future__ import annotations
from abc import ABC
from enum import unique, Enum
from typing import Callable, Union, Awaitable, Dict, List, Optional, NewType, Iterable, Any, Protocol, Hashable
from typing import Callable, Union, Awaitable, Dict, List, Optional, Iterable, Any, Protocol, Hashable, TYPE_CHECKING

from dff.context_storages import DBContextStorage
from dff.script import Context, ActorStage, NodeLabel2Type, Script, Message
from typing_extensions import NotRequired, TypedDict, TypeAlias
from pydantic import BaseModel


_ForwardPipeline = NewType("Pipeline", Any)
_ForwardPipelineComponent = NewType("PipelineComponent", Any)
_ForwardService = NewType("Service", _ForwardPipelineComponent)
_ForwardServiceBuilder = NewType("ServiceBuilder", Any)
_ForwardServiceGroup = NewType("ServiceGroup", _ForwardPipelineComponent)
_ForwardComponentExtraHandler = NewType("_ComponentExtraHandler", Any)
_ForwardProvider = NewType("ABCProvider", ABC)
_ForwardExtraHandlerRuntimeInfo = NewType("ExtraHandlerRuntimeInfo", Any)
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline
from dff.pipeline.service.service import Service
from dff.pipeline.service.group import ServiceGroup
from dff.pipeline.service.extra import _ComponentExtraHandler
from dff.messengers.common.interface import MessengerInterface


class PipelineRunnerFunction(Protocol):
Expand Down Expand Up @@ -112,7 +110,7 @@ class ExtraHandlerType(str, Enum):
"""


StartConditionCheckerFunction: TypeAlias = Callable[[Context, _ForwardPipeline], bool]
StartConditionCheckerFunction: TypeAlias = Callable[[Context, "Pipeline"], bool]
"""
A function type for components `start_conditions`.
Accepts context and pipeline, returns boolean (whether service can be launched).
Expand Down Expand Up @@ -152,8 +150,8 @@ class ServiceRuntimeInfo(BaseModel):

ExtraHandlerFunction: TypeAlias = Union[
Callable[[Context], Any],
Callable[[Context, _ForwardPipeline], Any],
Callable[[Context, _ForwardPipeline, _ForwardExtraHandlerRuntimeInfo], Any],
Callable[[Context, "Pipeline"], Any],
Callable[[Context, "Pipeline", "ExtraHandlerRuntimeInfo"], Any],
]
"""
A function type for creating wrappers (before and after functions).
Expand All @@ -177,10 +175,10 @@ class ExtraHandlerRuntimeInfo(BaseModel):
ServiceFunction: TypeAlias = Union[
Callable[[Context], None],
Callable[[Context], Awaitable[None]],
Callable[[Context, _ForwardPipeline], None],
Callable[[Context, _ForwardPipeline], Awaitable[None]],
Callable[[Context, _ForwardPipeline, ServiceRuntimeInfo], None],
Callable[[Context, _ForwardPipeline, ServiceRuntimeInfo], Awaitable[None]],
Callable[[Context, "Pipeline"], None],
Callable[[Context, "Pipeline"], Awaitable[None]],
Callable[[Context, "Pipeline", ServiceRuntimeInfo], None],
Callable[[Context, "Pipeline", ServiceRuntimeInfo], Awaitable[None]],
]
"""
A function type for creating service handlers.
Expand All @@ -190,7 +188,7 @@ class ExtraHandlerRuntimeInfo(BaseModel):


ExtraHandlerBuilder: TypeAlias = Union[
_ForwardComponentExtraHandler,
"_ComponentExtraHandler",
TypedDict(
"WrapperDict",
{
Expand All @@ -205,19 +203,19 @@ class ExtraHandlerRuntimeInfo(BaseModel):
A type, representing anything that can be transformed to ExtraHandlers.
It can be:
- _ForwardComponentExtraHandler object
- ExtraHandlerFunction object
- Dictionary, containing keys `timeout`, `asynchronous`, `functions`
"""


ServiceBuilder: TypeAlias = Union[
ServiceFunction,
_ForwardService,
"Service",
str,
TypedDict(
"ServiceDict",
{
"handler": _ForwardServiceBuilder,
"handler": "ServiceBuilder",
"before_handler": NotRequired[Optional[ExtraHandlerBuilder]],
"after_handler": NotRequired[Optional[ExtraHandlerBuilder]],
"timeout": NotRequired[Optional[float]],
Expand All @@ -239,8 +237,8 @@ class ExtraHandlerRuntimeInfo(BaseModel):


ServiceGroupBuilder: TypeAlias = Union[
List[Union[ServiceBuilder, List[ServiceBuilder], _ForwardServiceGroup]],
_ForwardServiceGroup,
List[Union[ServiceBuilder, List[ServiceBuilder], "ServiceGroup"]],
"ServiceGroup",
]
"""
A type, representing anything that can be transformed to service group.
Expand All @@ -254,7 +252,7 @@ class ExtraHandlerRuntimeInfo(BaseModel):
PipelineBuilder: TypeAlias = TypedDict(
"PipelineBuilder",
{
"messenger_interface": NotRequired[Optional[_ForwardProvider]],
"messenger_interface": NotRequired[Optional["MessengerInterface"]],
"context_storage": NotRequired[Optional[Union[DBContextStorage, Dict]]],
"components": ServiceGroupBuilder,
"before_handler": NotRequired[Optional[ExtraHandlerBuilder]],
Expand Down
13 changes: 6 additions & 7 deletions dff/script/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@
The context can be easily serialized to a format that can be stored or transmitted, such as JSON.
This allows developers to save the context data and resume the conversation later.
"""
from __future__ import annotations
import logging
from uuid import UUID, uuid4
from typing import Any, Optional, Union, Dict, List, Set
from typing import Any, Optional, Union, Dict, List, Set, TYPE_CHECKING

from pydantic import BaseModel, Field, field_validator

from .types import NodeLabel2Type, ModuleName
from .message import Message

logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from dff.script.core.script import Node

Node = BaseModel
logger = logging.getLogger(__name__)


def get_last_index(dictionary: dict) -> int:
Expand Down Expand Up @@ -120,7 +122,7 @@ def sort_dict_keys(cls, dictionary: dict) -> dict:
return {key: dictionary[key] for key in sorted(dictionary)}

@classmethod
def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs) -> "Context":
def cast(cls, ctx: Optional[Union[Context, dict, str]] = None, *args, **kwargs) -> Context:
"""
Transform different data types to the objects of the
:py:class:`~.Context` class.
Expand Down Expand Up @@ -277,6 +279,3 @@ def current_node(self) -> Optional[Node]:
)

return node


Context.model_rebuild()
14 changes: 6 additions & 8 deletions dff/script/core/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@
that is suitable for script and actor execution process.
This module contains a basic set of functions for normalizing data in a dialog script.
"""
from __future__ import annotations
import logging

from typing import Union, Callable, Optional, ForwardRef
from typing import Union, Callable, Optional, TYPE_CHECKING

from .keywords import Keywords
from .context import Context
from .types import NodeLabel3Type, NodeLabelType, ConditionType, LabelType
from .message import Message

from pydantic import validate_call
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline

logger = logging.getLogger(__name__)

Pipeline = ForwardRef("Pipeline")


def normalize_label(
label: NodeLabelType, default_flow_label: LabelType = ""
Expand Down Expand Up @@ -83,10 +82,9 @@ def callable_condition_handler(ctx: Context, pipeline: Pipeline) -> bool:
return callable_condition_handler


@validate_call
def normalize_response(
response: Optional[Union[Message, Callable[[Context, Pipeline], Message]]]
) -> Callable[[Context, Pipeline], Message]:
response: Optional[Union[Message, Callable[[Context, "Pipeline"], Message]]]
) -> Callable[[Context, "Pipeline"], Message]:
"""
This function is used to normalize response. If the response is a Callable, it is returned, otherwise
the response is wrapped in an asynchronous function and this function is returned.
Expand Down
17 changes: 8 additions & 9 deletions dff/script/core/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@
the user's input and the current state of the conversation.
"""
# %%

from __future__ import annotations
import logging
from typing import Callable, Optional, Any, Dict, Union
from typing import Callable, Optional, Any, Dict, Union, TYPE_CHECKING

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, validate_call

from .types import LabelType, NodeLabelType, ConditionType, NodeLabel3Type
from .message import Message
from .keywords import Keywords
from .normalization import normalize_condition, normalize_label, validate_call
from typing import ForwardRef

logger = logging.getLogger(__name__)
from .normalization import normalize_condition, normalize_label

if TYPE_CHECKING:
from dff.script.core.context import Context
from dff.pipeline.pipeline.pipeline import Pipeline

Pipeline = ForwardRef("Pipeline")
Context = ForwardRef("Context")
logger = logging.getLogger(__name__)


class Node(BaseModel, extra="forbid", validate_assignment=True):
Expand Down
6 changes: 4 additions & 2 deletions dff/script/labels/std_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
This module contains a standard set of scripting :py:const:`labels <dff.script.NodeLabelType>` that
can be used by developers to define the conversation flow.
"""
from typing import Optional, Callable, ForwardRef
from __future__ import annotations
from typing import Optional, Callable, TYPE_CHECKING
from dff.script import Context, NodeLabel3Type

Pipeline = ForwardRef("Pipeline")
if TYPE_CHECKING:
from dff.pipeline.pipeline.pipeline import Pipeline


def repeat(priority: Optional[float] = None) -> Callable:
Expand Down

0 comments on commit a5b3c22

Please sign in to comment.