Skip to content

Commit 5945de6

Browse files
mldwgbckohan
authored andcommitted
polymorphic accessors now use builtin caching from underlying fields
1 parent 11d2b74 commit 5945de6

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/polymorphic/models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,15 @@ def __init__(self, *args, **kwargs):
201201
return
202202
self.__class__.polymorphic_super_sub_accessors_replaced = True
203203

204-
def create_accessor_function_for_model(model, accessor_name):
204+
def create_accessor_function_for_model(model, field):
205205
def accessor_function(self):
206-
objects = getattr(model, "_base_objects", model.objects)
207-
attr = objects.get(pk=self.pk)
208-
return attr
206+
try:
207+
rel_obj = field.get_cached_value(self)
208+
except KeyError:
209+
objects = getattr(model, "_base_objects", model.objects)
210+
rel_obj = objects.get(pk=self.pk)
211+
field.set_cached_value(self, rel_obj)
212+
return rel_obj
209213

210214
return accessor_function
211215

@@ -218,10 +222,14 @@ def accessor_function(self):
218222
type(orig_accessor),
219223
(ReverseOneToOneDescriptor, ForwardManyToOneDescriptor),
220224
):
225+
226+
field = orig_accessor.related \
227+
if isinstance(orig_accessor, ReverseOneToOneDescriptor) else orig_accessor.field
228+
221229
setattr(
222230
self.__class__,
223231
name,
224-
property(create_accessor_function_for_model(model, name)),
232+
property(create_accessor_function_for_model(model, field)),
225233
)
226234

227235
def _get_inheritance_relation_fields_and_models(self):

src/polymorphic/tests/test_orm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,29 @@ def test_parent_link_and_related_name(self):
973973
# test that we can delete the object
974974
t.delete()
975975

976+
def test_polymorphic__accessor_caching(self):
977+
blog_a = BlogA.objects.create(name="blog")
978+
979+
blog_base = BlogBase.objects.non_polymorphic().get(id=blog_a.id)
980+
blog_a = BlogA.objects.get(id=blog_a.id)
981+
982+
# test reverse accessor & check that we get back cached object on repeated access
983+
self.assertEqual(blog_base.bloga, blog_a)
984+
self.assertIs(blog_base.bloga, blog_base.bloga)
985+
cached_blog_a = blog_base.bloga
986+
987+
# test forward accessor & check that we get back cached object on repeated access
988+
self.assertEqual(blog_a.blogbase_ptr, blog_base)
989+
self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr)
990+
cached_blog_base = blog_a.blogbase_ptr
991+
992+
# check that refresh_from_db correctly clears cached related objects
993+
blog_base.refresh_from_db()
994+
blog_a.refresh_from_db()
995+
996+
self.assertIsNot(cached_blog_a, blog_base.bloga)
997+
self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr)
998+
976999
def test_polymorphic__aggregate(self):
9771000
"""test ModelX___field syntax on aggregate (should work for annotate either)"""
9781001

0 commit comments

Comments
 (0)