From 922f8a96be086a92e86f56f9cb826707cb92c3ca Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Fri, 14 Jul 2023 16:19:50 +0300 Subject: [PATCH 1/8] initial commit --- dff/context_storages/json.py | 23 ++++---- dff/context_storages/mongo.py | 2 +- dff/context_storages/redis.py | 2 +- dff/context_storages/sql.py | 2 +- dff/context_storages/ydb.py | 2 +- dff/messengers/telegram/message.py | 25 +++++---- dff/pipeline/types.py | 2 +- dff/script/__init__.py | 2 - dff/script/conditions/std_conditions.py | 26 ++++----- dff/script/core/context.py | 54 +++++++++---------- dff/script/core/message.py | 36 +++++++------ dff/script/core/normalization.py | 37 ++----------- dff/script/core/script.py | 48 ++++++++++++++--- dff/utils/testing/telegram.py | 2 +- tests/context_storages/test_dbs.py | 5 +- tests/script/core/test_context.py | 4 +- tests/script/core/test_normalization.py | 8 +-- .../3_load_testing_with_locust.py | 8 +-- .../web_api_interface/4_streamlit_chat.py | 6 ++- .../pipeline/4_groups_and_conditions_full.py | 3 +- .../script/core/6_context_serialization.py | 4 +- 21 files changed, 153 insertions(+), 148 deletions(-) diff --git a/dff/context_storages/json.py b/dff/context_storages/json.py index 52ea29560..4a56b7ff5 100644 --- a/dff/context_storages/json.py +++ b/dff/context_storages/json.py @@ -16,14 +16,15 @@ except ImportError: json_available = False -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, model_validator from .database import DBContextStorage, threadsafe_method from dff.script import Context -class SerializableStorage(BaseModel, extra=Extra.allow): - @root_validator +class SerializableStorage(BaseModel, extra="allow"): + @model_validator(mode="before") + @classmethod def validate_any(cls, vals): for key, value in vals.items(): vals[key] = Context.cast(value) @@ -43,36 +44,36 @@ def __init__(self, path: str): @threadsafe_method async def len_async(self) -> int: - return len(self.storage.__dict__) + return len(self.storage.model_extra) @threadsafe_method async def set_item_async(self, key: Hashable, value: Context): - self.storage.__dict__.__setitem__(str(key), value) + self.storage.model_extra.__setitem__(str(key), value) await self._save() @threadsafe_method async def get_item_async(self, key: Hashable) -> Context: await self._load() - return Context.cast(self.storage.__dict__.__getitem__(str(key))) + return Context.cast(self.storage.model_extra.__getitem__(str(key))) @threadsafe_method async def del_item_async(self, key: Hashable): - self.storage.__dict__.__delitem__(str(key)) + self.storage.model_extra.__delitem__(str(key)) await self._save() @threadsafe_method async def contains_async(self, key: Hashable) -> bool: await self._load() - return self.storage.__dict__.__contains__(str(key)) + return self.storage.model_extra.__contains__(str(key)) @threadsafe_method async def clear_async(self): - self.storage.__dict__.clear() + self.storage.model_extra.clear() await self._save() async def _save(self): async with aiofiles.open(self.path, "w+", encoding="utf-8") as file_stream: - await file_stream.write(self.storage.json()) + await file_stream.write(self.storage.model_dump_json()) async def _load(self): if not await aiofiles.os.path.isfile(self.path) or (await aiofiles.os.stat(self.path)).st_size == 0: @@ -80,4 +81,4 @@ async def _load(self): await self._save() else: async with aiofiles.open(self.path, "r", encoding="utf-8") as file_stream: - self.storage = SerializableStorage.parse_raw(await file_stream.read()) + self.storage = SerializableStorage.model_validate_json(await file_stream.read()) diff --git a/dff/context_storages/mongo.py b/dff/context_storages/mongo.py index 0df2a8715..1afeeda84 100644 --- a/dff/context_storages/mongo.py +++ b/dff/context_storages/mongo.py @@ -60,7 +60,7 @@ def _adjust_key(key: Hashable) -> Dict[str, ObjectId]: async def set_item_async(self, key: Hashable, value: Context): new_key = self._adjust_key(key) value = value if isinstance(value, Context) else Context.cast(value) - document = json.loads(value.json()) + document = json.loads(value.model_dump_json()) document.update(new_key) await self.collection.replace_one(new_key, document, upsert=True) diff --git a/dff/context_storages/redis.py b/dff/context_storages/redis.py index d2014eb89..d506a1d47 100644 --- a/dff/context_storages/redis.py +++ b/dff/context_storages/redis.py @@ -49,7 +49,7 @@ async def contains_async(self, key: Hashable) -> bool: @threadsafe_method async def set_item_async(self, key: Hashable, value: Context): value = value if isinstance(value, Context) else Context.cast(value) - await self._redis.set(str(key), value.json()) + await self._redis.set(str(key), value.model_dump_json()) @threadsafe_method async def get_item_async(self, key: Hashable) -> Context: diff --git a/dff/context_storages/sql.py b/dff/context_storages/sql.py index 99f72a9bb..9bc5c7822 100644 --- a/dff/context_storages/sql.py +++ b/dff/context_storages/sql.py @@ -115,7 +115,7 @@ def __init__(self, path: str, table_name: str = "contexts", custom_driver: bool @threadsafe_method async def set_item_async(self, key: Hashable, value: Context): value = value if isinstance(value, Context) else Context.cast(value) - value = json.loads(value.json()) + value = json.loads(value.model_dump_json()) insert_stmt = insert(self.table).values(id=str(key), context=value) update_stmt = await self._get_update_stmt(insert_stmt) diff --git a/dff/context_storages/ydb.py b/dff/context_storages/ydb.py index c1899a001..f499592c6 100644 --- a/dff/context_storages/ydb.py +++ b/dff/context_storages/ydb.py @@ -74,7 +74,7 @@ async def callee(session): await session.transaction(ydb.SerializableReadWrite()).execute( prepared_query, - {"$queryId": str(key), "$queryContext": value.json()}, + {"$queryId": str(key), "$queryContext": value.model_dump_json()}, commit_tx=True, ) diff --git a/dff/messengers/telegram/message.py b/dff/messengers/telegram/message.py index 3b9db63d1..fc652a146 100644 --- a/dff/messengers/telegram/message.py +++ b/dff/messengers/telegram/message.py @@ -22,7 +22,8 @@ ChatJoinRequest, ) -from dff.script.core.message import Message, Location, Keyboard, DataModel, root_validator, ValidationError +from dff.script.core.message import Message, Location, Keyboard, DataModel +from pydantic import model_validator class TelegramUI(Keyboard): @@ -34,13 +35,14 @@ class TelegramUI(Keyboard): row_width: int = 3 """Limits the maximum number of buttons in a row.""" - @root_validator - def validate_buttons(cls, values): - if not values.get("is_inline"): - for button in values.get("buttons"): - if button.payload is not None or button.source is not None: - raise ValidationError(f"`payload` and `source` are only used for inline keyboards: {button}") - return values + @model_validator(mode="after") + def validate_buttons(self, _): + if not self.is_inline: + for button in self.buttons: + assert ( + button.payload is None and button.source is None + ), f"`payload` and `source` are only used for inline keyboards: {button}" + return self class _ClickButton(DataModel): @@ -66,9 +68,6 @@ class ParseMode(Enum): class TelegramMessage(Message): - class Config: - smart_union = True - ui: Optional[ Union[TelegramUI, RemoveKeyboard, ReplyKeyboardRemove, ReplyKeyboardMarkup, InlineKeyboardMarkup] ] = None @@ -97,9 +96,9 @@ class Config: def __eq__(self, other): if isinstance(other, Message): - for field in self.__fields__: + for field in self.model_fields: if field not in ("parse_mode", "update_id", "update", "update_type"): - if field not in other.__fields__: + if field not in other.model_fields: return False if self.__getattribute__(field) != other.__getattribute__(field): return False diff --git a/dff/pipeline/types.py b/dff/pipeline/types.py index 0a1e679fe..510976ad5 100644 --- a/dff/pipeline/types.py +++ b/dff/pipeline/types.py @@ -112,7 +112,7 @@ class ExtraHandlerType(str, Enum): class ServiceRuntimeInfo(BaseModel): name: str path: str - timeout: Optional[float] + timeout: Optional[float] = None asynchronous: bool execution_state: Dict[str, ComponentExecutionState] diff --git a/dff/script/__init__.py b/dff/script/__init__.py index 6c88b5461..74399f58a 100644 --- a/dff/script/__init__.py +++ b/dff/script/__init__.py @@ -15,10 +15,8 @@ from .core.normalization import ( normalize_label, normalize_condition, - normalize_transitions, normalize_response, normalize_processing, - normalize_script, ) from .core.script import Node, Script from .core.types import ( diff --git a/dff/script/conditions/std_conditions.py b/dff/script/conditions/std_conditions.py index dd0330bc4..0dddb7266 100644 --- a/dff/script/conditions/std_conditions.py +++ b/dff/script/conditions/std_conditions.py @@ -12,7 +12,7 @@ import logging import re -from pydantic import validate_arguments +from pydantic import validate_call from dff.pipeline import Pipeline from dff.script import NodeLabel2Type, Context, Message @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -@validate_arguments +@validate_call def exact_match(match: Message, skip_none: bool = True, *args, **kwargs) -> Callable[..., bool]: """ Return function handler. This handler returns `True` only if the last user phrase @@ -35,11 +35,11 @@ def exact_match_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwa request = ctx.last_request if request is None: return False - for field in match.__fields__: + for field in match.model_fields: match_value = match.__getattribute__(field) if skip_none and match_value is None: continue - if field in request.__fields__.keys(): + if field in request.model_fields.keys(): if request.__getattribute__(field) != match.__getattribute__(field): return False else: @@ -49,7 +49,7 @@ def exact_match_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwa return exact_match_condition_handler -@validate_arguments +@validate_call def regexp( pattern: Union[str, Pattern], flags: Union[int, re.RegexFlag] = 0, *args, **kwargs ) -> Callable[[Context, Pipeline, Any, Any], bool]: @@ -75,7 +75,7 @@ def regexp_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) return regexp_condition_handler -@validate_arguments +@validate_call def check_cond_seq(cond_seq: list): """ Check if the list consists only of Callables. @@ -97,7 +97,7 @@ def check_cond_seq(cond_seq: list): """ -@validate_arguments +@validate_call def aggregate( cond_seq: list, aggregate_func: Callable = _any, *args, **kwargs ) -> Callable[[Context, Pipeline, Any, Any], bool]: @@ -119,7 +119,7 @@ def aggregate_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwarg return aggregate_condition_handler -@validate_arguments +@validate_call def any(cond_seq: list, *args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]: """ Return function handler. This handler returns `True` @@ -135,7 +135,7 @@ def any_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> return any_condition_handler -@validate_arguments +@validate_call def all(cond_seq: list, *args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]: """ Return function handler. This handler returns `True` only @@ -151,7 +151,7 @@ def all_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> return all_condition_handler -@validate_arguments +@validate_call def negation(condition: Callable, *args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]: """ Return function handler. This handler returns negation of the :py:func:`~condition`: `False` @@ -166,7 +166,7 @@ def negation_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs return negation_condition_handler -@validate_arguments +@validate_call def has_last_labels( flow_labels: Optional[List[str]] = None, labels: Optional[List[NodeLabel2Type]] = None, @@ -198,7 +198,7 @@ def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline, *args, * return has_last_labels_condition_handler -@validate_arguments +@validate_call def true(*args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]: """ Return function handler. This handler always returns `True`. @@ -210,7 +210,7 @@ def true_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: return true_handler -@validate_arguments +@validate_call def false(*args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]: """ Return function handler. This handler always returns `False`. diff --git a/dff/script/core/context.py b/dff/script/core/context.py index 17cf70f71..df552b955 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -21,7 +21,7 @@ from typing import Any, Optional, Union, Dict, List, Set -from pydantic import BaseModel, validate_arguments, Field, validator +from pydantic import ConfigDict, BaseModel, validate_call, Field, field_validator from .types import NodeLabel2Type, ModuleName from .message import Message @@ -30,19 +30,7 @@ Node = BaseModel -@validate_arguments -def sort_dict_keys(dictionary: dict) -> dict: - """ - Sorting the keys in the `dictionary`. This needs to be done after deserialization, - since the keys are deserialized in a random order. - - :param dictionary: Dictionary with unsorted keys. - :return: Dictionary with sorted keys. - """ - return {key: dictionary[key] for key in sorted(dictionary)} - - -@validate_arguments +@validate_call def get_last_index(dictionary: dict) -> int: """ Obtaining the last index from the `dictionary`. Functions returns `-1` if the `dict` is empty. @@ -59,11 +47,12 @@ class Context(BaseModel): A structure that is used to store data about the context of a dialog. """ - class Config: - property_set_methods = { + model_config = ConfigDict( + property_set_methods={ "last_response": "set_last_response", "last_request": "set_last_request", } + ) id: Union[UUID, int, str] = Field(default_factory=uuid4) """ @@ -121,10 +110,17 @@ class Config: - value - Temporary variable data. """ - # validators - _sort_labels = validator("labels", allow_reuse=True)(sort_dict_keys) - _sort_requests = validator("requests", allow_reuse=True)(sort_dict_keys) - _sort_responses = validator("responses", allow_reuse=True)(sort_dict_keys) + @field_validator("labels", "requests", "responses") + @classmethod + def sort_dict_keys(cls, dictionary: dict) -> dict: + """ + Sorting the keys in the `dictionary`. This needs to be done after deserialization, + since the keys are deserialized in a random order. + + :param dictionary: Dictionary with unsorted keys. + :return: Dictionary with sorted keys. + """ + return {key: dictionary[key] for key in sorted(dictionary)} @classmethod def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs) -> "Context": @@ -144,16 +140,16 @@ def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs if not ctx: ctx = Context(*args, **kwargs) elif isinstance(ctx, dict): - ctx = Context.parse_obj(ctx) + ctx = Context.model_validate(ctx) elif isinstance(ctx, str): - ctx = Context.parse_raw(ctx) + ctx = Context.model_validate_json(ctx) elif not issubclass(type(ctx), Context): raise ValueError( f"context expected as sub class of Context class or object of dict/str(json) type, but got {ctx}" ) return ctx - @validate_arguments + # @validate_call def add_request(self, request: Message): """ Adds to the context the next `request` corresponding to the next turn. @@ -164,7 +160,7 @@ def add_request(self, request: Message): last_index = get_last_index(self.requests) self.requests[last_index + 1] = request - @validate_arguments + # @validate_call def add_response(self, response: Message): """ Adds to the context the next `response` corresponding to the next turn. @@ -175,7 +171,7 @@ def add_response(self, response: Message): last_index = get_last_index(self.responses) self.responses[last_index + 1] = response - @validate_arguments + # @validate_call def add_label(self, label: NodeLabel2Type): """ Adds to the context the next :py:const:`label `, @@ -187,7 +183,7 @@ def add_label(self, label: NodeLabel2Type): last_index = get_last_index(self.labels) self.labels[last_index + 1] = label - @validate_arguments + # @validate_call def clear( self, hold_last_n_indices: int, @@ -282,7 +278,7 @@ def current_node(self) -> Optional[Node]: return node - @validate_arguments + # @validate_call def overwrite_current_node_in_processing(self, processed_node: Node): """ Overwrites the current node with a processed node. This method only works in processing functions. @@ -299,11 +295,11 @@ def overwrite_current_node_in_processing(self, processed_node: Node): ) def __setattr__(self, key, val): - method = self.__config__.property_set_methods.get(key, None) + method = self.model_config.get("property_set_methods", {}).get(key, None) if method is None: super().__setattr__(key, val) else: getattr(self, method)(val) -Context.update_forward_refs() +Context.model_rebuild() diff --git a/dff/script/core/message.py b/dff/script/core/message.py index c26a29db4..5b3a1a0e8 100644 --- a/dff/script/core/message.py +++ b/dff/script/core/message.py @@ -9,8 +9,8 @@ from pathlib import Path from urllib.request import urlopen -from pydantic import Extra, Field, ValidationError, FilePath, HttpUrl, BaseModel -from pydantic import validator, root_validator +from pydantic import field_validator, ConfigDict, Field, FilePath, HttpUrl, BaseModel +from pydantic import model_validator class Session(Enum): @@ -27,9 +27,7 @@ class DataModel(BaseModel): This class is a Pydantic BaseModel that serves as a base class for all DFF models. """ - class Config: - extra = Extra.allow - arbitrary_types_allowed = True + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) class Command(DataModel): @@ -72,11 +70,11 @@ class Attachment(DataModel): def get_bytes(self) -> Optional[bytes]: if self.source is None: return None - if isinstance(self.source, HttpUrl): - with urlopen(self.source) as file: + if isinstance(self.source, Path): + with open(self.source, "rb") as file: return file.read() else: - with open(self.source, "rb") as file: + with urlopen(self.source.unicode_string()) as file: return file.read() def __eq__(self, other): @@ -88,13 +86,17 @@ def __eq__(self, other): return self.get_bytes() == other.get_bytes() return NotImplemented - @root_validator - def validate_source_or_id(cls, values): - if bool(values.get("source")) == bool(values.get("id")): - raise ValidationError("Attachment type requires exactly one parameter, `source` or `id`, to be set.") + @model_validator(mode="before") + @classmethod + def validate_source_or_id(cls, values: dict): + assert isinstance(values, dict) + assert bool(values.get("source")) != bool( + values.get("id") + ), "Attachment type requires exactly one parameter, `source` or `id`, to be set." return values - @validator("source") + @field_validator("source", mode="before") + @classmethod def validate_source(cls, value): if isinstance(value, Path): return Path(value) @@ -175,7 +177,7 @@ class Keyboard(DataModel): that can be used for a chatbot or messaging application. """ - buttons: List[Button] = Field(default_factory=list, min_items=1) + buttons: List[Button] = Field(default_factory=list, min_length=1) def __eq__(self, other): if isinstance(other, Keyboard): @@ -201,8 +203,8 @@ class level variables to store message information. def __eq__(self, other): if isinstance(other, Message): - for field in self.__fields__: - if field not in other.__fields__: + for field in self.model_fields: + if field not in other.model_fields: return False if self.__getattribute__(field) != other.__getattribute__(field): return False @@ -210,7 +212,7 @@ def __eq__(self, other): return NotImplemented def __repr__(self) -> str: - return " ".join([f"{key}='{value}'" for key, value in self.dict(exclude_none=True).items()]) + return " ".join([f"{key}='{value}'" for key, value in self.model_dump(exclude_none=True).items()]) class MultiMessage(Message): diff --git a/dff/script/core/normalization.py b/dff/script/core/normalization.py index 9aeeaa79c..a0f91407f 100644 --- a/dff/script/core/normalization.py +++ b/dff/script/core/normalization.py @@ -14,7 +14,7 @@ from .types import NodeLabel3Type, NodeLabelType, ConditionType, LabelType from .message import Message -from pydantic import validate_arguments +from pydantic import validate_call logger = logging.getLogger(__name__) @@ -81,21 +81,7 @@ def callable_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs return callable_condition_handler -@validate_arguments -def normalize_transitions( - transitions: Dict[NodeLabelType, ConditionType] -) -> Dict[Union[Callable, NodeLabel3Type], Callable]: - """ - The function which is used to normalize transitions and returns normalized dict. - - :param transitions: Transitions to normalize. - :return: Transitions with normalized label and condition. - """ - transitions = {normalize_label(label): normalize_condition(condition) for label, condition in transitions.items()} - return transitions - - -@validate_arguments +@validate_call def normalize_response(response: Optional[Union[Message, Callable[..., Message]]]) -> Callable[..., Message]: """ This function is used to normalize response, if response Callable, it is returned, otherwise @@ -120,7 +106,7 @@ def response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs): return response_handler -@validate_arguments +@validate_call def normalize_processing(processing: Dict[Any, Callable]) -> Callable: """ This function is used to normalize processing. @@ -144,20 +130,3 @@ def processing_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Con return ctx return processing_handler - - -@validate_arguments -def normalize_script(script: Dict[LabelType, Any]) -> Dict[LabelType, Dict[LabelType, Dict[str, Any]]]: - """ - This function normalizes :py:class:`.Script`: it returns dict where the GLOBAL node is moved - into the flow with the GLOBAL name. The function returns the structure - - `{GLOBAL: {...NODE...}, ...}` -> `{GLOBAL: {GLOBAL: {...NODE...}}, ...}`. - - :param script: :py:class:`.Script` that describes the dialog scenario. - :return: Normalized :py:class:`.Script`. - """ - if isinstance(script, dict): - if Keywords.GLOBAL in script and all([isinstance(item, Keywords) for item in script[Keywords.GLOBAL].keys()]): - script[Keywords.GLOBAL] = {Keywords.GLOBAL: script[Keywords.GLOBAL]} - return script diff --git a/dff/script/core/script.py b/dff/script/core/script.py index 0499afdce..7259fe8b7 100644 --- a/dff/script/core/script.py +++ b/dff/script/core/script.py @@ -10,11 +10,12 @@ import logging from typing import Callable, Optional, Any, Dict, Union -from pydantic import BaseModel, validator, Extra +from pydantic import BaseModel, field_validator -from .types import LabelType, NodeLabelType, ConditionType +from .types import LabelType, NodeLabelType, ConditionType, NodeLabel3Type from .message import Message -from .normalization import normalize_response, normalize_processing, normalize_transitions, normalize_script +from .keywords import Keywords +from .normalization import normalize_response, normalize_processing, normalize_condition, normalize_label, validate_call from typing import ForwardRef logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ Context = ForwardRef("Context") -class Node(BaseModel, extra=Extra.forbid): +class Node(BaseModel, extra="forbid"): """ The class for the `Node` object. """ @@ -35,7 +36,22 @@ class Node(BaseModel, extra=Extra.forbid): pre_response_processing: Dict[Any, Callable] = {} misc: dict = {} - _normalize_transitions = validator("transitions", allow_reuse=True)(normalize_transitions) + @field_validator("transitions", mode="before") + @classmethod + @validate_call + def normalize_transitions( + _, transitions: Dict[NodeLabelType, ConditionType] + ) -> Dict[Union[Callable, NodeLabel3Type], Callable]: + """ + The function which is used to normalize transitions and returns normalized dict. + + :param transitions: Transitions to normalize. + :return: Transitions with normalized label and condition. + """ + transitions = { + normalize_label(label): normalize_condition(condition) for label, condition in transitions.items() + } + return transitions def run_response(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: """ @@ -68,14 +84,32 @@ def run_processing( return processing(ctx, pipeline, *args, **kwargs) -class Script(BaseModel, extra=Extra.forbid): +class Script(BaseModel, extra="forbid"): """ The class for the `Script` object. """ script: Dict[LabelType, Dict[LabelType, Node]] - _normalize_script = validator("script", allow_reuse=True, pre=True)(normalize_script) + @field_validator("script", mode="before") + @classmethod + @validate_call + def normalize_script(cls, script: Dict[LabelType, Any]) -> Dict[LabelType, Dict[LabelType, Dict[str, Any]]]: + """ + This function normalizes :py:class:`.Script`: it returns dict where the GLOBAL node is moved + into the flow with the GLOBAL name. The function returns the structure + + `{GLOBAL: {...NODE...}, ...}` -> `{GLOBAL: {GLOBAL: {...NODE...}}, ...}`. + + :param script: :py:class:`.Script` that describes the dialog scenario. + :return: Normalized :py:class:`.Script`. + """ + if isinstance(script, dict): + if Keywords.GLOBAL in script and all( + [isinstance(item, Keywords) for item in script[Keywords.GLOBAL].keys()] + ): + script[Keywords.GLOBAL] = {Keywords.GLOBAL: script[Keywords.GLOBAL]} + return script def __getitem__(self, key): return self.script[key] diff --git a/dff/utils/testing/telegram.py b/dff/utils/testing/telegram.py index 46df9212d..8fe110be4 100644 --- a/dff/utils/testing/telegram.py +++ b/dff/utils/testing/telegram.py @@ -275,4 +275,4 @@ async def null(): last_message = bot_messages[0] logging.info("Got responses") result = await self.parse_responses(bot_messages, file_download_destination) - assert result == TelegramMessage.parse_obj(response) + assert result == TelegramMessage.model_validate(response) diff --git a/tests/context_storages/test_dbs.py b/tests/context_storages/test_dbs.py index 3d48ff2af..bd8624bc0 100644 --- a/tests/context_storages/test_dbs.py +++ b/tests/context_storages/test_dbs.py @@ -75,7 +75,10 @@ def generic_test(db, testing_context, context_id): # test read operations new_ctx = db[context_id] assert isinstance(new_ctx, Context) - assert {**new_ctx.dict(), "id": str(new_ctx.id)} == {**testing_context.dict(), "id": str(testing_context.id)} + assert {**new_ctx.model_dump(), "id": str(new_ctx.id)} == { + **testing_context.model_dump(), + "id": str(testing_context.id), + } # test delete operations del db[context_id] assert context_id not in db diff --git a/tests/script/core/test_context.py b/tests/script/core/test_context.py index f0f3f3917..bf7d0c4d6 100644 --- a/tests/script/core/test_context.py +++ b/tests/script/core/test_context.py @@ -17,7 +17,7 @@ def test_context(): ctx.labels = shuffle_dict_keys(ctx.labels) ctx.requests = shuffle_dict_keys(ctx.requests) ctx.responses = shuffle_dict_keys(ctx.responses) - ctx = Context.cast(ctx.json()) + ctx = Context.cast(ctx.model_dump_json()) ctx.misc[123] = 312 ctx.clear(5, ["requests", "responses", "misc", "labels", "framework_states"]) ctx.misc[1001] = "11111" @@ -52,7 +52,7 @@ def test_context(): assert ctx.misc == {1001: "11111"} assert ctx.current_node is None ctx.overwrite_current_node_in_processing(Node(**{"response": Message(text="text")})) - ctx.json() + ctx.model_dump_json() try: Context.cast(123) diff --git a/tests/script/core/test_normalization.py b/tests/script/core/test_normalization.py index e2df4b526..bbe6309c0 100644 --- a/tests/script/core/test_normalization.py +++ b/tests/script/core/test_normalization.py @@ -10,6 +10,8 @@ PRE_RESPONSE_PROCESSING, PRE_TRANSITIONS_PROCESSING, Context, + Script, + Node, NodeLabel3Type, Message, ) @@ -19,10 +21,8 @@ from dff.script import ( normalize_condition, normalize_label, - normalize_script, normalize_processing, normalize_response, - normalize_transitions, ) @@ -80,7 +80,7 @@ def false_condition_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> b def test_normalize_transitions(): - trans = normalize_transitions({("flow", "node", 1.0): std_func}) + trans = Node.normalize_transitions({("flow", "node", 1.0): std_func}) assert list(trans)[0] == ("flow", "node", 1.0) assert callable(list(trans.values())[0]) @@ -148,7 +148,7 @@ def test_normalize_script(): MISC.name.lower(): {"key": "val"}, } script = {GLOBAL: node_template.copy(), "flow": {"node": node_template.copy()}} - script = normalize_script(script) + script = Script.normalize_script(script) assert isinstance(script, dict) assert script[GLOBAL][GLOBAL] == node_template_gold assert script["flow"]["node"] == node_template_gold diff --git a/tutorials/messengers/web_api_interface/3_load_testing_with_locust.py b/tutorials/messengers/web_api_interface/3_load_testing_with_locust.py index 9cde69fe3..57c9b18ea 100644 --- a/tutorials/messengers/web_api_interface/3_load_testing_with_locust.py +++ b/tutorials/messengers/web_api_interface/3_load_testing_with_locust.py @@ -92,7 +92,7 @@ def check_happy_path(self, happy_path): data=request.json(), catch_response=True, ) as candidate_response: - text_response = Message.parse_obj(candidate_response.json().get("response")) + text_response = Message.model_validate(candidate_response.json().get("response")) if response is not None: if callable(response): @@ -101,7 +101,7 @@ def check_happy_path(self, happy_path): candidate_response.failure(error_message) elif text_response != response: candidate_response.failure( - f"Expected: {response.json()}\nGot: {text_response.json()}" + f"Expected: {response.model_dump_json()}\nGot: {text_response.model_dump_json()}" ) time.sleep(self.wait_time()) @@ -114,9 +114,9 @@ def dialog_1(self): def dialog_2(self): def check_first_message(msg: Message) -> str | None: if msg.text is None: - return f"Message does not contain text: {msg.json()}" + return f"Message does not contain text: {msg.model_dump_json()}" if "Hi" not in msg.text: - return f'"Hi" is not in the response message: {msg.json()}' + return f'"Hi" is not in the response message: {msg.model_dump_json()}' return None self.check_happy_path( diff --git a/tutorials/messengers/web_api_interface/4_streamlit_chat.py b/tutorials/messengers/web_api_interface/4_streamlit_chat.py index 636934f25..ae46f6e39 100644 --- a/tutorials/messengers/web_api_interface/4_streamlit_chat.py +++ b/tutorials/messengers/web_api_interface/4_streamlit_chat.py @@ -115,10 +115,12 @@ def send_and_receive(): st.session_state["user_requests"].append(user_request) - bot_response = query(Message(text=user_request).dict(), user_id=st.session_state["user_id"]) + bot_response = query( + Message(text=user_request).model_dump(), user_id=st.session_state["user_id"] + ) bot_response.raise_for_status() - bot_message = Message.parse_obj(bot_response.json()["response"]).text + bot_message = Message.model_validate(bot_response.json()["response"]).text # # Implementation without using Message: # bot_response = query({"text": user_request}, user_id=st.session_state["user_id"]) diff --git a/tutorials/pipeline/4_groups_and_conditions_full.py b/tutorials/pipeline/4_groups_and_conditions_full.py index 12d69bed1..585471bdd 100644 --- a/tutorials/pipeline/4_groups_and_conditions_full.py +++ b/tutorials/pipeline/4_groups_and_conditions_full.py @@ -150,7 +150,8 @@ def never_running_service(_, __, info: ServiceRuntimeInfo): def runtime_info_printing_service(_, __, info: ServiceRuntimeInfo): logger.info( - f"Service '{info.name}' runtime execution info:" f"{info.json(indent=4, default=str)}" + f"Service '{info.name}' runtime execution info:" + f"{info.model_dump_json(indent=4, default=str)}" ) diff --git a/tutorials/script/core/6_context_serialization.py b/tutorials/script/core/6_context_serialization.py index 87a7fcca0..9c9f36539 100644 --- a/tutorials/script/core/6_context_serialization.py +++ b/tutorials/script/core/6_context_serialization.py @@ -59,13 +59,13 @@ def response_handler(ctx: Context, _: Pipeline, *args, **kwargs) -> Message: # %% def process_response(ctx: Context): - ctx_json = ctx.json() + ctx_json = ctx.model_dump_json() if isinstance(ctx_json, str): logging.info("context serialized to json str") else: raise Exception(f"ctx={ctx_json} has to be serialized to json string") - ctx_dict = ctx.dict() + ctx_dict = ctx.model_dump() if isinstance(ctx_dict, dict): logging.info("context serialized to dict") else: From cbce46d8b21f6669b8b1803f655f110dc97872c6 Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Fri, 14 Jul 2023 16:26:54 +0300 Subject: [PATCH 2/8] update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 621804f78..78e44780c 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def merge_req_lists(*req_lists: List[str]) -> List[str]: core = [ - "pydantic<2.0", + "pydantic>=2.0,<3.0", "nest-asyncio", "typing-extensions", ] From 9a58516d581e140ed2fc7097dbb0551fee946e3e Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Fri, 14 Jul 2023 18:24:19 +0300 Subject: [PATCH 3/8] Uncomment validate_call for context methods; avoid pydantic errors by using the 'singledispatchmethod' decorator --- dff/script/core/context.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/dff/script/core/context.py b/dff/script/core/context.py index df552b955..f41a7be44 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -21,9 +21,10 @@ from typing import Any, Optional, Union, Dict, List, Set -from pydantic import ConfigDict, BaseModel, validate_call, Field, field_validator +from pydantic import ConfigDict, BaseModel, Field, field_validator, validate_call from .types import NodeLabel2Type, ModuleName from .message import Message +from functools import singledispatchmethod logger = logging.getLogger(__name__) @@ -149,7 +150,8 @@ def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs ) return ctx - # @validate_call + @singledispatchmethod + @validate_call def add_request(self, request: Message): """ Adds to the context the next `request` corresponding to the next turn. @@ -160,7 +162,8 @@ def add_request(self, request: Message): last_index = get_last_index(self.requests) self.requests[last_index + 1] = request - # @validate_call + @singledispatchmethod + @validate_call def add_response(self, response: Message): """ Adds to the context the next `response` corresponding to the next turn. @@ -171,7 +174,8 @@ def add_response(self, response: Message): last_index = get_last_index(self.responses) self.responses[last_index + 1] = response - # @validate_call + @singledispatchmethod + @validate_call def add_label(self, label: NodeLabel2Type): """ Adds to the context the next :py:const:`label `, @@ -183,7 +187,8 @@ def add_label(self, label: NodeLabel2Type): last_index = get_last_index(self.labels) self.labels[last_index + 1] = label - # @validate_call + @singledispatchmethod + @validate_call def clear( self, hold_last_n_indices: int, @@ -232,6 +237,8 @@ def last_response(self) -> Optional[Message]: last_index = get_last_index(self.responses) return self.responses.get(last_index) + @singledispatchmethod + @validate_call def set_last_response(self, response: Optional[Message]): """ Sets the last `response` of the current :py:class:`~dff.core.engine.core.context.Context`. @@ -249,6 +256,8 @@ def last_request(self) -> Optional[Message]: last_index = get_last_index(self.requests) return self.requests.get(last_index) + @singledispatchmethod + @validate_call def set_last_request(self, request: Optional[Message]): """ Sets the last `request` of the current :py:class:`~dff.core.engine.core.context.Context`. @@ -278,7 +287,8 @@ def current_node(self) -> Optional[Node]: return node - # @validate_call + @singledispatchmethod + @validate_call def overwrite_current_node_in_processing(self, processed_node: Node): """ Overwrites the current node with a processed node. This method only works in processing functions. From b8760f94e4d879cd390a26d6361598ae28a77af0 Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Mon, 17 Jul 2023 11:53:57 +0300 Subject: [PATCH 4/8] update version requirements to pydantic 2.0.3 --- dff/script/core/context.py | 8 -------- setup.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/dff/script/core/context.py b/dff/script/core/context.py index f41a7be44..0fe703b1e 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -24,7 +24,6 @@ from pydantic import ConfigDict, BaseModel, Field, field_validator, validate_call from .types import NodeLabel2Type, ModuleName from .message import Message -from functools import singledispatchmethod logger = logging.getLogger(__name__) @@ -150,7 +149,6 @@ def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs ) return ctx - @singledispatchmethod @validate_call def add_request(self, request: Message): """ @@ -162,7 +160,6 @@ def add_request(self, request: Message): last_index = get_last_index(self.requests) self.requests[last_index + 1] = request - @singledispatchmethod @validate_call def add_response(self, response: Message): """ @@ -174,7 +171,6 @@ def add_response(self, response: Message): last_index = get_last_index(self.responses) self.responses[last_index + 1] = response - @singledispatchmethod @validate_call def add_label(self, label: NodeLabel2Type): """ @@ -187,7 +183,6 @@ def add_label(self, label: NodeLabel2Type): last_index = get_last_index(self.labels) self.labels[last_index + 1] = label - @singledispatchmethod @validate_call def clear( self, @@ -237,7 +232,6 @@ def last_response(self) -> Optional[Message]: last_index = get_last_index(self.responses) return self.responses.get(last_index) - @singledispatchmethod @validate_call def set_last_response(self, response: Optional[Message]): """ @@ -256,7 +250,6 @@ def last_request(self) -> Optional[Message]: last_index = get_last_index(self.requests) return self.requests.get(last_index) - @singledispatchmethod @validate_call def set_last_request(self, request: Optional[Message]): """ @@ -287,7 +280,6 @@ def current_node(self) -> Optional[Node]: return node - @singledispatchmethod @validate_call def overwrite_current_node_in_processing(self, processed_node: Node): """ diff --git a/setup.py b/setup.py index 78e44780c..dd3e6eff5 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def merge_req_lists(*req_lists: List[str]) -> List[str]: core = [ - "pydantic>=2.0,<3.0", + "pydantic>=2.0.3,<3.0", "nest-asyncio", "typing-extensions", ] From d2371953aa2fb9a4c7bfea7532ad36b21ced216e Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Mon, 17 Jul 2023 14:40:23 +0300 Subject: [PATCH 5/8] Comment validate_call temporarily --- dff/script/core/context.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dff/script/core/context.py b/dff/script/core/context.py index 0fe703b1e..38d60c99c 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -30,7 +30,7 @@ Node = BaseModel -@validate_call +# @validate_call def get_last_index(dictionary: dict) -> int: """ Obtaining the last index from the `dictionary`. Functions returns `-1` if the `dict` is empty. @@ -149,7 +149,7 @@ def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs ) return ctx - @validate_call + # @validate_call def add_request(self, request: Message): """ Adds to the context the next `request` corresponding to the next turn. @@ -160,7 +160,7 @@ def add_request(self, request: Message): last_index = get_last_index(self.requests) self.requests[last_index + 1] = request - @validate_call + # @validate_call def add_response(self, response: Message): """ Adds to the context the next `response` corresponding to the next turn. @@ -171,7 +171,7 @@ def add_response(self, response: Message): last_index = get_last_index(self.responses) self.responses[last_index + 1] = response - @validate_call + # @validate_call def add_label(self, label: NodeLabel2Type): """ Adds to the context the next :py:const:`label `, @@ -183,7 +183,7 @@ def add_label(self, label: NodeLabel2Type): last_index = get_last_index(self.labels) self.labels[last_index + 1] = label - @validate_call + # @validate_call def clear( self, hold_last_n_indices: int, @@ -232,7 +232,7 @@ def last_response(self) -> Optional[Message]: last_index = get_last_index(self.responses) return self.responses.get(last_index) - @validate_call + # @validate_call def set_last_response(self, response: Optional[Message]): """ Sets the last `response` of the current :py:class:`~dff.core.engine.core.context.Context`. @@ -250,7 +250,7 @@ def last_request(self) -> Optional[Message]: last_index = get_last_index(self.requests) return self.requests.get(last_index) - @validate_call + # @validate_call def set_last_request(self, request: Optional[Message]): """ Sets the last `request` of the current :py:class:`~dff.core.engine.core.context.Context`. @@ -280,7 +280,7 @@ def current_node(self) -> Optional[Node]: return node - @validate_call + # @validate_call def overwrite_current_node_in_processing(self, processed_node: Node): """ Overwrites the current node with a processed node. This method only works in processing functions. From ae9b3e5ae275416ded64199ee4cd9ca24212512a Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Mon, 17 Jul 2023 14:46:28 +0300 Subject: [PATCH 6/8] bypass lint problem --- dff/script/core/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dff/script/core/context.py b/dff/script/core/context.py index 38d60c99c..b30135e7b 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -21,7 +21,8 @@ from typing import Any, Optional, Union, Dict, List, Set -from pydantic import ConfigDict, BaseModel, Field, field_validator, validate_call +from pydantic import ConfigDict, BaseModel, Field, field_validator +from pydantic import validate_call # noqa: F401 from .types import NodeLabel2Type, ModuleName from .message import Message From f67d2420b01156ebc399558b96fe817fc64ecd77 Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Mon, 17 Jul 2023 15:34:16 +0300 Subject: [PATCH 7/8] remove commented references to validate_call --- dff/script/core/context.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/dff/script/core/context.py b/dff/script/core/context.py index b30135e7b..2b5b65e82 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -22,7 +22,6 @@ from typing import Any, Optional, Union, Dict, List, Set from pydantic import ConfigDict, BaseModel, Field, field_validator -from pydantic import validate_call # noqa: F401 from .types import NodeLabel2Type, ModuleName from .message import Message @@ -150,7 +149,6 @@ def cast(cls, ctx: Optional[Union["Context", dict, str]] = None, *args, **kwargs ) return ctx - # @validate_call def add_request(self, request: Message): """ Adds to the context the next `request` corresponding to the next turn. @@ -158,10 +156,10 @@ def add_request(self, request: Message): :param request: `request` to be added to the context. """ + request_message = Message.model_validate(request) last_index = get_last_index(self.requests) - self.requests[last_index + 1] = request + self.requests[last_index + 1] = request_message - # @validate_call def add_response(self, response: Message): """ Adds to the context the next `response` corresponding to the next turn. @@ -169,10 +167,10 @@ def add_response(self, response: Message): :param response: `response` to be added to the context. """ + response_message = Message.model_validate(response) last_index = get_last_index(self.responses) - self.responses[last_index + 1] = response + self.responses[last_index + 1] = response_message - # @validate_call def add_label(self, label: NodeLabel2Type): """ Adds to the context the next :py:const:`label `, @@ -184,7 +182,6 @@ def add_label(self, label: NodeLabel2Type): last_index = get_last_index(self.labels) self.labels[last_index + 1] = label - # @validate_call def clear( self, hold_last_n_indices: int, @@ -233,14 +230,13 @@ def last_response(self) -> Optional[Message]: last_index = get_last_index(self.responses) return self.responses.get(last_index) - # @validate_call def set_last_response(self, response: Optional[Message]): """ Sets the last `response` of the current :py:class:`~dff.core.engine.core.context.Context`. Required for use with various response wrappers. """ last_index = get_last_index(self.responses) - self.responses[last_index] = Message() if response is None else response + self.responses[last_index] = Message() if response is None else Message.model_validate(response) @property def last_request(self) -> Optional[Message]: @@ -251,14 +247,13 @@ def last_request(self) -> Optional[Message]: last_index = get_last_index(self.requests) return self.requests.get(last_index) - # @validate_call def set_last_request(self, request: Optional[Message]): """ Sets the last `request` of the current :py:class:`~dff.core.engine.core.context.Context`. Required for use with various request wrappers. """ last_index = get_last_index(self.requests) - self.requests[last_index] = Message() if request is None else request + self.requests[last_index] = Message() if request is None else Message.model_validate(request) @property def current_node(self) -> Optional[Node]: @@ -281,7 +276,6 @@ def current_node(self) -> Optional[Node]: return node - # @validate_call def overwrite_current_node_in_processing(self, processed_node: Node): """ Overwrites the current node with a processed node. This method only works in processing functions. @@ -290,7 +284,7 @@ def overwrite_current_node_in_processing(self, processed_node: Node): """ is_processing = self.framework_states.get("actor", {}).get("processed_node") if is_processing: - self.framework_states["actor"]["processed_node"] = processed_node + self.framework_states["actor"]["processed_node"] = Node.model_validate(processed_node) else: logger.warning( f"The `{self.overwrite_current_node_in_processing.__name__}` " From 45010ee81141a702f99bb2e9679b98330f2ab3c5 Mon Sep 17 00:00:00 2001 From: ruthenian8 Date: Wed, 26 Jul 2023 11:37:50 +0300 Subject: [PATCH 8/8] Update model configurations, update property setters for Context --- dff/messengers/telegram/message.py | 5 ++--- dff/pipeline/types.py | 2 +- dff/script/core/context.py | 23 +++++------------------ dff/script/core/message.py | 15 +++++++-------- dff/script/core/script.py | 2 +- 5 files changed, 16 insertions(+), 31 deletions(-) diff --git a/dff/messengers/telegram/message.py b/dff/messengers/telegram/message.py index fc652a146..bc47c4f21 100644 --- a/dff/messengers/telegram/message.py +++ b/dff/messengers/telegram/message.py @@ -39,9 +39,8 @@ class TelegramUI(Keyboard): def validate_buttons(self, _): if not self.is_inline: for button in self.buttons: - assert ( - button.payload is None and button.source is None - ), f"`payload` and `source` are only used for inline keyboards: {button}" + if button.payload is not None or button.source is not None: + raise AssertionError(f"`payload` and `source` are only used for inline keyboards: {button}") return self diff --git a/dff/pipeline/types.py b/dff/pipeline/types.py index 510976ad5..0a1e679fe 100644 --- a/dff/pipeline/types.py +++ b/dff/pipeline/types.py @@ -112,7 +112,7 @@ class ExtraHandlerType(str, Enum): class ServiceRuntimeInfo(BaseModel): name: str path: str - timeout: Optional[float] = None + timeout: Optional[float] asynchronous: bool execution_state: Dict[str, ComponentExecutionState] diff --git a/dff/script/core/context.py b/dff/script/core/context.py index 2b5b65e82..6658c346f 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -21,7 +21,7 @@ from typing import Any, Optional, Union, Dict, List, Set -from pydantic import ConfigDict, BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator from .types import NodeLabel2Type, ModuleName from .message import Message @@ -30,7 +30,6 @@ Node = BaseModel -# @validate_call def get_last_index(dictionary: dict) -> int: """ Obtaining the last index from the `dictionary`. Functions returns `-1` if the `dict` is empty. @@ -47,13 +46,6 @@ class Context(BaseModel): A structure that is used to store data about the context of a dialog. """ - model_config = ConfigDict( - property_set_methods={ - "last_response": "set_last_response", - "last_request": "set_last_request", - } - ) - id: Union[UUID, int, str] = Field(default_factory=uuid4) """ `id` is the unique context identifier. By default, randomly generated using `uuid4` `id` is used. @@ -230,7 +222,8 @@ def last_response(self) -> Optional[Message]: last_index = get_last_index(self.responses) return self.responses.get(last_index) - def set_last_response(self, response: Optional[Message]): + @last_response.setter + def last_response(self, response: Optional[Message]): """ Sets the last `response` of the current :py:class:`~dff.core.engine.core.context.Context`. Required for use with various response wrappers. @@ -247,7 +240,8 @@ def last_request(self) -> Optional[Message]: last_index = get_last_index(self.requests) return self.requests.get(last_index) - def set_last_request(self, request: Optional[Message]): + @last_request.setter + def last_request(self, request: Optional[Message]): """ Sets the last `request` of the current :py:class:`~dff.core.engine.core.context.Context`. Required for use with various request wrappers. @@ -291,12 +285,5 @@ def overwrite_current_node_in_processing(self, processed_node: Node): "function can only be run during processing functions." ) - def __setattr__(self, key, val): - method = self.model_config.get("property_set_methods", {}).get(key, None) - if method is None: - super().__setattr__(key, val) - else: - getattr(self, method)(val) - Context.model_rebuild() diff --git a/dff/script/core/message.py b/dff/script/core/message.py index 5b3a1a0e8..05f9974f9 100644 --- a/dff/script/core/message.py +++ b/dff/script/core/message.py @@ -9,8 +9,7 @@ from pathlib import Path from urllib.request import urlopen -from pydantic import field_validator, ConfigDict, Field, FilePath, HttpUrl, BaseModel -from pydantic import model_validator +from pydantic import field_validator, Field, FilePath, HttpUrl, BaseModel, model_validator class Session(Enum): @@ -22,12 +21,12 @@ class Session(Enum): FINISHED = auto() -class DataModel(BaseModel): +class DataModel(BaseModel, extra="allow", arbitrary_types_allowed=True): """ This class is a Pydantic BaseModel that serves as a base class for all DFF models. """ - model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + ... class Command(DataModel): @@ -89,10 +88,10 @@ def __eq__(self, other): @model_validator(mode="before") @classmethod def validate_source_or_id(cls, values: dict): - assert isinstance(values, dict) - assert bool(values.get("source")) != bool( - values.get("id") - ), "Attachment type requires exactly one parameter, `source` or `id`, to be set." + if not isinstance(values, dict): + raise AssertionError(f"Invalid constructor parameters: {str(values)}") + if bool(values.get("source")) == bool(values.get("id")): + raise AssertionError("Attachment type requires exactly one parameter, `source` or `id`, to be set.") return values @field_validator("source", mode="before") diff --git a/dff/script/core/script.py b/dff/script/core/script.py index 7259fe8b7..31fc7c9a5 100644 --- a/dff/script/core/script.py +++ b/dff/script/core/script.py @@ -40,7 +40,7 @@ class Node(BaseModel, extra="forbid"): @classmethod @validate_call def normalize_transitions( - _, transitions: Dict[NodeLabelType, ConditionType] + cls, transitions: Dict[NodeLabelType, ConditionType] ) -> Dict[Union[Callable, NodeLabel3Type], Callable]: """ The function which is used to normalize transitions and returns normalized dict.