Skip to content

Commit

Permalink
Annotate test_field_tracker module
Browse files Browse the repository at this point in the history
  • Loading branch information
mthuurne committed Apr 10, 2024
1 parent e02ba82 commit 869c6a6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 35 deletions.
92 changes: 61 additions & 31 deletions tests/test_fields/test_field_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from unittest import skip

from django.core.cache import cache
Expand All @@ -9,7 +10,7 @@
from django.test import TestCase

from model_utils import FieldTracker
from model_utils.tracker import DescriptorWrapper
from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker
from tests.models import (
InheritedModelTracked,
InheritedTracked,
Expand All @@ -28,12 +29,18 @@
TrackerTimeStamped,
)

if TYPE_CHECKING:
MixinBase = TestCase
else:
MixinBase = object

class FieldTrackerTestCase(TestCase):

tracker = None
class FieldTrackerMixin(MixinBase):

def assertHasChanged(self, *, tracker=None, **kwargs):
tracker: FieldInstanceTracker
instance: models.Model

def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
for field, value in kwargs.items():
Expand All @@ -43,29 +50,35 @@ def assertHasChanged(self, *, tracker=None, **kwargs):
else:
self.assertEqual(tracker.has_changed(field), value)

def assertPrevious(self, *, tracker=None, **kwargs):
def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
for field, value in kwargs.items():
self.assertEqual(tracker.previous(field), value)

def assertChanged(self, *, tracker=None, **kwargs):
def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
self.assertEqual(tracker.changed(), kwargs)

def assertCurrent(self, *, tracker=None, **kwargs):
def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
self.assertEqual(tracker.current(), kwargs)

def update_instance(self, **kwargs):
def update_instance(self, **kwargs: Any) -> None:
for field, value in kwargs.items():
setattr(self.instance, field, value)
self.instance.save()


class FieldTrackerCommonTests:
class FieldTrackerCommonMixin(FieldTrackerMixin):

instance: (
Tracked | TrackedNotDefault | TrackedMultiple
| ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
| TrackedAbstract
)

def test_pre_save_previous(self) -> None:
self.assertPrevious(name=None, number=None)
Expand All @@ -74,9 +87,10 @@ def test_pre_save_previous(self) -> None:
self.assertPrevious(name=None, number=None)


class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
class FieldTrackerTests(FieldTrackerCommonMixin, TestCase):

tracked_class: type[models.Model] = Tracked
tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked
instance: Tracked | ModelTracked | TrackedAbstract

def setUp(self) -> None:
self.instance = self.tracked_class()
Expand Down Expand Up @@ -221,6 +235,7 @@ def test_with_deferred(self) -> None:
self.instance.number = 1
self.instance.save()
item = self.tracked_class.objects.only('name').first()
assert item is not None
self.assertTrue(item.get_deferred_fields())

# has_changed() returns False for deferred fields, without un-deferring them.
Expand All @@ -236,6 +251,7 @@ def test_with_deferred(self) -> None:

# examining a deferred field un-defers it
item = self.tracked_class.objects.only('name').first()
assert item is not None
self.assertEqual(item.number, 1)
self.assertTrue('number' not in item.get_deferred_fields())
self.assertEqual(item.tracker.previous('number'), 1)
Expand All @@ -254,6 +270,7 @@ def test_with_deferred(self) -> None:
if self.tracked_class == Tracked:

item = self.tracked_class.objects.only('name').first()
assert item is not None
item.number = 2

# previous() fetches correct value from database after deferred field is assigned
Expand All @@ -280,10 +297,10 @@ def test_with_deferred_fields_access_multiple(self) -> None:
instance.name


class FieldTrackedModelCustomTests(FieldTrackerTestCase,
FieldTrackerCommonTests):
class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase):

tracked_class: type[models.Model] = TrackedNotDefault
tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault
instance: TrackedNotDefault | ModelTrackedNotDefault

def setUp(self) -> None:
self.instance = self.tracked_class()
Expand Down Expand Up @@ -360,9 +377,10 @@ def test_update_fields(self) -> None:
self.assertChanged()


class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase):

tracked_class = TrackedNonFieldAttr
instance: TrackedNonFieldAttr

def setUp(self) -> None:
self.instance = self.tracked_class()
Expand Down Expand Up @@ -411,10 +429,10 @@ def test_current(self) -> None:
self.assertCurrent(rounded=8)


class FieldTrackedModelMultiTests(FieldTrackerTestCase,
FieldTrackerCommonTests):
class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase):

tracked_class: type[models.Model] = TrackedMultiple
tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple
instance: TrackedMultiple | ModelTrackedMultiple

def setUp(self) -> None:
self.instance = self.tracked_class()
Expand Down Expand Up @@ -503,10 +521,11 @@ def test_current(self) -> None:
self.assertCurrent(tracker=self.trackers[1], number=8)


class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
class FieldTrackerForeignKeyMixin(FieldTrackerMixin):

fk_class: type[models.Model] = Tracked
tracked_class: type[models.Model] = TrackedFK
fk_class: type[Tracked | ModelTracked]
tracked_class: type[TrackedFK | ModelTrackedFK]
instance: TrackedFK | ModelTrackedFK

def setUp(self) -> None:
self.old_fk = self.fk_class.objects.create(number=8)
Expand Down Expand Up @@ -545,11 +564,18 @@ def test_custom_without_id(self) -> None:
self.assertCurrent(fk=self.instance.fk_id)


class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):

fk_class = Tracked
tracked_class = TrackedFK


class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase):
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""

fk_class = Tracked
tracked_class = TrackedFK
instance: TrackedFK

def setUp(self) -> None:
model_tracked = self.fk_class.objects.create(name="", number=0)
Expand All @@ -568,10 +594,11 @@ def test_custom_without_id(self) -> None:
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))


class FieldTrackerTimeStampedTests(FieldTrackerTestCase):
class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase):

fk_class = Tracked
tracked_class = TrackerTimeStamped
instance: TrackerTimeStamped

def setUp(self) -> None:
self.instance = self.tracked_class.objects.create(name='old', number=1)
Expand Down Expand Up @@ -607,9 +634,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
tracked_class = InheritedTrackedFK


class FieldTrackerFileFieldTests(FieldTrackerTestCase):
class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase):

tracked_class = TrackedFileField
instance: TrackedFileField

def setUp(self) -> None:
self.instance = self.tracked_class()
Expand All @@ -631,7 +659,7 @@ def test_saved_data_without_instance(self) -> None:
self.assertEqual(self.tracker.saved_data, {})
self.update_instance(some_file=self.some_file)
field_file_copy = self.tracker.saved_data.get('some_file')
self.assertIsNotNone(field_file_copy)
assert field_file_copy is not None
self.assertEqual(field_file_copy.__getstate__().get('instance'), None)
self.assertEqual(self.instance.some_file.instance, self.instance)
self.assertIsInstance(self.instance.some_file, FieldFile)
Expand Down Expand Up @@ -732,7 +760,8 @@ def test_current(self) -> None:

class ModelTrackerTests(FieldTrackerTests):

tracked_class: type[models.Model] = ModelTracked
tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked
instance: ModelTracked

def test_cache_compatible(self) -> None:
cache.set('key', self.instance)
Expand Down Expand Up @@ -848,10 +877,11 @@ def test_pre_save_changed(self) -> None:
self.assertChanged()


class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):

fk_class = ModelTracked
tracked_class = ModelTrackedFK
instance: ModelTrackedFK

def test_custom_without_id(self) -> None:
with self.assertNumQueries(2):
Expand Down Expand Up @@ -889,11 +919,11 @@ def setUp(self) -> None:
self.instance = Tracked.objects.create(number=1)
self.tracker = self.instance.tracker

def assertChanged(self, *fields):
def assertChanged(self, *fields: str) -> None:
for f in fields:
self.assertTrue(self.tracker.has_changed(f))

def assertNotChanged(self, *fields):
def assertNotChanged(self, *fields: str) -> None:
for f in fields:
self.assertFalse(self.tracker.has_changed(f))

Expand Down Expand Up @@ -924,7 +954,7 @@ def test_context_manager_fields(self) -> None:
def test_tracker_decorator(self) -> None:

@Tracked.tracker
def tracked_method(obj):
def tracked_method(obj: Tracked) -> None:
obj.name = 'new'
self.assertChanged('name')

Expand All @@ -935,7 +965,7 @@ def tracked_method(obj):
def test_tracker_decorator_fields(self) -> None:

@Tracked.tracker(fields=['name'])
def tracked_method(obj):
def tracked_method(obj: Tracked) -> None:
obj.name = 'new'
obj.number += 1
self.assertChanged('name', 'number')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fields/test_monitor_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_double_save(self) -> None:

def test_no_monitor_arg(self) -> None:
with self.assertRaises(TypeError):
MonitorField()
MonitorField() # type: ignore[call-arg]

def test_monitor_default_is_none_when_nullable(self) -> None:
self.assertIsNone(self.instance.name_changed_nullable)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_fields/test_urlsafe_token_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_factory_default(self) -> None:

def test_factory_not_callable(self) -> None:
with self.assertRaises(TypeError):
UrlsafeTokenField(factory='INVALID')
UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type]

def test_get_default(self) -> None:
field = UrlsafeTokenField()
Expand All @@ -57,8 +57,8 @@ def test_no_default_param(self) -> None:
self.assertIs(field.default, NOT_PROVIDED)

def test_deconstruct(self) -> None:
def test_factory() -> None:
pass
def test_factory(max_length: int) -> str:
assert False
instance = UrlsafeTokenField(factory=test_factory)
name, path, args, kwargs = instance.deconstruct()
new_instance = UrlsafeTokenField(*args, **kwargs)
Expand Down

0 comments on commit 869c6a6

Please sign in to comment.