Skip to content

Commit

Permalink
Disable signals on save feature
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsaral committed May 20, 2019
1 parent b7d953d commit 5d6f8f4
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 2 deletions.
19 changes: 18 additions & 1 deletion docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,22 @@ Also you can override the default uuid version. Versions 1,3,4 and 5 are now sup
pass
.. _`UUIDField`: https://github.com/jazzband/django-model-utils/blob/master/docs/fields.rst#uuidfield


SaveSignalHandlingModel
-----------------------

An abstract base class model to pass a parameter ``signals_to_disable``
to ``save`` method in order to disable signals

.. code-block:: python
from model_utils.models import SaveSignalHandlingModel
class SaveSignalTestModel(SaveSignalHandlingModel):
name = models.CharField(max_length=20)
obj = SaveSignalTestModel(name='Test')
# Note: If you use `Model.objects.create`, the signals can't be disabled
obj.save(signals_to_disable=['pre_save'] # disable `pre_save` signal
60 changes: 59 additions & 1 deletion model_utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import django
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db import models, transaction, router
from django.db.models.signals import post_save, pre_save
from django.utils.translation import ugettext_lazy as _

from model_utils.fields import (
Expand Down Expand Up @@ -159,3 +160,60 @@ class UUIDModel(models.Model):

class Meta:
abstract = True


class SaveSignalHandlingModel(models.Model):
"""
An abstract base class model to pass a parameter ``signals_to_disable``
to ``save`` method in order to disable signals
"""
class Meta:
abstract = True

def save(self, signals_to_disable=None, *args, **kwargs):
"""
Add an extra parameters to hold which signals to disable
If empty, nothing will change
"""

self.signals_to_disable = signals_to_disable or []

super(SaveSignalHandlingModel, self).save(*args, **kwargs)

def save_base(self, raw=False, force_insert=False,
force_update=False, using=None, update_fields=None):
"""
Copied from base class for a minor change.
This is an ugly overwriting but since Django's ``save_base`` method
does not differ between versions 1.8 and 1.10,
that way of implementing wouldn't harm the flow
"""
using = using or router.db_for_write(self.__class__, instance=self)
assert not (force_insert and (force_update or update_fields))
assert update_fields is None or len(update_fields) > 0
cls = origin = self.__class__

if cls._meta.proxy:
cls = cls._meta.concrete_model
meta = cls._meta
if not meta.auto_created and not 'pre_save' in self.signals_to_disable:
pre_save.send(
sender=origin, instance=self, raw=raw, using=using,
update_fields=update_fields,
)
with transaction.atomic(using=using, savepoint=False):
if not raw:
self._save_parents(cls, using, update_fields)
updated = self._save_table(raw, cls, force_insert, force_update, using, update_fields)

self._state.db = using
self._state.adding = False

if not meta.auto_created and not 'post_save' in self.signals_to_disable:
post_save.send(
sender=origin, instance=self, created=(not updated),
update_fields=update_fields, raw=raw, using=using,
)

# Empty the signals in case it might be used somewhere else in future
self.signals_to_disable = []
6 changes: 6 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db import models
from django.db.models.query_utils import DeferredAttribute
from django.db.models import Manager
from django.dispatch import receiver
from django.utils.encoding import python_2_unicode_compatible
from django.utils.translation import ugettext_lazy as _

Expand All @@ -25,6 +26,7 @@
TimeFramedModel,
TimeStampedModel,
UUIDModel,
SaveSignalHandlingModel,
)
from tests.fields import MutableField
from tests.managers import CustomSoftDeleteManager
Expand Down Expand Up @@ -437,3 +439,7 @@ class CustomUUIDModel(UUIDModel):

class CustomNotPrimaryUUIDModel(models.Model):
uuid = UUIDField(primary_key=False)


class SaveSignalHandlingTestModel(SaveSignalHandlingModel):
name = models.CharField(max_length=20)
5 changes: 5 additions & 0 deletions tests/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def pre_save_test(instance, *args, **kwargs):
instance.pre_save_runned = True

def post_save_test(instance, created, *args, **kwargs):
instance.post_save_runned = True
45 changes: 45 additions & 0 deletions tests/test_models/test_savesignalhandling_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import unicode_literals

from django.test import TestCase

from tests.models import SaveSignalHandlingTestModel
from tests.signals import pre_save_test, post_save_test
from django.db.models.signals import pre_save, post_save


class SaveSignalHandlingModelTests(TestCase):

def test_pre_save(self):
pre_save.connect(pre_save_test, sender=SaveSignalHandlingTestModel)

obj = SaveSignalHandlingTestModel.objects.create(name='Test')
delattr(obj, 'pre_save_runned')
obj.name = 'Test A'
obj.save()
self.assertEqual(obj.name, 'Test A')
self.assertTrue(hasattr(obj, 'pre_save_runned'))

obj = SaveSignalHandlingTestModel.objects.create(name='Test')
delattr(obj, 'pre_save_runned')
obj.name = 'Test B'
obj.save(signals_to_disable=['pre_save'])
self.assertEqual(obj.name, 'Test B')
self.assertFalse(hasattr(obj, 'pre_save_runned'))


def test_post_save(self):
post_save.connect(post_save_test, sender=SaveSignalHandlingTestModel)

obj = SaveSignalHandlingTestModel.objects.create(name='Test')
delattr(obj, 'post_save_runned')
obj.name = 'Test A'
obj.save()
self.assertEqual(obj.name, 'Test A')
self.assertTrue(hasattr(obj, 'post_save_runned'))

obj = SaveSignalHandlingTestModel.objects.create(name='Test')
delattr(obj, 'post_save_runned')
obj.name = 'Test B'
obj.save(signals_to_disable=['post_save'])
self.assertEqual(obj.name, 'Test B')
self.assertFalse(hasattr(obj, 'post_save_runned'))

0 comments on commit 5d6f8f4

Please sign in to comment.