From c6d7c92d9a2f1fd5fd6b04b4b6f79da6f11f555a Mon Sep 17 00:00:00 2001 From: Bart Smits Date: Sun, 8 Dec 2024 20:49:58 +0100 Subject: [PATCH] Resolve primary key name appropriately using _meta Co-authored-by: Bart Smits Co-authored-by: Brian Kohan --- docs/changelog.rst | 1 + src/polymorphic/base.py | 6 +- .../tests/migrations/0001_initial.py | 50 +++++++++- src/polymorphic/tests/models.py | 19 ++++ src/polymorphic/tests/test_orm.py | 96 +++++++++++++++++++ 5 files changed, 167 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5c799245..d19b1441 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog v4.3.0 (202X-XX-XX) ------------------- +* Fixed `Resolve primary key name correctly. `_ * Implemented `Include get_child_inlines() hook in stacked inline admin forms. `_ * Fixed `multi-database support in inheritance accessors. `_ * Fixed `Caching in inheritance accessor functions `_ diff --git a/src/polymorphic/base.py b/src/polymorphic/base.py index 7affd7c8..2a9783fa 100644 --- a/src/polymorphic/base.py +++ b/src/polymorphic/base.py @@ -79,10 +79,8 @@ def __new__(self, model_name, bases, attrs, **kwargs): # determine the name of the primary key field and store it into the class variable # polymorphic_primary_key_name (it is needed by query.py) - for f in new_class._meta.fields: - if f.primary_key and type(f) is not models.OneToOneField: - new_class.polymorphic_primary_key_name = f.name - break + if new_class._meta.pk: + new_class.polymorphic_primary_key_name = new_class._meta.pk.name return new_class diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index 3bd2f78b..d75bfc18 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,5 +1,6 @@ -# Generated by Django 4.2.26 on 2025-12-05 02:34 +# Generated by Django 4.2.27 on 2025-12-08 15:20 +from django.conf import settings from django.db import migrations, models import django.db.models.deletion import django.db.models.manager @@ -17,6 +18,17 @@ class Migration(migrations.Migration): ] operations = [ + migrations.CreateModel( + name='Account', + fields=[ + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='account', serialize=False, to=settings.AUTH_USER_MODEL)), + ('polymorphic_ctype', models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_%(app_label)s.%(class)s_set+', to='contenttypes.contenttype')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + ), migrations.CreateModel( name='Base', fields=[ @@ -599,6 +611,30 @@ class Migration(migrations.Migration): }, bases=('tests.relationbase',), ), + migrations.CreateModel( + name='SpecialAccount1', + fields=[ + ('account_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.account')), + ('extra1', models.IntegerField(blank=True, default=None, null=True)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.account',), + ), + migrations.CreateModel( + name='SpecialAccount2', + fields=[ + ('account_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.account')), + ('extra1', models.CharField(blank=True, default='', max_length=30)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.account',), + ), migrations.CreateModel( name='SubclassSelectorAbstractConcreteModel', fields=[ @@ -1060,6 +1096,18 @@ class Migration(migrations.Migration): }, bases=('tests.relationb',), ), + migrations.CreateModel( + name='SpecialAccount1_1', + fields=[ + ('specialaccount1_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.specialaccount1')), + ('extra2', models.IntegerField(blank=True, default=None, null=True)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.specialaccount1',), + ), migrations.CreateModel( name='SubclassSelectorProxyConcreteModel', fields=[ diff --git a/src/polymorphic/tests/models.py b/src/polymorphic/tests/models.py index 54710ab5..7e9bd8c3 100644 --- a/src/polymorphic/tests/models.py +++ b/src/polymorphic/tests/models.py @@ -2,6 +2,7 @@ import django from django.contrib.auth.models import Group +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models.query import QuerySet @@ -565,3 +566,21 @@ class Meta: class PurpleHeadDuck(HomeDuck, BlueHeadDuck): class Meta: proxy = True + + +class Account(PolymorphicModel): + user = models.OneToOneField( + get_user_model(), primary_key=True, on_delete=models.CASCADE, related_name="account" + ) + + +class SpecialAccount1(Account): + extra1 = models.IntegerField(null=True, default=None, blank=True) + + +class SpecialAccount1_1(SpecialAccount1): + extra2 = models.IntegerField(null=True, default=None, blank=True) + + +class SpecialAccount2(Account): + extra1 = models.CharField(default="", blank=True, max_length=30) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 4028684c..be260512 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -2,6 +2,7 @@ import re import uuid +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.db import models, connection from django.db.models import Case, Count, FilteredRelation, Q, Sum, When, Exists, OuterRef @@ -94,6 +95,10 @@ UUIDResearchProject, Duck, PurpleHeadDuck, + Account, + SpecialAccount1, + SpecialAccount1_1, + SpecialAccount2, ) @@ -1521,3 +1526,94 @@ def test_subqueries(self): InlineParent.objects.all().delete() InlineModelA.objects.all().delete() InlineModelB.objects.all().delete() + + def test_one_to_one_primary_key(self): + # check pk name resolution + for mdl in [Account, SpecialAccount1, SpecialAccount1_1, SpecialAccount2]: + assert mdl.polymorphic_primary_key_name == mdl._meta.pk.name + + user1 = get_user_model().objects.create( + username="user1", email="user1@example.com", password="password" + ) + user2 = get_user_model().objects.create( + username="user2", email="user2@example.com", password="password" + ) + user3 = get_user_model().objects.create( + username="user3", email="user3@example.com", password="password" + ) + user4 = get_user_model().objects.create( + username="user4", email="user4@example.com", password="password" + ) + + user1_profile = SpecialAccount1_1.objects.create(user=user1, extra1=5, extra2=6) + + user2_profile = SpecialAccount1.objects.create(user=user2, extra1=5) + + user3_profile = SpecialAccount2.objects.create(user=user3, extra1="test") + + user4_profile = SpecialAccount1_1.objects.create(user=user4, extra1=7, extra2=8) + + user1.refresh_from_db() + assert user1.account.__class__ is SpecialAccount1_1 + assert user1.account.extra1 == 5 + assert user1.account.extra2 == 6 + assert user1_profile.pk == user1.account.pk + + user2.refresh_from_db() + assert user2.account.__class__ is SpecialAccount1 + assert user2.account.extra1 == 5 + assert user2_profile.pk == user2.account.pk + assert not hasattr(user2.account, "extra2") + + user3.refresh_from_db() + assert user3.account.__class__ is SpecialAccount2 + assert user3.account.extra1 == "test" + assert user3_profile.pk == user3.account.pk + assert not hasattr(user3.account, "extra2") + + user4.refresh_from_db() + assert user4.account.__class__ is SpecialAccount1_1 + assert user4.account.extra1 == 7 + assert user4.account.extra2 == 8 + assert user4_profile.pk == user4.account.pk + + assert get_user_model().objects.filter(pk=user2.pk).delete() == ( + 3, + {"tests.SpecialAccount1": 1, "tests.Account": 1, "auth.User": 1}, + ) + + assert SpecialAccount1.objects.count() == 2 + assert Account.objects.count() == 3 + + remaining = get_user_model().objects.filter( + pk__in=[user1.pk, user2.pk, user3.pk, user4.pk] + ) + assert remaining.count() == 3 + for usr, expected in zip( + remaining.order_by("pk"), (user1_profile, user3_profile, user4_profile) + ): + assert usr.account == expected + + assert get_user_model().objects.filter(pk__in=[user3.pk]).delete() == ( + 3, + {"tests.SpecialAccount2": 1, "tests.Account": 1, "auth.User": 1}, + ) + + assert Account.objects.count() == 2 + + assert SpecialAccount1_1.objects.all().delete() == ( + 6, + {"tests.SpecialAccount1_1": 2, "tests.SpecialAccount1": 2, "tests.Account": 2}, + ) + + assert Account.objects.count() == 0 + + remaining = get_user_model().objects.filter(pk__gte=user1.pk) + assert remaining.count() == 2 + for usr in remaining: + assert not hasattr(usr, "account") + + assert get_user_model().objects.filter(pk__in=[user1.pk, user4.pk]).delete() == ( + 2, + {"auth.User": 2}, + )