Skip to content

Commit

Permalink
Type annotations for query results
Browse files Browse the repository at this point in the history
supporting type annotation for queries results
  • Loading branch information
roman-right authored Jul 6, 2021
2 parents dfe4cb4 + f9bfb30 commit eda1340
Show file tree
Hide file tree
Showing 16 changed files with 533 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/github-actions-mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ jobs:
- name: mypy install
run: pip3 install mypy types-click types-toml
- name: mypy
run: mypy beanie/ --config-file pyproject.toml
run: mypy beanie/ tests/typing --config-file pyproject.toml
30 changes: 30 additions & 0 deletions .github/workflows/github-actions-pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: PyRight
on: [ pull_request ]

jobs:
pyright:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-18.04 ]
python-version: [ 3.9 ]
poetry-version: [ 1.1.4 ]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: ${{ matrix.poetry-version }}
- name: Setup node.js (for pyright)
uses: actions/setup-node@v1
with:
node-version: "12"
- name: poetry install
run: poetry install
- name: pyright install
run: npm install -g pyright
- name: pyright test
run: poetry run pyright
2 changes: 1 addition & 1 deletion beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from beanie.odm.utils.general import init_beanie
from beanie.odm.documents import Document

__version__ = "1.2.0"
__version__ = "1.2.1"
__all__ = [
# ODM
"Document",
Expand Down
2 changes: 1 addition & 1 deletion beanie/migrations/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def build(cls, path: Path):
for name in names:
module = SourceFileLoader(
(path / name).stem, str((path / name).absolute())
).load_module()
).load_module((path / name).stem)
forward_class = getattr(module, "Forward", None)
backward_class = getattr(module, "Backward", None)
migration_node = MigrationNode(
Expand Down
188 changes: 167 additions & 21 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
Dict,
Optional,
List,
Type,
Expand All @@ -7,6 +8,7 @@
Mapping,
TypeVar,
Any,
overload,
)

from bson import ObjectId
Expand Down Expand Up @@ -43,6 +45,7 @@
from beanie.odm.utils.dump import get_dict

DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


class Document(BaseModel, UpdateMethods):
Expand Down Expand Up @@ -163,13 +166,33 @@ async def get(
document_id = parse_obj_as(cls.__fields__["id"].type_, document_id)
return await cls.find_one({"_id": document_id}, session=session)

@overload
@classmethod
def find_one(
cls,
*args: Mapping[str, Any],
projection_model: Optional[Type[BaseModel]] = None,
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: None = None,
session: Optional[ClientSession] = None,
) -> FindOne[DocType]:
...

@overload
@classmethod
def find_one(
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: Type[DocumentProjectionType],
session: Optional[ClientSession] = None,
) -> FindOne:
) -> FindOne[DocumentProjectionType]:
...

@classmethod
def find_one(
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: Optional[Type[DocumentProjectionType]] = None,
session: Optional[ClientSession] = None,
):
"""
Find one document by criteria.
Returns [FindOne](https://roman-right.github.io/beanie/api/queries/#findone) query object.
Expand All @@ -186,16 +209,42 @@ def find_one(
session=session,
)

@overload
@classmethod
def find_many(
cls,
*args: Mapping[str, Any],
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: None = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[ClientSession] = None,
) -> FindMany[DocType]:
...

@overload
@classmethod
def find_many(
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: Type[DocumentProjectionType] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
projection_model: Optional[Type[BaseModel]] = None,
session: Optional[ClientSession] = None,
) -> FindMany:
) -> FindMany[DocumentProjectionType]:
...

@classmethod
def find_many(
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: Optional[Type[DocumentProjectionType]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[ClientSession] = None,
) -> Union[FindMany[DocType], FindMany[DocumentProjectionType]]:
"""
Find many documents by criteria.
Returns [FindMany](https://roman-right.github.io/beanie/api/queries/#findmany) query object
Expand All @@ -219,16 +268,42 @@ def find_many(
session=session,
)

@overload
@classmethod
def find(
cls,
*args: Mapping[str, Any],
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: None = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[ClientSession] = None,
) -> FindMany[DocType]:
...

@overload
@classmethod
def find(
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: Type[DocumentProjectionType],
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[ClientSession] = None,
) -> FindMany[DocumentProjectionType]:
...

@classmethod
def find(
cls: Type[DocType],
*args: Union[Mapping[str, Any], Any],
projection_model: Optional[Type[DocumentProjectionType]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
projection_model: Optional[Type[BaseModel]] = None,
session: Optional[ClientSession] = None,
) -> FindMany:
) -> Union[FindMany[DocType], FindMany[DocumentProjectionType]]:
"""
The same as find_many
"""
Expand All @@ -241,15 +316,39 @@ def find(
session=session,
)

@overload
@classmethod
def find_all(
cls,
cls: Type[DocType],
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
projection_model: Optional[Type[BaseModel]] = None,
projection_model: None = None,
session: Optional[ClientSession] = None,
) -> FindMany:
) -> FindMany[DocType]:
...

@overload
@classmethod
def find_all(
cls: Type[DocType],
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
projection_model: Optional[Type[DocumentProjectionType]] = None,
session: Optional[ClientSession] = None,
) -> FindMany[DocumentProjectionType]:
...

@classmethod
def find_all(
cls: Type[DocType],
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
projection_model: Optional[Type[DocumentProjectionType]] = None,
session: Optional[ClientSession] = None,
) -> Union[FindMany[DocType], FindMany[DocumentProjectionType]]:
"""
Get all the documents
Expand All @@ -271,15 +370,39 @@ def find_all(
session=session,
)

@overload
@classmethod
def all(
cls,
cls: Type[DocType],
projection_model: None = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
projection_model: Optional[Type[BaseModel]] = None,
session: Optional[ClientSession] = None,
) -> FindMany:
) -> FindMany[DocType]:
...

@overload
@classmethod
def all(
cls: Type[DocType],
projection_model: Type[DocumentProjectionType],
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[ClientSession] = None,
) -> FindMany[DocumentProjectionType]:
...

@classmethod
def all(
cls: Type[DocType],
projection_model: Optional[Type[DocumentProjectionType]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[ClientSession] = None,
) -> Union[FindMany[DocType], FindMany[DocumentProjectionType]]:
"""
the same as find_all
"""
Expand Down Expand Up @@ -395,13 +518,36 @@ async def delete_all(
"""
return await cls.find_all().delete(session=session)

@overload
@classmethod
def aggregate(
cls,
cls: Type[DocType],
aggregation_pipeline: list,
projection_model: None = None,
session: Optional[ClientSession] = None,
) -> AggregationQuery[Dict[str, Any]]:
...

@overload
@classmethod
def aggregate(
cls: Type[DocType],
aggregation_pipeline: list,
projection_model: Type[DocumentProjectionType],
session: Optional[ClientSession] = None,
) -> AggregationQuery[DocumentProjectionType]:
...

@classmethod
def aggregate(
cls: Type[DocType],
aggregation_pipeline: list,
projection_model: Type[BaseModel] = None,
projection_model: Optional[Type[DocumentProjectionType]] = None,
session: Optional[ClientSession] = None,
) -> AggregationQuery:
) -> Union[
AggregationQuery[Dict[str, Any]],
AggregationQuery[DocumentProjectionType],
]:
"""
Aggregate over collection.
Returns [AggregationQuery](https://roman-right.github.io/beanie/api/queries/#aggregationquery) query object
Expand Down
10 changes: 4 additions & 6 deletions beanie/odm/interfaces/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from abc import abstractmethod
from typing import Type, Any, Optional, Union, List, Dict, cast
from typing import Any, Optional, Union, List, Dict, cast

from pydantic import BaseModel
from pymongo.client_session import ClientSession

from beanie.odm.fields import ExpressionField
from beanie.odm.queries.aggregation import AggregationQuery


class AggregateMethods:
Expand All @@ -17,9 +15,9 @@ class AggregateMethods:
def aggregate(
self,
aggregation_pipeline,
projection_model: Type[BaseModel] = None,
session: Optional[ClientSession] = None,
) -> AggregationQuery:
projection_model=None,
session=None,
):
...

async def sum(
Expand Down
10 changes: 9 additions & 1 deletion beanie/odm/queries/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Optional,
TYPE_CHECKING,
Any,
Generic,
TypeVar,
)

from pydantic import BaseModel
Expand All @@ -16,8 +18,14 @@
if TYPE_CHECKING:
from beanie.odm.documents import DocType

AggregationProjectionType = TypeVar("AggregationProjectionType")

class AggregationQuery(BaseCursorQuery, SessionMethods):

class AggregationQuery(
Generic[AggregationProjectionType],
BaseCursorQuery[AggregationProjectionType],
SessionMethods,
):
"""
Aggregation Query
Expand Down
Loading

0 comments on commit eda1340

Please sign in to comment.