Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve type checking support #103

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ pip install pydantic-mongo
```python
from bson import ObjectId
from pydantic import BaseModel
from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo import AbstractRepository, PydanticObjectId
from pymongo import MongoClient
from typing import List
from typing import Optional, List
import os

class Foo(BaseModel):
Expand All @@ -31,7 +31,8 @@ class Bar(BaseModel):
banana: str = 'y'

class Spam(BaseModel):
id: ObjectIdField = None
# PydanticObjectId is an alias to Annotated[ObjectId, ObjectIdAnnotation]
id: Optional[PydanticObjectId] = None
foo: Foo
bars: List[Bar]

Expand Down Expand Up @@ -64,7 +65,7 @@ spam_repository.delete(spam)
# Find One By Id
result = spam_repository.find_one_by_id(spam.id)

# Find One By Id using string if the id attribute is a ObjectIdField
# Find One By Id using string if the id attribute is a PydanticObjectId
result = spam_repository.find_one_by_id(ObjectId('611827f2878b88b49ebb69fc'))
assert result.foo.count == 2

Expand Down
1 change: 1 addition & 0 deletions integration_test/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re


def extract_python_snippets(content):
# Regular expression pattern for finding Python code blocks
pattern = r'```python(.*?)```'
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
plugins = pydantic.mypy
4 changes: 1 addition & 3 deletions phulpyfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def integration_test(phulpy):

@task
def typecheck(phulpy):
result = system(
r'find ./pydantic_mongo -name "*.py" -exec mypy --ignore-missing-imports --follow-imports=skip --strict-optional {} \+'
)
result = system('mypy pydantic_mongo test --check-untyped-defs')
if result:
raise Exception("lint test failed")
9 changes: 7 additions & 2 deletions pydantic_mongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .abstract_repository import AbstractRepository
from .fields import ObjectIdField
from .fields import ObjectIdAnnotation, ObjectIdField, PydanticObjectId
from .version import __version__ # noqa: F401

__all__ = ["ObjectIdField", "AbstractRepository"]
__all__ = [
"AbstractRepository",
"ObjectIdField",
"ObjectIdAnnotation",
"PydanticObjectId",
]
31 changes: 18 additions & 13 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Type,
TypeVar,
Union,
cast,
)

from pydantic import BaseModel
Expand All @@ -26,10 +27,13 @@

T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)

Sort = Sequence[Tuple[str, int]]


class ModelWithId(BaseModel):
id: Any


class AbstractRepository(Generic[T]):
class Meta:
collection_name: str
Expand All @@ -53,8 +57,6 @@ def get_collection(self) -> Collection:
return self.__database[self.__collection_name]

def __validate(self):
if not issubclass(self.__document_class, BaseModel):
raise Exception("Document class should inherit BaseModel")
if "id" not in self.__document_class.model_fields:
raise Exception("Document class should have id field")
if not self.__collection_name:
Expand All @@ -67,10 +69,11 @@ def to_document(model: T) -> dict:
:param model:
:return: dict
"""
data = model.model_dump()
model_with_id = cast(ModelWithId, model)
data = model_with_id.model_dump()
data.pop("id")
if model.id:
data["_id"] = model.id
if model_with_id.id:
data["_id"] = model_with_id.id
return data

def __map_id(self, data: dict) -> dict:
Expand Down Expand Up @@ -109,15 +112,16 @@ def save(self, model: T) -> Union[InsertOneResult, UpdateResult]:
Save entity to database. It will update the entity if it has id, otherwise it will insert it.
"""
document = self.to_document(model)
model_with_id = cast(ModelWithId, model)

if model.id:
if model_with_id.id:
mongo_id = document.pop("_id")
return self.get_collection().update_one(
{"_id": mongo_id}, {"$set": document}, upsert=True
)

result = self.get_collection().insert_one(document)
model.id = result.inserted_id
model_with_id.id = result.inserted_id
return result

def save_many(self, models: Iterable[T]):
Expand All @@ -128,7 +132,8 @@ def save_many(self, models: Iterable[T]):
models_to_update = []

for model in models:
if model.id:
model_with_id = cast(ModelWithId, model)
if model_with_id.id:
models_to_update.append(model)
else:
models_to_insert.append(model)
Expand All @@ -138,7 +143,7 @@ def save_many(self, models: Iterable[T]):
)

for idx, inserted_id in enumerate(result.inserted_ids):
models_to_insert[idx].id = inserted_id
cast(ModelWithId, models_to_insert[idx]).id = inserted_id

if len(models_to_update) == 0:
return
Expand All @@ -152,7 +157,7 @@ def save_many(self, models: Iterable[T]):
self.get_collection().bulk_write(bulk_operations)

def delete(self, model: T):
return self.get_collection().delete_one({"_id": model.id})
return self.get_collection().delete_one({"_id": cast(ModelWithId, model).id})

def delete_by_id(self, _id: Any):
return self.get_collection().delete_one({"_id": _id})
Expand Down Expand Up @@ -198,7 +203,7 @@ def find_by_with_output_type(
cursor.limit(limit)
if skip:
cursor.skip(skip)
if sort:
if mapped_sort:
cursor.sort(mapped_sort)
return map(lambda doc: self.to_model_custom(output_type, doc), cursor)

Expand Down Expand Up @@ -285,7 +290,7 @@ def paginate_with_output_type(
)

return map(
lambda model: Edge[T](
lambda model: Edge[OutputT](
node=model,
cursor=encode_pagination_cursor(
get_pagination_cursor_payload(model, sort_keys)
Expand Down
13 changes: 12 additions & 1 deletion pydantic_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from bson import ObjectId
from pydantic_core import core_schema
from typing_extensions import Annotated


class ObjectIdField(ObjectId):
class ObjectIdAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
Expand All @@ -31,3 +32,13 @@ def validate(cls, value):
raise ValueError("Invalid id")

return ObjectId(value)


# Deprecated, use PydanticObjectId instead.
class ObjectIdField(ObjectId):
@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Any):
return ObjectIdAnnotation.__get_pydantic_core_schema__(_source_type, _handler)


PydanticObjectId = Annotated[ObjectId, ObjectIdAnnotation]
2 changes: 1 addition & 1 deletion pydantic_mongo/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Edge(BaseModel, Generic[DataT]):


def encode_pagination_cursor(data: List) -> str:
byte_data = bson.BSON.encode({"v": data})
byte_data: bytes = bson.BSON.encode({"v": data})
byte_data = zlib.compress(byte_data, 9)
return b64encode(byte_data).decode("utf-8")

Expand Down
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pytest==8.1.1
pytest-cov==4.1.0
pytest-mock==3.12.0
mongomock==4.1.2
pydantic>=2.0.2
pydantic==2.7.1
pymongo==4.6.3
mypy==1.10.0
mypy-extensions==1.0.0
Expand Down
8 changes: 4 additions & 4 deletions test/test_enhance_meta.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo import AbstractRepository, PydanticObjectId
from typing_extensions import Optional


class HamModel(BaseModel):
id: ObjectIdField = Field(default=None)
id: Optional[PydanticObjectId]
name: str


Expand All @@ -31,7 +31,7 @@ def test_repository_with_v2_meta(ham_repo):


def test_save_with_new_repo(clean_ham_collection, ham_repo):
m = HamModel(name="wilfred")
m = HamModel(id=None, name="wilfred")
assert m.id is None, "should have no id"
ham_repo.save(m)
assert m.id
6 changes: 4 additions & 2 deletions test/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
from typing import Optional
from bson import ObjectId
from pydantic import BaseModel, ValidationError

from pydantic_mongo import ObjectIdField


class User(BaseModel):
id: ObjectIdField = None
id: ObjectIdField


class TestFields:
Expand All @@ -26,5 +27,6 @@ def test_modify_schema(self):
assert {
"title": "User",
"type": "object",
"properties": {"id": {"default": None, "title": "Id", "type": "string"}},
"properties": {"id": {"title": "Id", "type": "string"}},
"required": ["id"],
} == schema
7 changes: 3 additions & 4 deletions test/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
from typing import List
from typing import List, Optional

from bson import ObjectId
from pydantic import BaseModel
Expand All @@ -13,7 +12,7 @@

class Foo(BaseModel):
count: int
size: float = None
size: Optional[float] = None


class Bar(BaseModel):
Expand All @@ -22,7 +21,7 @@ class Bar(BaseModel):


class Spam(BaseModel):
id: str = None
id: Optional[str] = None
foo: Foo
bars: List[Bar]

Expand Down
26 changes: 11 additions & 15 deletions test/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import List
from typing import List, Optional, cast

import mongomock
import pytest
from bson import ObjectId
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo import AbstractRepository, PydanticObjectId
from pydantic_mongo.errors import PaginationError


class Foo(BaseModel):
count: int
size: float = None
size: Optional[float] = None


class Bar(BaseModel):
Expand All @@ -20,9 +20,9 @@ class Bar(BaseModel):


class Spam(BaseModel):
id: ObjectIdField = None
foo: Foo = None
bars: List[Bar] = None
id: Optional[PydanticObjectId] = None
foo: Optional[Foo] = None
bars: Optional[List[Bar]] = None


class SpamRepository(AbstractRepository[Spam]):
Expand All @@ -49,7 +49,7 @@ def test_save(self, database):
"bars": [{"apple": "x", "banana": "y"}],
} == database["spams"].find()[0]

spam.foo.count = 2
cast(Foo, spam.foo).count = 2
spam_repository.save(spam)

assert {
Expand Down Expand Up @@ -147,6 +147,8 @@ def test_find_by_id(self, database):
spam_repository = SpamRepository(database=database)
result = spam_repository.find_one_by_id(spam_id)

assert result is not None
assert result.bars is not None
assert issubclass(Spam, type(result))
assert spam_id == result.id
assert "x" == result.bars[0].apple
Expand All @@ -171,6 +173,8 @@ def test_find_by(self, database):
result = spam_repository.find_by({})
results = [x for x in result]
assert 2 == len(results)
assert results[0].foo is not None
assert results[1].foo is not None
assert 2 == results[0].foo.count
assert 3 == results[1].foo.count

Expand All @@ -181,14 +185,6 @@ def test_find_by(self, database):
results = [x for x in result]
assert 0 == len(results)

def test_invalid_model_class(self, database):
class BrokenRepository(AbstractRepository[int]):
class Meta:
collection_name = "spams"

with pytest.raises(Exception):
BrokenRepository(database=database)

def test_invalid_model_id_field(self, database):
class NoIdModel(BaseModel):
something: str
Expand Down
Loading