Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pydantic2.0 #168

Merged
merged 9 commits into from
Aug 3, 2023
Merged
23 changes: 12 additions & 11 deletions dff/context_storages/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
except ImportError:
json_available = False

from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, model_validator

from .database import DBContextStorage, threadsafe_method
from dff.script import Context


class SerializableStorage(BaseModel, extra=Extra.allow):
@root_validator
class SerializableStorage(BaseModel, extra="allow"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to pydantic 2.0, but why is this storage implemented using this class?
Seems like the only use is to cast all values in loaded json file to Context but can't we just do that inside the _load method?
If not, this class should be documented and defined inside the JSONContextStorage class.

Previously this issue wasn't that evident, but now code like len(self.storage.model_extra) looks weird.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that changes to this class are necessary, since it's going to be replaced by the new context storage implementation

ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
@model_validator(mode="before")
@classmethod
def validate_any(cls, vals):
for key, value in vals.items():
vals[key] = Context.cast(value)
Expand All @@ -43,41 +44,41 @@ def __init__(self, path: str):

@threadsafe_method
async def len_async(self) -> int:
return len(self.storage.__dict__)
return len(self.storage.model_extra)

@threadsafe_method
async def set_item_async(self, key: Hashable, value: Context):
self.storage.__dict__.__setitem__(str(key), value)
self.storage.model_extra.__setitem__(str(key), value)
await self._save()

@threadsafe_method
async def get_item_async(self, key: Hashable) -> Context:
await self._load()
return Context.cast(self.storage.__dict__.__getitem__(str(key)))
return Context.cast(self.storage.model_extra.__getitem__(str(key)))

@threadsafe_method
async def del_item_async(self, key: Hashable):
self.storage.__dict__.__delitem__(str(key))
self.storage.model_extra.__delitem__(str(key))
await self._save()

@threadsafe_method
async def contains_async(self, key: Hashable) -> bool:
await self._load()
return self.storage.__dict__.__contains__(str(key))
return self.storage.model_extra.__contains__(str(key))

@threadsafe_method
async def clear_async(self):
self.storage.__dict__.clear()
self.storage.model_extra.clear()
await self._save()

async def _save(self):
async with aiofiles.open(self.path, "w+", encoding="utf-8") as file_stream:
await file_stream.write(self.storage.json())
await file_stream.write(self.storage.model_dump_json())

async def _load(self):
if not await aiofiles.os.path.isfile(self.path) or (await aiofiles.os.stat(self.path)).st_size == 0:
self.storage = SerializableStorage()
await self._save()
else:
async with aiofiles.open(self.path, "r", encoding="utf-8") as file_stream:
self.storage = SerializableStorage.parse_raw(await file_stream.read())
self.storage = SerializableStorage.model_validate_json(await file_stream.read())
2 changes: 1 addition & 1 deletion dff/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _adjust_key(key: Hashable) -> Dict[str, ObjectId]:
async def set_item_async(self, key: Hashable, value: Context):
new_key = self._adjust_key(key)
value = value if isinstance(value, Context) else Context.cast(value)
document = json.loads(value.json())
document = json.loads(value.model_dump_json())

document.update(new_key)
await self.collection.replace_one(new_key, document, upsert=True)
Expand Down
2 changes: 1 addition & 1 deletion dff/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def contains_async(self, key: Hashable) -> bool:
@threadsafe_method
async def set_item_async(self, key: Hashable, value: Context):
value = value if isinstance(value, Context) else Context.cast(value)
await self._redis.set(str(key), value.json())
await self._redis.set(str(key), value.model_dump_json())

@threadsafe_method
async def get_item_async(self, key: Hashable) -> Context:
Expand Down
2 changes: 1 addition & 1 deletion dff/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, path: str, table_name: str = "contexts", custom_driver: bool
@threadsafe_method
async def set_item_async(self, key: Hashable, value: Context):
value = value if isinstance(value, Context) else Context.cast(value)
value = json.loads(value.json())
value = json.loads(value.model_dump_json())

insert_stmt = insert(self.table).values(id=str(key), context=value)
update_stmt = await self._get_update_stmt(insert_stmt)
Expand Down
2 changes: 1 addition & 1 deletion dff/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def callee(session):

await session.transaction(ydb.SerializableReadWrite()).execute(
prepared_query,
{"$queryId": str(key), "$queryContext": value.json()},
{"$queryId": str(key), "$queryContext": value.model_dump_json()},
commit_tx=True,
)

Expand Down
22 changes: 10 additions & 12 deletions dff/messengers/telegram/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
ChatJoinRequest,
)

from dff.script.core.message import Message, Location, Keyboard, DataModel, root_validator, ValidationError
from dff.script.core.message import Message, Location, Keyboard, DataModel
from pydantic import model_validator


class TelegramUI(Keyboard):
Expand All @@ -34,13 +35,13 @@ class TelegramUI(Keyboard):
row_width: int = 3
"""Limits the maximum number of buttons in a row."""

@root_validator
def validate_buttons(cls, values):
if not values.get("is_inline"):
for button in values.get("buttons"):
@model_validator(mode="after")
def validate_buttons(self, _):
if not self.is_inline:
for button in self.buttons:
if button.payload is not None or button.source is not None:
raise ValidationError(f"`payload` and `source` are only used for inline keyboards: {button}")
return values
raise AssertionError(f"`payload` and `source` are only used for inline keyboards: {button}")
return self


class _ClickButton(DataModel):
Expand All @@ -66,9 +67,6 @@ class ParseMode(Enum):


class TelegramMessage(Message):
class Config:
smart_union = True

ui: Optional[
Union[TelegramUI, RemoveKeyboard, ReplyKeyboardRemove, ReplyKeyboardMarkup, InlineKeyboardMarkup]
] = None
Expand Down Expand Up @@ -97,9 +95,9 @@ class Config:

def __eq__(self, other):
if isinstance(other, Message):
for field in self.__fields__:
for field in self.model_fields:
if field not in ("parse_mode", "update_id", "update", "update_type"):
if field not in other.__fields__:
if field not in other.model_fields:
return False
if self.__getattribute__(field) != other.__getattribute__(field):
return False
Expand Down
2 changes: 0 additions & 2 deletions dff/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from .core.normalization import (
normalize_label,
normalize_condition,
normalize_transitions,
normalize_response,
normalize_processing,
normalize_script,
)
Comment on lines 16 to 20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the PR, but:

Why import these functions if they are not supposed to be used by the user?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .core.script import Node, Script
from .core.types import (
Expand Down
26 changes: 13 additions & 13 deletions dff/script/conditions/std_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import logging
import re

from pydantic import validate_arguments
from pydantic import validate_call

from dff.pipeline import Pipeline
from dff.script import NodeLabel2Type, Context, Message

logger = logging.getLogger(__name__)


@validate_arguments
@validate_call
def exact_match(match: Message, skip_none: bool = True, *args, **kwargs) -> Callable[..., bool]:
"""
Return function handler. This handler returns `True` only if the last user phrase
Expand All @@ -35,11 +35,11 @@ def exact_match_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwa
request = ctx.last_request
if request is None:
return False
for field in match.__fields__:
for field in match.model_fields:
match_value = match.__getattribute__(field)
if skip_none and match_value is None:
continue
if field in request.__fields__.keys():
if field in request.model_fields.keys():
if request.__getattribute__(field) != match.__getattribute__(field):
return False
else:
Expand All @@ -49,7 +49,7 @@ def exact_match_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwa
return exact_match_condition_handler


@validate_arguments
@validate_call
def regexp(
pattern: Union[str, Pattern], flags: Union[int, re.RegexFlag] = 0, *args, **kwargs
) -> Callable[[Context, Pipeline, Any, Any], bool]:
Expand All @@ -75,7 +75,7 @@ def regexp_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs)
return regexp_condition_handler


@validate_arguments
@validate_call
def check_cond_seq(cond_seq: list):
"""
Check if the list consists only of Callables.
Expand All @@ -97,7 +97,7 @@ def check_cond_seq(cond_seq: list):
"""


@validate_arguments
@validate_call
def aggregate(
cond_seq: list, aggregate_func: Callable = _any, *args, **kwargs
) -> Callable[[Context, Pipeline, Any, Any], bool]:
Expand All @@ -119,7 +119,7 @@ def aggregate_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwarg
return aggregate_condition_handler


@validate_arguments
@validate_call
def any(cond_seq: list, *args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]:
"""
Return function handler. This handler returns `True`
Expand All @@ -135,7 +135,7 @@ def any_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) ->
return any_condition_handler


@validate_arguments
@validate_call
def all(cond_seq: list, *args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]:
"""
Return function handler. This handler returns `True` only
Expand All @@ -151,7 +151,7 @@ def all_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) ->
return all_condition_handler


@validate_arguments
@validate_call
def negation(condition: Callable, *args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]:
"""
Return function handler. This handler returns negation of the :py:func:`~condition`: `False`
Expand All @@ -166,7 +166,7 @@ def negation_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs
return negation_condition_handler


@validate_arguments
@validate_call
def has_last_labels(
flow_labels: Optional[List[str]] = None,
labels: Optional[List[NodeLabel2Type]] = None,
Expand Down Expand Up @@ -198,7 +198,7 @@ def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline, *args, *
return has_last_labels_condition_handler


@validate_arguments
@validate_call
def true(*args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]:
"""
Return function handler. This handler always returns `True`.
Expand All @@ -210,7 +210,7 @@ def true_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool:
return true_handler


@validate_arguments
@validate_call
def false(*args, **kwargs) -> Callable[[Context, Pipeline, Any, Any], bool]:
"""
Return function handler. This handler always returns `False`.
Expand Down
Loading