Skip to content

Commit

Permalink
Release v0.9 (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
RLKRo authored Sep 27, 2024
2 parents d428908 + 8aedf8c commit 8ea8c7a
Show file tree
Hide file tree
Showing 56 changed files with 2,545 additions and 3,485 deletions.
7 changes: 5 additions & 2 deletions chatsky/__rebuild_pydantic_models__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# flake8: noqa: F401

from chatsky.core.service.types import ExtraHandlerRuntimeInfo, StartConditionCheckerFunction, ComponentExecutionState
from chatsky.core.service.types import ExtraHandlerRuntimeInfo, ComponentExecutionState
from chatsky.core import Context, Script
from chatsky.core.script import Node
from chatsky.core.pipeline import Pipeline
from chatsky.slots.slots import SlotManager
from chatsky.core.context import FrameworkData
from chatsky.core.context import FrameworkData, ServiceState
from chatsky.core.service import PipelineComponent

PipelineComponent.model_rebuild()
Pipeline.model_rebuild()
Script.model_rebuild()
Context.model_rebuild()
ExtraHandlerRuntimeInfo.model_rebuild()
FrameworkData.model_rebuild()
ServiceState.model_rebuild()
1 change: 1 addition & 0 deletions chatsky/conditions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
HasCallbackQuery,
)
from chatsky.conditions.slots import SlotsExtracted
from chatsky.conditions.service import ServiceFinished
40 changes: 40 additions & 0 deletions chatsky/conditions/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Service Conditions
------------------
Provides service-related conditions
"""

from __future__ import annotations

from chatsky.core.context import Context
from chatsky.core.script_function import BaseCondition

from chatsky.core.service.types import (
ComponentExecutionState,
)


class ServiceFinished(BaseCondition):
"""
Check if a :py:class:`~chatsky.core.service.service.Service` was executed successfully.
"""

path: str
"""The path of the condition pipeline component."""
wait: bool = False
"""
Whether to wait for the service to be finished.
This eliminates possible service states ``NOT_RUN`` and ``RUNNING``.
"""

def __init__(self, path: str, *, wait: bool = False):
super().__init__(path=path, wait=wait)

async def call(self, ctx: Context) -> bool:
if self.wait:
await ctx.framework_data.service_states[self.path].finished_event.wait()

state = ctx.framework_data.service_states[self.path].execution_status

return ComponentExecutionState[state] == ComponentExecutionState.FINISHED
18 changes: 11 additions & 7 deletions chatsky/conditions/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HasText(BaseCondition):
Text to search for in the last request.
"""

def __init__(self, text):
def __init__(self, text: str):
super().__init__(text=text)

async def call(self, ctx: Context) -> bool:
Expand All @@ -94,7 +94,7 @@ class Regexp(BaseCondition):
Flags to pass to ``re.compile``.
"""

def __init__(self, pattern, *, flags=0):
def __init__(self, pattern: Union[str, Pattern], *, flags: Union[int, re.RegexFlag] = 0):
super().__init__(pattern=pattern, flags=flags)

@computed_field
Expand All @@ -120,7 +120,7 @@ class Any(BaseCondition):
List of conditions.
"""

def __init__(self, *conditions):
def __init__(self, *conditions: BaseCondition):
super().__init__(conditions=list(conditions))

async def call(self, ctx: Context) -> bool:
Expand All @@ -137,7 +137,7 @@ class All(BaseCondition):
List of conditions.
"""

def __init__(self, *conditions):
def __init__(self, *conditions: BaseCondition):
super().__init__(conditions=list(conditions))

async def call(self, ctx: Context) -> bool:
Expand All @@ -154,7 +154,7 @@ class Negation(BaseCondition):
Condition to negate.
"""

def __init__(self, condition):
def __init__(self, condition: BaseCondition):
super().__init__(condition=condition)

async def call(self, ctx: Context) -> bool:
Expand Down Expand Up @@ -189,7 +189,11 @@ class CheckLastLabels(BaseCondition):
"""

def __init__(
self, *, flow_labels=None, labels: Optional[List[AbsoluteNodeLabelInitTypes]] = None, last_n_indices=1
self,
*,
flow_labels: Optional[List[str]] = None,
labels: Optional[List[AbsoluteNodeLabelInitTypes]] = None,
last_n_indices: int = 1
):
if flow_labels is None:
flow_labels = []
Expand All @@ -216,7 +220,7 @@ class HasCallbackQuery(BaseCondition):
Query string to find in last request's attachments.
"""

def __init__(self, query_string):
def __init__(self, query_string: str):
super().__init__(query_string=query_string)

async def call(self, ctx: Context) -> bool:
Expand Down
25 changes: 21 additions & 4 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations
import logging
import asyncio
from uuid import UUID, uuid4
from typing import Any, Optional, Union, Dict, TYPE_CHECKING

Expand All @@ -28,9 +29,9 @@
from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes

if TYPE_CHECKING:
from chatsky.core.service import ComponentExecutionState
from chatsky.core.script import Node
from chatsky.core.pipeline import Pipeline
from chatsky.core.service.types import ComponentExecutionState

logger = logging.getLogger(__name__)

Expand All @@ -53,13 +54,29 @@ class ContextError(Exception):
"""Raised when context methods are not used correctly."""


class FrameworkData(BaseModel):
class ServiceState(BaseModel, arbitrary_types_allowed=True):
execution_status: ComponentExecutionState = Field(default="NOT_RUN")
"""
:py:class:`.ComponentExecutionState` of this pipeline service.
Cleared at the end of every turn.
"""
finished_event: asyncio.Event = Field(default_factory=asyncio.Event)
"""
Asyncio `Event` which can be awaited until this service finishes.
Cleared at the end of every turn.
"""


class FrameworkData(BaseModel, arbitrary_types_allowed=True):
"""
Framework uses this to store data related to any of its modules.
"""

service_states: Dict[str, ComponentExecutionState] = Field(default_factory=dict, exclude=True)
"Statuses of all the pipeline services. Cleared at the end of every turn."
service_states: Dict[str, ServiceState] = Field(default_factory=dict, exclude=True)
"""
Dictionary containing :py:class:`.ServiceState` of all the pipeline components.
Cleared at the end of every turn.
"""
current_node: Optional[Node] = Field(default=None, exclude=True)
"""
A copy of the current node provided by :py:meth:`~chatsky.core.script.Script.get_inherited_node`.
Expand Down
84 changes: 5 additions & 79 deletions chatsky/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
from chatsky.slots.slots import GroupSlot
from chatsky.core.service.group import ServiceGroup, ServiceGroupInitTypes
from chatsky.core.service.extra import ComponentExtraHandlerInitTypes, BeforeHandler, AfterHandler
from chatsky.core.service.types import (
GlobalExtraHandlerType,
ExtraHandlerFunction,
)
from .service import Service
from .utils import finalize_service_group
from .utils import finalize_service_group, initialize_service_states
from chatsky.core.service.actor import Actor
from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes
from chatsky.core.script_parsing import JSONImporter, Path
Expand Down Expand Up @@ -104,15 +100,6 @@ class Pipeline(BaseModel, extra="forbid", arbitrary_types_allowed=True):
timeout: Optional[float] = None
"""
Timeout to add to pipeline root service group.
"""
optimization_warnings: bool = False
"""
Asynchronous pipeline optimization check request flag;
warnings will be sent to logs. Additionally, it has some calculated fields:
- `services_pipeline` is a pipeline root :py:class:`~.ServiceGroup` object,
- `actor` is a pipeline actor, found among services.
"""
parallelize_processing: bool = False
"""
Expand All @@ -136,7 +123,6 @@ def __init__(
before_handler: ComponentExtraHandlerInitTypes = None,
after_handler: ComponentExtraHandlerInitTypes = None,
timeout: float = None,
optimization_warnings: bool = None,
parallelize_processing: bool = None,
):
if fallback_label is None:
Expand All @@ -154,7 +140,6 @@ def __init__(
"before_handler": before_handler,
"after_handler": after_handler,
"timeout": timeout,
"optimization_warnings": optimization_warnings,
"parallelize_processing": parallelize_processing,
}
empty_fields = set()
Expand Down Expand Up @@ -216,14 +201,11 @@ def services_pipeline(self) -> PipelineServiceGroup:
after_handler=self.after_handler,
timeout=self.timeout,
)
services_pipeline.name = "pipeline"
services_pipeline.path = ".pipeline"
services_pipeline.name = ""
services_pipeline.path = ""

finalize_service_group(services_pipeline, path=services_pipeline.path)

if self.optimization_warnings:
services_pipeline.log_optimization_warnings()

return services_pipeline

@model_validator(mode="after")
Expand All @@ -240,60 +222,6 @@ def validate_fallback_label(self):
raise ValueError(f"Unknown fallback_label={self.fallback_label}")
return self

def add_global_handler(
self,
global_handler_type: GlobalExtraHandlerType,
extra_handler: ExtraHandlerFunction,
whitelist: Optional[List[str]] = None,
blacklist: Optional[List[str]] = None,
):
"""
Method for adding global wrappers to pipeline.
Different types of global wrappers are called before/after pipeline execution
or before/after each pipeline component.
They can be used for pipeline statistics collection or other functionality extensions.
NB! Global wrappers are still wrappers,
they shouldn't be used for much time-consuming tasks (see :py:mod:`chatsky.core.service.extra`).
:param global_handler_type: (required) indication where the wrapper
function should be executed.
:param extra_handler: (required) wrapper function itself.
:type extra_handler: ExtraHandlerFunction
:param whitelist: a list of services to only add this wrapper to.
:param blacklist: a list of services to not add this wrapper to.
:return: `None`
"""

def condition(name: str) -> bool:
return (whitelist is None or name in whitelist) and (blacklist is None or name not in blacklist)

if (
global_handler_type is GlobalExtraHandlerType.BEFORE_ALL
or global_handler_type is GlobalExtraHandlerType.AFTER_ALL
):
whitelist = ["pipeline"]
global_handler_type = (
GlobalExtraHandlerType.BEFORE
if global_handler_type is GlobalExtraHandlerType.BEFORE_ALL
else GlobalExtraHandlerType.AFTER
)

self.services_pipeline.add_extra_handler(global_handler_type, extra_handler, condition)

@property
def info_dict(self) -> dict:
"""
Property for retrieving info dictionary about this pipeline.
Returns info dict, containing most important component public fields as well as its type.
All complex or unserializable fields here are replaced with 'Instance of [type]'.
"""
return {
"type": type(self).__name__,
"messenger_interface": f"Instance of {type(self.messenger_interface).__name__}",
"context_storage": f"Instance of {type(self.context_storage).__name__}",
"services": [self.services_pipeline.info_dict],
}

async def _run_pipeline(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
) -> Context:
Expand Down Expand Up @@ -329,12 +257,10 @@ async def _run_pipeline(
ctx.framework_data.slot_manager.set_root_slot(self.slots)

ctx.framework_data.pipeline = self
initialize_service_states(ctx, self.services_pipeline)

ctx.add_request(request)
result = await self.services_pipeline(ctx, self)

if asyncio.iscoroutine(result):
await result
await self.services_pipeline(ctx)

ctx.framework_data.service_states.clear()
ctx.framework_data.pipeline = None
Expand Down
2 changes: 1 addition & 1 deletion chatsky/core/script_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def wrapped_call(self, ctx: Context, *, info: str = ""):
logger.debug(f"Function {self.__class__.__name__} returned {result!r}. {info}")
return result
except Exception as exc:
logger.warning(f"An exception occurred in {self.__class__.__name__}. {info}", exc_info=exc)
logger.error(f"An exception occurred in {self.__class__.__name__}. {info}", exc_info=exc)
return exc

async def __call__(self, ctx: Context):
Expand Down
11 changes: 0 additions & 11 deletions chatsky/core/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,14 @@
"""

from .component import PipelineComponent
from .conditions import (
always_start_condition,
service_successful_condition,
not_condition,
all_condition,
any_condition,
)
from .extra import BeforeHandler, AfterHandler
from .group import ServiceGroup
from .service import Service, to_service
from .types import (
ServiceRuntimeInfo,
ExtraHandlerRuntimeInfo,
GlobalExtraHandlerType,
ExtraHandlerType,
PipelineRunnerFunction,
ComponentExecutionState,
StartConditionCheckerFunction,
ExtraHandlerConditionFunction,
ExtraHandlerFunction,
ServiceFunction,
)
Loading

0 comments on commit 8ea8c7a

Please sign in to comment.