Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AnyUrl not being encodable #1071

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/github-actions-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12", "3.13" ]
mongodb-version: [4.4, 5.0, 6.0, 7.0, 8.0 ]
pydantic-version: [ "1.10.18", "2.9.2" ]
pydantic-version: [ "1.10.18", "2.10.1" ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way we can ensure that it is now not also broken on Pydantic v2.9.2, or some version inbetween?
Or should we always support the newer versions, or rather, test against these newer versions?
Because in the pyproject.toml we support any Pydantic v2 version.

runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
41 changes: 18 additions & 23 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import (
ValidationInfo,
simple_ser_schema,
)
else:
Expand All @@ -64,8 +63,8 @@

if IS_PYDANTIC_V2:
plain_validator = (
core_schema.with_info_plain_validator_function
if hasattr(core_schema, "with_info_plain_validator_function")
core_schema.no_info_plain_validator_function
if hasattr(core_schema, "no_info_plain_validator_function")
else core_schema.general_plain_validator_function
)
else:
Comment on lines 64 to 70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my unpublished commit (working on the Link/BackLink serialization bugfix) I simply removed all of this boilerplate code and used the Pydantic v2 no_info_plain_validator_function() and a single cls._validate(cls, v) method, which already had the same validation code running under both major Pydantic versions.
Of course, this required some small but simple refactoring.

Expand Down Expand Up @@ -135,12 +134,14 @@ class PydanticObjectId(ObjectId):

@classmethod
def __get_validators__(cls):
yield cls.validate
yield cls._validate

if IS_PYDANTIC_V2:

@classmethod
def validate(cls, v, _: ValidationInfo):
def _validate(cls, v: str | PydanticObjectId, *_) -> PydanticObjectId:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why write the signature of the _validate() method with "*_" at the end? Which other arguments can it receive?
I believe the Pydanticv2 no_info_plain_validator_function() signature of the called method is simply (cls, v).
I tested this and it worked for me both under Pydantic v1 and Pydantic v2, using a single cls._validate(cls, v) method.

Also, the v: str | PydanticObjectId part, the v can be "Any" essentially. Due to the if check on line 145, v can also be bytes.

if isinstance(v, ObjectId):
return PydanticObjectId(v)
if isinstance(v, bytes):
v = v.decode("utf-8")
try:
Expand All @@ -152,20 +153,14 @@ def validate(cls, v, _: ValidationInfo):
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema: # type: ignore
return core_schema.json_or_python_schema(
python_schema=plain_validator(cls.validate),
json_schema=plain_validator(
cls.validate,
metadata={
"pydantic_js_input_core_schema": core_schema.str_schema(
pattern="^[0-9a-f]{24}$",
min_length=24,
max_length=24,
)
},
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: str(instance), when_used="json"
return core_schema.no_info_after_validator_function(
cls,
schema=core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=plain_validator(cls._validate),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: str(instance), when_used="json"
),
),
)

Expand All @@ -185,7 +180,7 @@ def __get_pydantic_json_schema__(
else:

@classmethod
def validate(cls, v):
def _validate(cls, v):
if isinstance(v, bytes):
v = v.decode("utf-8")
try:
Expand Down Expand Up @@ -378,7 +373,7 @@ def serialize(value: Union[Link, BaseModel]):

@classmethod
def build_validation(cls, handler, source_type):
def validate(v: Union[DBRef, T], validation_info: ValidationInfo):
def validate(v: Union[DBRef, T], *_):
document_class = DocsRegistry.evaluate_fr(
get_args(source_type)[0]
) # type: ignore # noqa: F821
Expand Down Expand Up @@ -477,7 +472,7 @@ def __init__(self, document_class: Type[T]):

@classmethod
def build_validation(cls, handler, source_type):
def validate(v: Union[DBRef, T], field):
def validate(v: Union[DBRef, T], *_):
document_class = DocsRegistry.evaluate_fr(
get_args(source_type)[0]
) # type: ignore # noqa: F821
Expand Down Expand Up @@ -590,7 +585,7 @@ def merge_indexes(
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema: # type: ignore
def validate(v, _):
def validate(v, *_):
if isinstance(v, IndexModel):
return IndexModelField(v)
else:
Expand Down
2 changes: 2 additions & 0 deletions beanie/odm/utils/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import bson
import pydantic
from pydantic import AnyUrl

import beanie
from beanie.odm.fields import Link, LinkTypes
Expand All @@ -45,6 +46,7 @@
decimal.Decimal: bson.Decimal128,
uuid.UUID: bson.Binary.from_uuid,
re.Pattern: bson.Regex.from_native,
AnyUrl: str,
}
if IS_PYDANTIC_V2:
from pydantic_core import Url
Expand Down
12 changes: 8 additions & 4 deletions beanie/odm/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ def get_model_fields(model):


def parse_model(model_type: Type[BaseModel], data: Any):
if IS_PYDANTIC_V2:
return model_type.model_validate(data)
else:
return model_type.parse_obj(data)
try:
if IS_PYDANTIC_V2:
return model_type.model_validate(data)
else:
return model_type.parse_obj(data)
except Exception:
print(f"Error parsing model {model_type} with data {data}")
model_type.model_validate(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the validation failed for whatever reason, why try and re-call the model_validate() method again, in the Exception case?
If we were using Pydantic v1, this would again re-throw...
Should this line be removed, since it's a no-op?



def get_extra_field_info(field, parameter: str):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ classifiers = [
dependencies = [
"pydantic>=1.10.18,<3.0",
"motor>=2.5.0,<4.0.0",
"pymongo<4.10.0",
"click>=7",
"toml",
"lazy-model==0.2.0",
Expand Down
13 changes: 13 additions & 0 deletions tests/fastapi/test_openapi_retieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastapi.openapi.utils import get_openapi

from tests.fastapi.app import app


def test_openapi_schema_generation():
get_openapi(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test should ideally have some assertation, but I get why this is here...

title=app.title,
version=app.version,
summary=app.summary,
description=app.description,
routes=app.routes,
)
12 changes: 6 additions & 6 deletions tests/odm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,12 @@ async def deprecated_init_beanie(db):
database=db,
document_models=[DocumentWithDeprecatedHiddenField],
)
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)
assert (
"DocumentWithDeprecatedHiddenField: 'hidden=True' is deprecated, please use 'exclude=True'"
in str(w[-1].message)
)
found = False
for warning in w:
if issubclass(warning.category, DeprecationWarning):
found = True
break
assert found, "Deprecation warning not raised"


@pytest.fixture(autouse=True)
Expand Down
41 changes: 26 additions & 15 deletions tests/odm/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,29 @@ class SampleModel3(SampleModel2): ...

class TestConcurrency:
async def test_without_init(self, settings):
for i in range(10):
cli = motor.motor_asyncio.AsyncIOMotorClient(settings.mongodb_dsn)
cli.get_io_loop = asyncio.get_running_loop
db = cli[settings.mongodb_db_name]
await init_beanie(
db, document_models=[SampleModel3, SampleModel, SampleModel2]
)

async def insert_find():
await SampleModel2().insert()
docs = await SampleModel2.find(SampleModel2.i == 10).to_list()
return docs

await asyncio.gather(*[insert_find() for _ in range(10)])
await SampleModel2.delete_all()
clients = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there's too many unrelated changes now in this PR.
This seems to address this test spitting out some warning when there's multiple AsyncIOMotorClients being created and connecting to the same DB, but without closing the previous one.
Could this be put in another PR, please?

try:
for i in range(10):
cli = motor.motor_asyncio.AsyncIOMotorClient(
settings.mongodb_dsn
)
clients.append(cli)
cli.get_io_loop = asyncio.get_running_loop
db = cli[settings.mongodb_db_name]
await init_beanie(
db,
document_models=[SampleModel3, SampleModel, SampleModel2],
)

async def insert_find():
await SampleModel2().insert()
docs = await SampleModel2.find(
SampleModel2.i == 10
).to_list()
return docs

await asyncio.gather(*[insert_find() for _ in range(10)])
await SampleModel2.delete_all()
finally:
for cli in clients:
cli.close()
1 change: 1 addition & 0 deletions tests/odm/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_should_encode_pydantic_v2_url_correctly():
assert encoded_url == "https://example.com/"


# this used to fail before now it does not
async def test_should_be_able_to_save_retrieve_doc_with_url():
doc = DocumentWithHttpUrlField(url_field="https://example.com")
assert isinstance(doc.url_field, AnyUrl)
Expand Down
Loading