Skip to content

Commit

Permalink
fix_no_dereference_thread_safetyness
Browse files Browse the repository at this point in the history
  • Loading branch information
bagerard committed Aug 19, 2024
1 parent 6f7f7b7 commit d22305b
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 42 deletions.
7 changes: 6 additions & 1 deletion mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,15 @@ def _from_son(cls, son, _auto_dereference=True, created=False):

fields = cls._fields
if not _auto_dereference:
# if auto_deref is turned off, we copy the fields so
# we can mutate the auto_dereference of the fields
fields = copy.deepcopy(fields)
for field in fields.values():
field.set_auto_dereferencing(_auto_dereference)

# Apply field-name / db-field conversion
for field_name, field in fields.items():
field._auto_dereference = _auto_dereference
field.set_auto_dereferencing(_auto_dereference)
if field.db_field in data:
value = data[field.db_field]
try:
Expand Down
50 changes: 47 additions & 3 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
import threading
import weakref

import pymongo
Expand All @@ -16,6 +17,17 @@
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")


class no_dereference_for_field:
def __init__(self, field):
self.field = field

def __enter__(self):
self.field._incr_no_dereference_context()

def __exit__(self, exc_type, exc_value, traceback):
self.field._decr_no_dereference_context()


class BaseField:
"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
Expand All @@ -24,7 +36,8 @@ class BaseField:
name = None # set in TopLevelDocumentMetaclass
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
_auto_dereference = True
_thread_local_storage = threading.local()
# _auto_dereference = True

# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
Expand Down Expand Up @@ -85,6 +98,8 @@ def __init__(
self.sparse = sparse
self._owner_document = None

self.__auto_dereference = True

# Make sure db_field is a string (if it's explicitly defined).
if self.db_field is not None and not isinstance(self.db_field, str):
raise TypeError("db_field should be a string.")
Expand Down Expand Up @@ -120,6 +135,33 @@ def __init__(
self.creation_counter = BaseField.creation_counter
BaseField.creation_counter += 1

def set_auto_dereferencing(self, value):
self.__auto_dereference = value

@property
def no_dereference_context_local(self):
if not hasattr(self._thread_local_storage, "no_dereference_context"):
self._thread_local_storage.no_dereference_context = 0
return self._thread_local_storage.no_dereference_context

@property
def no_dereference_context_is_set(self):
return self.no_dereference_context_local > 0

def _incr_no_dereference_context(self):
self._thread_local_storage.no_dereference_context = (
self.no_dereference_context_local + 1
)

def _decr_no_dereference_context(self):
self._thread_local_storage.no_dereference_context = (
self.no_dereference_context_local - 1
)

@property
def _auto_dereference(self):
return self.__auto_dereference and not self.no_dereference_context_is_set

def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document."""
if instance is None:
Expand Down Expand Up @@ -375,7 +417,9 @@ def to_python(self, value):
return value

if self.field:
self.field._auto_dereference = self._auto_dereference
self.field.set_auto_dereferencing(
self._auto_dereference
) # ?? sounds like it should be temporary for to_python
value_dict = {
key: self.field.to_python(item) for key, item in value.items()
}
Expand Down Expand Up @@ -506,7 +550,7 @@ def lookup_member(self, member_name):
def _set_owner_document(self, owner_document):
if self.field:
self.field.owner_document = owner_document
self._owner_document = owner_document
self._owner_document = owner_document # what's this owner_document?!


class ObjectIdField(BaseField):
Expand Down
49 changes: 40 additions & 9 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern

from mongoengine.base.fields import no_dereference_for_field
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents
Expand All @@ -22,6 +23,7 @@

class MyThreadLocals(threading.local):
def __init__(self):
# {DocCls: count} keeping track of classes with an active no_dereference context
self.no_dereferencing_class = {}


Expand Down Expand Up @@ -126,14 +128,34 @@ def __exit__(self, t, value, traceback):
self.cls._get_collection_name = self.ori_get_collection_name


# import contextlib
#
# @contextlib.contextmanager
# def no_dereference2(cls):
# """no_dereference context manager.
#
# Turns off all dereferencing in Documents for the duration of the context
# manager::
#
# with no_dereference(Group):
# Group.objects()
# """
# try:
# print "Entering internal_cm"
# yield None
# print "Exiting cleanly from internal_cm"
# finally:
# print "Finally internal_cm"


class no_dereference:
"""no_dereference context manager.
Turns off all dereferencing in Documents for the duration of the context
manager::
with no_dereference(Group):
Group.objects.find()
Group.objects()
"""

def __init__(self, cls):
Expand All @@ -148,24 +170,33 @@ def __init__(self, cls):
ComplexBaseField = _import_class("ComplexBaseField")

self.deref_fields = [
k
for k, v in self.cls._fields.items()
if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
field
for name, field in self.cls._fields.items()
if isinstance(
field, (ReferenceField, GenericReferenceField, ComplexBaseField)
)
]
self.no_deref_for_fields_contexts = [
no_dereference_for_field(field) for field in self.deref_fields
]

def __enter__(self):
"""Change the objects default and _auto_dereference values."""
_register_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False
for ndff_context in self.no_deref_for_fields_contexts:
ndff_context.__enter__()
# for field in self.deref_fields:
# self.cls._fields[field]._auto_dereference = False

def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values."""
_unregister_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
for ndff_context in self.no_deref_for_fields_contexts:
ndff_context.__exit__(t, value, traceback)
# for field in self.deref_fields:
# self.cls._fields[field]._auto_dereference = True # should set initial values back


class no_sub_classes:
Expand All @@ -180,7 +211,7 @@ class no_sub_classes:
def __init__(self, cls):
"""Construct the no_sub_classes context manager.
:param cls: the class to turn querying sub classes on
:param cls: the class to turn querying subclasses on
"""
self.cls = cls
self.cls_initial_subclasses = None
Expand Down
5 changes: 2 additions & 3 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,10 +1035,9 @@ class DictField(ComplexBaseField):
"""

def __init__(self, field=None, *args, **kwargs):
self._auto_dereference = False

kwargs.setdefault("default", lambda: {})
kwargs.setdefault("default", dict)
super().__init__(*args, field=field, **kwargs)
self.set_auto_dereferencing(False)

def validate(self, value):
"""Make sure that a list of valid fields is being used."""
Expand Down
48 changes: 31 additions & 17 deletions tests/fields/test_dict_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,39 +116,34 @@ class BlogPost(Document):
post.reload()
assert post.info["authors"] == []

def test_dictfield_dump_document(self):
def test_dictfield_dump_document_with_inheritance__cls(self):
"""Ensure a DictField can handle another document's dump."""

class Doc(Document):
field = DictField()

class ToEmbed(Document):
id = IntField(primary_key=True, default=1)
recursive = DictField()

class ToEmbedParent(Document):
id = IntField(primary_key=True, default=1)
id = IntField(primary_key=True)
recursive = DictField()

meta = {"allow_inheritance": True}

class ToEmbedChild(ToEmbedParent):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

to_embed_recursive = ToEmbed(id=1).save()
to_embed = ToEmbed(
id=2, recursive=to_embed_recursive.to_mongo().to_dict()
).save()
doc = Doc(field=to_embed.to_mongo().to_dict())
doc.save()
assert isinstance(doc.field, dict)
assert doc.field == {"_id": 2, "recursive": {"_id": 1, "recursive": {}}}
# Same thing with a Document with a _cls field
Doc.drop_collection()
ToEmbedParent.drop_collection()

# with a Document with a _cls field
to_embed_recursive = ToEmbedChild(id=1).save()
to_embed_child = ToEmbedChild(
id=2, recursive=to_embed_recursive.to_mongo().to_dict()
).save()
doc = Doc(field=to_embed_child.to_mongo().to_dict())

doc_dump_as_dict = to_embed_child.to_mongo().to_dict()
doc = Doc(field=doc_dump_as_dict)
assert isinstance(doc.field, dict) # depends on auto_dereference
doc.save()
assert isinstance(doc.field, dict)
expected = {
Expand All @@ -162,6 +157,25 @@ class ToEmbedChild(ToEmbedParent):
}
assert doc.field == expected

def test_dictfield_dump_document_no_inheritance(self):
"""Ensure a DictField can handle another document's dump."""

class Doc(Document):
field = DictField()

class ToEmbed(Document):
id = IntField(primary_key=True)
recursive = DictField()

to_embed_recursive = ToEmbed(id=1).save()
to_embed = ToEmbed(
id=2, recursive=to_embed_recursive.to_mongo().to_dict()
).save()
doc = Doc(field=to_embed.to_mongo().to_dict())
doc.save()
assert isinstance(doc.field, dict)
assert doc.field == {"_id": 2, "recursive": {"_id": 1, "recursive": {}}}

def test_dictfield_strict(self):
"""Ensure that dict field handles validation if provided a strict field type."""

Expand Down
4 changes: 2 additions & 2 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,9 @@ class Bar(Document):
# When auto_dereference is disabled, there is no trouble returning DBRef
bar = Bar.objects.get()
expected = foo.to_dbref()
bar._fields["ref"]._auto_dereference = False
bar._fields["ref"].set_auto_dereferencing(False)
assert bar.ref == expected
bar._fields["generic_ref"]._auto_dereference = False
bar._fields["generic_ref"].set_auto_dereferencing(False)
assert bar.generic_ref == {"_ref": expected, "_cls": "Foo"}

def test_list_item_dereference(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/queryset/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5310,7 +5310,8 @@ class User(Document):

assert isinstance(qs.first().organization, Organization)

assert isinstance(qs.no_dereference().first().organization, DBRef)
user = qs.no_dereference().first()
assert isinstance(user.organization, DBRef)

assert isinstance(qs_user.organization, Organization)
assert isinstance(qs.first().organization, Organization)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ def test_ci_placeholder():
# setup the tox venv without running the test suite
# if we simply skip all test with pytest -k=wrong_pattern
# pytest command would return with exit_code=5 (i.e "no tests run")
# making travis fail
# making pipeline fail
# this empty test is the recommended way to handle this
# as described in https://github.com/pytest-dev/pytest/issues/2393
pass
18 changes: 13 additions & 5 deletions tests/test_context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,21 @@ def run_in_thread(id):
time.sleep(random.uniform(0.1, 0.5)) # Force desync of threads
if id % 2 == 0:
with no_dereference(Group):
group = Group.objects.first()
assert isinstance(group.ref, DBRef)
for i in range(100):
time.sleep(random.uniform(0.1, 0.5))
assert Group.ref._auto_dereference is False
group = Group.objects.first()
assert isinstance(group.ref, DBRef)
else:
group = Group.objects.first()
assert isinstance(group.ref, User)
for i in range(100):
time.sleep(random.uniform(0.1, 0.5))
assert Group.ref._auto_dereference is True
group = Group.objects.first()
assert isinstance(group.ref, User)

threads = [TestableThread(target=run_in_thread, args=(id,)) for id in range(10)]
threads = [
TestableThread(target=run_in_thread, args=(id,)) for id in range(100)
]
_ = [th.start() for th in threads]
_ = [th.join() for th in threads]

Expand Down

0 comments on commit d22305b

Please sign in to comment.