From fcc07b584fb2d29d2507b4c914b0cc99a294b8f9 Mon Sep 17 00:00:00 2001 From: Dhruv <83733638+dhruv-ahuja@users.noreply.github.com> Date: Sat, 24 Aug 2024 08:35:25 +0530 Subject: [PATCH 1/2] Fix Nested Filter-Sort and Model Querying Issues (#29) * refactor: make filter operators constant * fix: resolve filtering issues - ensure each filter is applied only once per query - fix failures on nested filtering * fix: allow nested sorting --- src/config/constants/app.py | 18 +++++++++++++ src/services/poe.py | 26 +++++++----------- src/utils/services.py | 54 ++++++++++++++++++++++--------------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/src/config/constants/app.py b/src/config/constants/app.py index e8e0f4b..1fe1d23 100644 --- a/src/config/constants/app.py +++ b/src/config/constants/app.py @@ -1,8 +1,10 @@ from typing import Literal from uuid import uuid4 import datetime as dt +import operator from beanie.odm.interfaces.find import FindType, DocumentProjectionType +from beanie.odm.operators.find.evaluation import RegEx as RegExOperator from beanie.odm.queries.find import FindMany @@ -31,3 +33,19 @@ FIND_MANY_QUERY = FindMany[FindType] | FindMany[DocumentProjectionType] FILTER_OPERATION = Literal["=", "!=", ">", ">=", "<", "<=", "like"] +FILTER_OPERATION_MAP = { + "=": operator.eq, + "!=": operator.ne, + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "like": RegExOperator, +} +NESTED_FILTER_OPERATION_MAP = { + "=": "$eq", + ">": "$gt", + ">=": "$gte", + "<": "$lt", + "<=": "$lte", +} diff --git a/src/services/poe.py b/src/services/poe.py index f4946c3..06e1879 100644 --- a/src/services/poe.py +++ b/src/services/poe.py @@ -42,30 +42,24 @@ async def get_items( ) -> tuple[list[ItemBase], int]: """Gets items by given category group, and the total items' count in the database.""" - query = Item.find() - chainer = QueryChainer(query, Item) - if filter_sort_input is None: - items_count = await query.find(fetch_links=True).count() - items = await chainer.paginate(pagination).query.find(fetch_links=True).project(ItemBase).to_list() + items_count = await Item.find().count() + items = await QueryChainer(Item.find(), Item).paginate(pagination).query.find().project(ItemBase).to_list() return items, items_count - base_query_chain = chainer.filter(filter_sort_input.filter_).sort(filter_sort_input.sort) - - # * clone the query for use with total record counts and pagination calculations - count_query = ( - base_query_chain.filter(filter_sort_input.filter_) + items_query = ( + QueryChainer(Item.find(), Item) + .filter(filter_sort_input.filter_) .sort(filter_sort_input.sort) - .clone() - .query.find(fetch_links=True) - .count() + .paginate(pagination) + .query.project(ItemBase) + .to_list() ) - - paginated_query = base_query_chain.paginate(pagination).query.find(fetch_links=True).project(ItemBase).to_list() + count_query = QueryChainer(Item.find(), Item).filter(filter_sort_input.filter_).query.count() try: - items = await paginated_query + items = await items_query items_count = await count_query except Exception as exc: logger.error(f"error getting items from database; filter_sort: {filter_sort_input}: {exc}") diff --git a/src/utils/services.py b/src/utils/services.py index 752b6df..de54e9a 100644 --- a/src/utils/services.py +++ b/src/utils/services.py @@ -1,15 +1,15 @@ import copy -import operator from typing import Self, Type, cast from beanie import Document from beanie.odm.operators.find.evaluation import RegEx as RegExOperator +from bson import Decimal128 from loguru import logger import orjson import pymongo from redis.asyncio import Redis, RedisError -from src.config.constants.app import FIND_MANY_QUERY +from src.config.constants.app import FILTER_OPERATION_MAP, FIND_MANY_QUERY, NESTED_FILTER_OPERATION_MAP from src.schemas.requests import FilterInputType, FilterSchema, PaginationInput, SortInputType, SortSchema from src.schemas.responses import E, T, BaseResponse @@ -71,7 +71,8 @@ def sort_on_query(query: FIND_MANY_QUERY, model: Type[Document], sort: SortInput field = entry.field operation = pymongo.ASCENDING if entry.operation == "asc" else pymongo.DESCENDING - model_field = getattr(model, field) + is_nested = "." in field + model_field = field if is_nested else getattr(model, field) expression = (model_field, operation) sort_expressions.append(expression) @@ -80,6 +81,23 @@ def sort_on_query(query: FIND_MANY_QUERY, model: Type[Document], sort: SortInput return query +def _build_nested_query(entry: FilterSchema, query: FIND_MANY_QUERY) -> FIND_MANY_QUERY: + """Builds queries for nested fields, using raw BSON query syntax to ensure nested fields are parsed properly.""" + + field = entry.field + operation = entry.operation + value = entry.value + + if operation != "like": + operation_function = NESTED_FILTER_OPERATION_MAP[operation] + filter_query = {field: {operation_function: Decimal128(value)}} + else: + filter_query = {field: {"$regex": value, "$options": "i"}} + + query = query.find(filter_query) + return query + + def filter_on_query(query: FIND_MANY_QUERY, model: Type[Document], filter_: FilterInputType) -> FIND_MANY_QUERY: """Parses, gathers and chains filter operations on the input query. Skips the process if filter input is empty.\n Maps the operation list to operator arguments that allow using the operator dynamically, to create expressions @@ -89,31 +107,25 @@ def filter_on_query(query: FIND_MANY_QUERY, model: Type[Document], filter_: Filt if not isinstance(filter_, list): return query - operation_map = { - "=": operator.eq, - "!=": operator.ne, - ">": operator.gt, - "<": operator.lt, - ">=": operator.ge, - "<=": operator.le, - "like": RegExOperator, - } - for entry in filter_: field = entry.field operation = entry.operation - operation_function = operation_map[operation] + operation_function = FILTER_OPERATION_MAP[operation] value = entry.value - model_field = getattr(model, field) - - if operation != "like": - query = query.find(operation_function(model_field, value)) + is_nested = "." in field + if is_nested: + query = _build_nested_query(entry, query) else: - operation_function = RegExOperator - options = "i" # case-insensitive search + model_field = getattr(model, field) + + if operation != "like": + query = query.find(operation_function(model_field, value)) + else: + operation_function = RegExOperator + options = "i" # case-insensitive search - query = query.find(operation_function(model_field, value, options=options)) + query = query.find(operation_function(model_field, value, options=options)) return query From 5559cccda77e43fb7ed003fe517f8be31414d4b7 Mon Sep 17 00:00:00 2001 From: Dhruv Ahuja <83733638+dhruv-ahuja@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:09:47 +0530 Subject: [PATCH 2/2] refactor: improve price data field representations --- src/schemas/poe.py | 31 +++++++++++++++++++++---------- src/scripts/price_prediction.py | 2 +- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/schemas/poe.py b/src/schemas/poe.py index 777238f..84a116f 100644 --- a/src/schemas/poe.py +++ b/src/schemas/poe.py @@ -19,7 +19,7 @@ def convert_decimal_values(values: dict[dt.datetime, str | Decimal128 | Decimal] return converted_values -def convert_current_price(value: Decimal | Decimal128 | str) -> Decimal: +def convert_decimal_value(value: Decimal | Decimal128 | str) -> Decimal: return value.to_decimal() if isinstance(value, Decimal128) else Decimal(value) @@ -33,31 +33,42 @@ class ItemIdType(str, Enum): receive = "receive" +class PriceDatedData(BaseModel): + """PriceDatedData encapsulates an instance of a timestamp and item value.""" + + timestamp: dt.datetime + price: Annotated[Decimal, BeforeValidator(convert_decimal_value)] + + class ItemPrice(BaseModel): """ItemPrice holds information regarding the current, past and future price of an item. It stores the recent and predicted prices in a dictionary, with the date as the key.""" - chaos_price: Annotated[Decimal, BeforeValidator(convert_current_price)] = Decimal(0) - divine_price: Annotated[Decimal, BeforeValidator(convert_current_price)] = Decimal(0) - price_history: Annotated[dict[dt.datetime, Decimal], BeforeValidator(convert_decimal_values)] = {} + chaos_price: Annotated[Decimal, BeforeValidator(convert_decimal_value)] = Decimal(0) + divine_price: Annotated[Decimal, BeforeValidator(convert_decimal_value)] = Decimal(0) + price_history: list[PriceDatedData] | None = None price_history_currency: Currency = Currency.chaos - price_prediction: Annotated[dict[dt.datetime, Decimal], BeforeValidator(convert_decimal_values)] = {} + price_prediction: list[PriceDatedData] | None = None price_prediction_currency: Currency = Currency.chaos low_confidence: bool = False listings: int = 0 - def serialize(self) -> dict: + def serialize_price_data(self) -> dict: """Serializes the object instance's data, making it compatible with MongoDB. Converts Decimal values into Decimal128 values and datetime keys into string keys.""" - price_history = self.price_history - price_prediction = self.price_prediction + price_history = self.price_history if self.price_history else [] + price_prediction = self.price_prediction if self.price_prediction else [] serialized_data = self.model_dump() # convert datetime keys into string variants - serialized_data["price_history"] = {str(k): v for k, v in price_history.items()} - serialized_data["price_prediction"] = {str(k): v for k, v in price_prediction.items()} + serialized_data["price_history_new"] = [ + {"timestamp": str(entry.timestamp), "price": entry.price} for entry in price_history + ] + serialized_data["price_prediction_new"] = [ + {"timestamp": str(entry.timestamp), "price": entry.price} for entry in price_prediction + ] # convert decimal types into Decimal128 types and cast the output as dictionary serialized_data = convert_decimal(serialized_data) diff --git a/src/scripts/price_prediction.py b/src/scripts/price_prediction.py index 2a895eb..6221643 100644 --- a/src/scripts/price_prediction.py +++ b/src/scripts/price_prediction.py @@ -275,7 +275,7 @@ async def update_items_data(updated_items_queue: Queue[list[ItemRecord] | None]) for item in updated_items: assert item.price_info is not None - serialized_data = item.price_info.serialize() + serialized_data = item.price_info.serialize_price_data() bulk_operations.append( pymongo.UpdateOne(