diff --git a/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py b/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py index f56edaf5b8..5f577e0123 100644 --- a/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py @@ -8,14 +8,15 @@ from datetime import datetime, timezone as dtz from unittest.mock import patch +from django.core.cache import cache from django.test import override_settings from django.utils import timezone from onadata.apps.api.viewsets.entity_list_viewset import EntityListViewSet from onadata.apps.api.tests.viewsets.test_abstract_viewset import TestAbstractViewSet -from onadata.libs.pagination import StandardPageNumberPagination from onadata.apps.logger.models import Entity, EntityHistory, EntityList, Project from onadata.libs.models.share_project import ShareProject +from onadata.libs.pagination import StandardPageNumberPagination from onadata.libs.permissions import ROLES, OwnerRole from onadata.libs.utils.user_auth import get_user_default_project @@ -243,16 +244,17 @@ def _create_entity_list(self, name, project=None): @override_settings(TIME_ZONE="UTC") def test_get_all(self): """Getting all EntityLists works""" - Entity.objects.create( - entity_list=self.trees_entity_list, - json={ - "species": "purpleheart", - "geometry": "-1.286905 36.772845 0 0", - "circumference_cm": 300, - "label": "300cm purpleheart", - }, - uuid="dbee4c32-a922-451c-9df7-42f40bf78f48", - ) + with self.captureOnCommitCallbacks(execute=True): + Entity.objects.create( + entity_list=self.trees_entity_list, + json={ + "species": "purpleheart", + "geometry": "-1.286905 36.772845 0 0", + "circumference_cm": 300, + "label": "300cm purpleheart", + }, + uuid="dbee4c32-a922-451c-9df7-42f40bf78f48", + ) qs = EntityList.objects.all().order_by("pk") first = qs[0] second = qs[1] @@ -384,6 +386,30 @@ def test_soft_deleted_excluded(self): self.assertEqual(response.status_code, 200) self.assertEqual(len(response.data), 0) + def test_num_entities_cached(self): + """`num_entities` includes cached counter""" + entity_list = EntityList.objects.get(name="trees") + entity_list.num_entities = 5 + entity_list.save() + cache.set(f"elist-num-entities-{entity_list.pk}", 7) + + request = self.factory.get("/", **self.extra) + response = self.view(request) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data[0]["num_entities"], 12) + + # Defaults to database counter if cache inaccessible + with patch.object(cache, "get") as mock_cache_get: + with patch("onadata.libs.utils.cache_tools.logger.exception") as mock_exc: + mock_cache_get.side_effect = ConnectionError + request = self.factory.get("/", **self.extra) + response = self.view(request) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data[0]["num_entities"], 5) + mock_exc.assert_called() + @override_settings(TIME_ZONE="UTC") class GetSingleEntityListTestCase(TestAbstractViewSet): @@ -401,16 +427,18 @@ def setUp(self): # Create Entity for trees EntityList trees_entity_list = EntityList.objects.get(name="trees") OwnerRole.add(self.user, trees_entity_list) - Entity.objects.create( - entity_list=trees_entity_list, - json={ - "species": "purpleheart", - "geometry": "-1.286905 36.772845 0 0", - "circumference_cm": 300, - "label": "300cm purpleheart", - }, - uuid="dbee4c32-a922-451c-9df7-42f40bf78f48", - ) + + with self.captureOnCommitCallbacks(execute=True): + Entity.objects.create( + entity_list=trees_entity_list, + json={ + "species": "purpleheart", + "geometry": "-1.286905 36.772845 0 0", + "circumference_cm": 300, + "label": "300cm purpleheart", + }, + uuid="dbee4c32-a922-451c-9df7-42f40bf78f48", + ) def test_get_entity_list(self): """Returns a single EntityList""" @@ -1452,27 +1480,32 @@ def setUp(self): super().setUp() self.view = EntityListViewSet.as_view({"delete": "entities"}) - self._create_entity() + + with self.captureOnCommitCallbacks(execute=True): + self._create_entity() + OwnerRole.add(self.user, self.entity_list) @patch("django.utils.timezone.now") def test_delete(self, mock_now): """Delete Entity works""" self.entity_list.refresh_from_db() - self.assertEqual(self.entity_list.num_entities, 1) + self.assertEqual(cache.get(f"elist-num-entities-{self.entity_list.pk}"), 1) date = datetime(2024, 6, 11, 14, 9, 0, tzinfo=timezone.utc) mock_now.return_value = date - request = self.factory.delete( - "/", data={"entity_ids": [self.entity.pk]}, **self.extra - ) - response = self.view(request, pk=self.entity_list.pk) - self.entity.refresh_from_db() - self.entity_list.refresh_from_db() + + with self.captureOnCommitCallbacks(execute=True): + request = self.factory.delete( + "/", data={"entity_ids": [self.entity.pk]}, **self.extra + ) + response = self.view(request, pk=self.entity_list.pk) + self.entity.refresh_from_db() + self.entity_list.refresh_from_db() self.assertEqual(response.status_code, 204) self.assertEqual(self.entity.deleted_at, date) self.assertEqual(self.entity.deleted_by, self.user) - self.assertEqual(self.entity_list.num_entities, 0) + self.assertEqual(cache.get(f"elist-num-entities-{self.entity_list.pk}"), 0) self.assertEqual( self.entity_list.last_entity_update_time, self.entity.date_modified ) diff --git a/onadata/apps/api/tests/viewsets/test_organization_profile_viewset.py b/onadata/apps/api/tests/viewsets/test_organization_profile_viewset.py index 358a2cc7a1..07fe2d4dd2 100644 --- a/onadata/apps/api/tests/viewsets/test_organization_profile_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_organization_profile_viewset.py @@ -955,7 +955,7 @@ def test_member_added_to_org_with_correct_perms(self): ) response = view(request, user="denoinc") self.assertEqual(response.status_code, 201) - self.assertEqual(response.data, ["denoinc", "aboy"]) + self.assertCountEqual(response.data, ["denoinc", "aboy"]) project_view = ProjectViewSet.as_view({"get": "retrieve"}) request = self.factory.get( diff --git a/onadata/apps/logger/models/entity.py b/onadata/apps/logger/models/entity.py index cbf3ebe492..14bfb039cd 100644 --- a/onadata/apps/logger/models/entity.py +++ b/onadata/apps/logger/models/entity.py @@ -3,6 +3,7 @@ """ import uuid +import importlib from django.contrib.auth import get_user_model from django.db import models, transaction @@ -40,8 +41,13 @@ def soft_delete(self, deleted_by=None): self.deleted_at = deletion_time self.deleted_by = deleted_by self.save(update_fields=["deleted_at", "deleted_by"]) - self.entity_list.num_entities = models.F("num_entities") - 1 - self.entity_list.save() + # Avoid cyclic dependency errors + logger_tasks = importlib.import_module("onadata.apps.logger.tasks") + transaction.on_commit( + lambda: logger_tasks.dec_elist_num_entities_async.delay( + self.entity_list.pk + ) + ) class Meta(BaseModel.Meta): app_label = "logger" diff --git a/onadata/apps/logger/signals.py b/onadata/apps/logger/signals.py index 28ac7806f4..f292eca552 100644 --- a/onadata/apps/logger/signals.py +++ b/onadata/apps/logger/signals.py @@ -4,14 +4,17 @@ """ from django.contrib.contenttypes.models import ContentType from django.db import transaction -from django.db.models import F from django.db.models.signals import post_save, post_delete from django.dispatch import receiver from django.utils import timezone from onadata.apps.logger.models import Entity, EntityList, Instance, SubmissionReview from onadata.apps.logger.models.xform import clear_project_cache -from onadata.apps.logger.tasks import set_entity_list_perms_async +from onadata.apps.logger.tasks import ( + dec_elist_num_entities_async, + inc_elist_num_entities_async, + set_entity_list_perms_async, +) from onadata.apps.main.models.meta_data import MetaData from onadata.libs.utils.logger_tools import create_or_update_entity_from_instance @@ -49,21 +52,16 @@ def increment_entity_list_num_entities(sender, instance, created=False, **kwargs entity_list = instance.entity_list if created: - # Using Queryset.update ensures we do not call the model's save method and - # signals - EntityList.objects.filter(pk=entity_list.pk).update( - num_entities=F("num_entities") + 1 + transaction.on_commit( + lambda: inc_elist_num_entities_async.delay(entity_list.pk) ) @receiver(post_delete, sender=Entity, dispatch_uid="update_enti_el_dec_num_entities") def decrement_entity_list_num_entities(sender, instance, **kwargs): """Decrement EntityList `num_entities`""" - entity_list = instance.entity_list - # Using Queryset.update ensures we do not call the model's save method and - # signals - EntityList.objects.filter(pk=entity_list.pk).update( - num_entities=F("num_entities") - 1 + transaction.on_commit( + lambda: dec_elist_num_entities_async.delay(instance.entity_list.pk) ) diff --git a/onadata/apps/logger/tasks.py b/onadata/apps/logger/tasks.py index 9ca9139fb6..16f1ecde60 100644 --- a/onadata/apps/logger/tasks.py +++ b/onadata/apps/logger/tasks.py @@ -1,16 +1,27 @@ # pylint: disable=import-error,ungrouped-imports -"""Module for logger tasks""" +""" +Asynchronous tasks for the logger app +""" import logging from django.core.cache import cache from django.contrib.auth import get_user_model from django.db import DatabaseError +from multidb.pinning import use_master from onadata.apps.logger.models import Entity, EntityList, Project from onadata.celeryapp import app -from onadata.libs.utils.cache_tools import PROJECT_DATE_MODIFIED_CACHE, safe_delete +from onadata.libs.utils.cache_tools import ( + PROJECT_DATE_MODIFIED_CACHE, + safe_delete, +) from onadata.libs.utils.project_utils import set_project_perms_to_object -from onadata.libs.utils.logger_tools import soft_delete_entities_bulk +from onadata.libs.utils.logger_tools import ( + commit_cached_elist_num_entities, + dec_elist_num_entities, + inc_elist_num_entities, + soft_delete_entities_bulk, +) logger = logging.getLogger(__name__) @@ -24,14 +35,15 @@ def set_entity_list_perms_async(entity_list_id): Args: pk (int): Primary key for EntityList """ - try: - entity_list = EntityList.objects.get(pk=entity_list_id) + with use_master: + try: + entity_list = EntityList.objects.get(pk=entity_list_id) - except EntityList.DoesNotExist as err: - logger.exception(err) - return + except EntityList.DoesNotExist as err: + logger.exception(err) + return - set_project_perms_to_object(entity_list, entity_list.project) + set_project_perms_to_object(entity_list, entity_list.project) @app.task(retry_backoff=3, autoretry_for=(DatabaseError, ConnectionError)) @@ -59,15 +71,50 @@ def delete_entities_bulk_async(entity_pks: list[int], username: str | None = Non entity_pks (list(int)): Primary keys of Entities to be deleted username (str): Username of the user initiating the delete """ - entity_qs = Entity.objects.filter(pk__in=entity_pks, deleted_at__isnull=True) - deleted_by = None + with use_master: + entity_qs = Entity.objects.filter(pk__in=entity_pks, deleted_at__isnull=True) + deleted_by = None + + try: + if username is not None: + deleted_by = User.objects.get(username=username) + + except User.DoesNotExist as exc: + logger.exception(exc) + + else: + soft_delete_entities_bulk(entity_qs, deleted_by) + - try: - if username is not None: - deleted_by = User.objects.get(username=username) +@app.task(retry_backoff=3, autoretry_for=(DatabaseError, ConnectionError)) +def commit_cached_elist_num_entities_async(): + """Commit cached EntityList `num_entities` counter to the database + + Call this task periodically, such as in a background task to ensure + cached counters for EntityList `num_entities` are commited to the + database. - except User.DoesNotExist as exc: - logger.exception(exc) + Cached counters have no expiry, so it is essential to ensure that + this task is called periodically. + """ + commit_cached_elist_num_entities() - else: - soft_delete_entities_bulk(entity_qs, deleted_by) + +@app.task(retry_backoff=3, autoretry_for=(DatabaseError, ConnectionError)) +def inc_elist_num_entities_async(elist_pk: int): + """Increment EntityList `num_entities` counter asynchronously + + Args: + elist_pk (int): Primary key for EntityList + """ + inc_elist_num_entities(elist_pk) + + +@app.task(retry_backoff=3, autoretry_for=(DatabaseError, ConnectionError)) +def dec_elist_num_entities_async(elist_pk: int) -> None: + """Decrement EntityList `num_entities` counter asynchronously + + Args: + elist_pk (int): Primary key for EntityList + """ + dec_elist_num_entities(elist_pk) diff --git a/onadata/apps/logger/tests/models/test_entity.py b/onadata/apps/logger/tests/models/test_entity.py index 4a2c568829..b375bf3f36 100644 --- a/onadata/apps/logger/tests/models/test_entity.py +++ b/onadata/apps/logger/tests/models/test_entity.py @@ -24,6 +24,7 @@ class EntityTestCase(TestBase): def setUp(self): super().setUp() + self.mocked_now = datetime(2023, 11, 8, 13, 17, 0, tzinfo=pytz.utc) self.project = get_user_default_project(self.user) self.entity_list = EntityList.objects.create(name="trees", project=self.project) @@ -57,14 +58,14 @@ def test_optional_fields(self): self.assertEqual(entity.json, {}) self.assertIsInstance(entity.uuid, uuid.UUID) + @patch("onadata.apps.logger.tasks.dec_elist_num_entities_async.delay") @patch("django.utils.timezone.now") - def test_soft_delete(self, mock_now): + def test_soft_delete(self, mock_now, mock_dec): """Soft delete works""" mock_now.return_value = self.mocked_now entity = Entity.objects.create(entity_list=self.entity_list) self.entity_list.refresh_from_db() - self.assertEqual(self.entity_list.num_entities, 1) self.assertIsNone(entity.deleted_at) self.assertIsNone(entity.deleted_by) @@ -72,10 +73,10 @@ def test_soft_delete(self, mock_now): self.entity_list.refresh_from_db() entity.refresh_from_db() - self.assertEqual(self.entity_list.num_entities, 0) self.assertEqual(self.entity_list.last_entity_update_time, self.mocked_now) self.assertEqual(entity.deleted_at, self.mocked_now) self.assertEqual(entity.deleted_at, self.mocked_now) + mock_dec.assert_called_once_with(self.entity_list.pk) # Soft deleted item cannot be soft deleted again deleted_at = timezone.now() @@ -95,20 +96,19 @@ def test_soft_delete(self, mock_now): self.assertEqual(entity3.deleted_at, self.mocked_now) self.assertIsNone(entity3.deleted_by) - def test_hard_delete(self): + @patch("onadata.apps.logger.tasks.dec_elist_num_entities_async.delay") + def test_hard_delete(self, mock_dec): """Hard deleting updates dataset info""" entity = Entity.objects.create(entity_list=self.entity_list) self.entity_list.refresh_from_db() old_last_entity_update_time = self.entity_list.last_entity_update_time - self.assertEqual(self.entity_list.num_entities, 1) - entity.delete() self.entity_list.refresh_from_db() new_last_entity_update_time = self.entity_list.last_entity_update_time - self.assertEqual(self.entity_list.num_entities, 0) self.assertTrue(old_last_entity_update_time < new_last_entity_update_time) + mock_dec.assert_called_once_with(self.entity_list.pk) def test_entity_list_uuid_unique(self): """`entity_list` and `uuid` are unique together""" diff --git a/onadata/apps/logger/tests/test_tasks.py b/onadata/apps/logger/tests/test_tasks.py index 51c1fa0d3f..ac0c5af1c9 100644 --- a/onadata/apps/logger/tests/test_tasks.py +++ b/onadata/apps/logger/tests/test_tasks.py @@ -14,6 +14,7 @@ from onadata.apps.logger.tasks import ( set_entity_list_perms_async, apply_project_date_modified_async, + commit_cached_elist_num_entities_async, ) from onadata.apps.main.tests.test_base import TestBase from onadata.libs.utils.cache_tools import PROJECT_DATE_MODIFIED_CACHE @@ -105,4 +106,51 @@ def test_update_project_date_modified_empty_cache(self): apply_project_date_modified_async.delay() # Verify that no projects were updated - self.assertIsNone(cache.get(PROJECT_DATE_MODIFIED_CACHE)) # Cache should remain empty + self.assertIsNone( + cache.get(PROJECT_DATE_MODIFIED_CACHE) + ) # Cache should remain empty + + +@patch("onadata.apps.logger.tasks.commit_cached_elist_num_entities") +class CommitEListNumEntitiesAsyncTestCase(TestBase): + """Tests for method `commit_cached_elist_num_entities_async`""" + + def setUp(self): + super().setUp() + + self.project = get_user_default_project(self.user) + self.entity_list = EntityList.objects.create( + name="trees", project=self.project, num_entities=10 + ) + + def test_counter_commited(self, mock_commit): + """Cached counter is commited in the database""" + # pylint: disable=no-member + commit_cached_elist_num_entities_async.delay() + mock_commit.assert_called_once() + + @patch("onadata.apps.logger.tasks.commit_cached_elist_num_entities_async.retry") + def test_retry_connection_error(self, mock_retry, mock_set_perms): + """ConnectionError exception is retried""" + mock_retry.side_effect = Retry + mock_set_perms.side_effect = ConnectionError + # pylint: disable=no-member + commit_cached_elist_num_entities_async.delay() + + self.assertTrue(mock_retry.called) + + _, kwargs = mock_retry.call_args_list[0] + self.assertTrue(isinstance(kwargs["exc"], ConnectionError)) + + @patch("onadata.apps.logger.tasks.commit_cached_elist_num_entities_async.retry") + def test_retry_database_error(self, mock_retry, mock_set_perms): + """DatabaseError exception is retried""" + mock_retry.side_effect = Retry + mock_set_perms.side_effect = DatabaseError + # pylint: disable=no-member + commit_cached_elist_num_entities_async.delay() + + self.assertTrue(mock_retry.called) + + _, kwargs = mock_retry.call_args_list[0] + self.assertTrue(isinstance(kwargs["exc"], DatabaseError)) diff --git a/onadata/libs/serializers/entity_serializer.py b/onadata/libs/serializers/entity_serializer.py index 07174143ae..61666d6a7b 100644 --- a/onadata/libs/serializers/entity_serializer.py +++ b/onadata/libs/serializers/entity_serializer.py @@ -2,6 +2,7 @@ """ Entities serializer module. """ + from django.utils.translation import gettext as _ from pyxform.constants import ENTITIES_RESERVED_PREFIX @@ -19,6 +20,7 @@ ) from onadata.apps.logger.tasks import delete_entities_bulk_async from onadata.libs.permissions import CAN_VIEW_PROJECT +from onadata.libs.utils.cache_tools import ELIST_NUM_ENTITIES, safe_cache_get class EntityListSerializer(serializers.ModelSerializer): @@ -79,6 +81,7 @@ class EntityListArraySerializer(serializers.HyperlinkedModelSerializer): public = serializers.BooleanField(source="project.shared") num_registration_forms = serializers.SerializerMethodField() num_follow_up_forms = serializers.SerializerMethodField() + num_entities = serializers.SerializerMethodField() def get_num_registration_forms(self, obj: EntityList) -> int: """Returns number of RegistrationForms for EntityList object""" @@ -88,6 +91,15 @@ def get_num_follow_up_forms(self, obj: EntityList) -> int: """Returns number of FollowUpForms consuming Entities from dataset""" return obj.follow_up_forms.count() + def get_num_entities(self, obj: EntityList) -> int: + """Returns number of Entities in the dataset + + Adds cached counter to database counter + """ + cached_counter = safe_cache_get(f"{ELIST_NUM_ENTITIES}{obj.pk}", 0) + + return obj.num_entities + cached_counter + class Meta: model = EntityList fields = ( diff --git a/onadata/libs/tests/utils/test_logger_tools.py b/onadata/libs/tests/utils/test_logger_tools.py index af14cf6cae..825fa1de69 100644 --- a/onadata/libs/tests/utils/test_logger_tools.py +++ b/onadata/libs/tests/utils/test_logger_tools.py @@ -4,12 +4,16 @@ """ import os import re +from datetime import datetime, timedelta from io import BytesIO -from unittest.mock import patch +from unittest.mock import patch, call from django.conf import settings +from django.core.cache import cache from django.core.files.uploadedfile import InMemoryUploadedFile from django.http.request import HttpRequest +from django.utils import timezone +from django.test.utils import override_settings from defusedxml.ElementTree import ParseError @@ -27,12 +31,16 @@ from onadata.libs.test_utils.pyxform_test_case import PyxformTestCase from onadata.libs.utils.common_tags import MEDIA_ALL_RECEIVED, MEDIA_COUNT, TOTAL_MEDIA from onadata.libs.utils.logger_tools import ( + commit_cached_elist_num_entities, create_entity_from_instance, create_instance, + dec_elist_num_entities, generate_content_disposition_header, get_first_record, + inc_elist_num_entities, safe_create_instance, ) +from onadata.libs.utils.user_auth import get_user_default_project class TestLoggerTools(PyxformTestCase, TestBase): @@ -717,7 +725,7 @@ def test_entity_created(self): self.assertCountEqual(entity.json, expected_json) self.assertEqual(str(entity.uuid), "dbee4c32-a922-451c-9df7-42f40bf78f48") - self.assertEqual(entity_list.num_entities, 1) + self.assertEqual(cache.get(f"elist-num-entities-{entity_list.pk}"), 1) self.assertEqual(entity_list.last_entity_update_time, entity.date_modified) self.assertEqual(entity.history.count(), 1) @@ -787,3 +795,240 @@ def test_grouped_section(self): self.assertEqual(Entity.objects.count(), 1) self.assertCountEqual(entity.json, expected_json) + + +class EntityListNumEntitiesBase(TestBase): + def setUp(self): + super().setUp() + + self.project = get_user_default_project(self.user) + self.entity_list = EntityList.objects.create( + name="trees", project=self.project, num_entities=10 + ) + self.ids_key = "elist-num-entities-ids" + self.lock_key = f"{self.ids_key}-lock" + self.counter_key_prefix = "elist-num-entities-" + self.counter_key = f"{self.counter_key_prefix}{self.entity_list.pk}" + self.created_at_key = "elist-num-entities-ids-created-at" + + def tearDown(self) -> None: + super().tearDown() + + cache.clear() + + +class IncEListNumEntitiesTestCase(EntityListNumEntitiesBase): + """Tests for method `inc_elist_num_entities`""" + + def test_cache_locked(self): + """Database counter is incremented if cache is locked""" + cache.set(self.lock_key, "true") + cache.set(self.counter_key, 3) + inc_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + + self.assertEqual(self.entity_list.num_entities, 11) + # Cached counter should not be updated + self.assertEqual(cache.get(self.counter_key), 3) + + @patch("django.utils.timezone.now") + def test_cache_unlocked(self, mock_now): + """Cache counter is incremented if cache is unlocked""" + mocked_now = datetime(2024, 7, 26, 12, 45, 0, tzinfo=timezone.utc) + mock_now.return_value = mocked_now + + self.assertIsNone(cache.get(self.counter_key)) + self.assertIsNone(cache.get(self.ids_key)) + self.assertIsNone(cache.get(self.created_at_key)) + + inc_elist_num_entities(self.entity_list.pk) + + self.assertEqual(cache.get(self.counter_key), 1) + self.assertEqual(cache.get(self.ids_key), {self.entity_list.pk}) + self.assertEqual(cache.get(self.created_at_key), mocked_now) + self.entity_list.refresh_from_db() + # Database counter should not be updated + self.assertEqual(self.entity_list.num_entities, 10) + # New EntityList + vaccine = EntityList.objects.create(name="vaccine", project=self.project) + inc_elist_num_entities(vaccine.pk) + + self.assertEqual(cache.get(f"{self.counter_key_prefix}{vaccine.pk}"), 1) + self.assertEqual(cache.get(self.ids_key), {self.entity_list.pk, vaccine.pk}) + vaccine.refresh_from_db() + self.assertEqual(vaccine.num_entities, 0) + + # Database counter incremented if cache inacessible + with patch( + "onadata.libs.utils.logger_tools._inc_elist_num_entities_cache" + ) as mock_inc: + with patch("onadata.libs.utils.logger_tools.logger.exception") as mock_exc: + mock_inc.side_effect = ConnectionError + cache.set(self.counter_key, 3) + inc_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + + self.assertEqual(cache.get(self.counter_key), 3) + self.assertEqual(self.entity_list.num_entities, 11) + mock_exc.assert_called_once() + + @patch("django.utils.timezone.now") + @patch.object(cache, "set") + @patch.object(cache, "add") + def test_cache_no_expire(self, mock_cache_add, mock_cache_set, mock_now): + """Cached counter does not expire + + Clean up should be done periodically such as in a background task + """ + mocked_now = datetime(2024, 7, 26, 12, 45, 0, tzinfo=timezone.utc) + mock_now.return_value = mocked_now + inc_elist_num_entities(self.entity_list.pk) + + # Timeout should be `None` + self.assertTrue( + call(self.counter_key, 1, None) in mock_cache_add.call_args_list + ) + self.assertTrue( + call(self.created_at_key, mocked_now, None) in mock_cache_add.call_args_list + ) + mock_cache_set.assert_called_once_with( + self.ids_key, {self.entity_list.pk}, None + ) + + def test_time_cache_set_once(self): + """The cached time of creation is set once""" + now = timezone.now() + cache.set(self.created_at_key, now) + + inc_elist_num_entities(self.entity_list.pk) + # Cache value is not overridden + self.assertEqual(cache.get(self.created_at_key), now) + + @override_settings(ELIST_COUNTER_COMMIT_FAILOVER_TIMEOUT=3) + @patch("onadata.libs.utils.logger_tools.report_exception") + def test_failover(self, mock_report_exc): + """Failover is executed if commit timeout threshold exceeded""" + cache_created_at = timezone.now() - timedelta(minutes=10) + cache.set(self.counter_key, 3) + cache.set(self.created_at_key, cache_created_at) + cache.set(self.ids_key, {self.entity_list.pk}) + + inc_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + + self.assertEqual(self.entity_list.num_entities, 14) + self.assertIsNone(cache.get(self.counter_key)) + self.assertIsNone(cache.get(self.ids_key)) + self.assertIsNone(cache.get(self.created_at_key)) + subject = "Periodic task not running" + task_name = "onadata.apps.logger.tasks.commit_cached_elist_num_entities_async" + msg = ( + f"The failover has been executed because task {task_name} " + "is not configured or has malfunctioned" + ) + mock_report_exc.assert_called_once_with(subject, msg) + self.assertEqual(cache.get("elist-failover-report-sent"), "sent") + + @override_settings(ELIST_COUNTER_COMMIT_FAILOVER_TIMEOUT=3) + @patch("onadata.libs.utils.logger_tools.report_exception") + def test_failover_report_cache_hit(self, mock_report_exc): + """Report exception not sent if cache `elist-failover-report-sent` set""" + cache.set("elist-failover-report-sent", "sent") + cache_created_at = timezone.now() - timedelta(minutes=10) + cache.set(self.counter_key, 3) + cache.set(self.created_at_key, cache_created_at) + cache.set(self.ids_key, {self.entity_list.pk}) + + inc_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + + self.assertEqual(self.entity_list.num_entities, 14) + self.assertIsNone(cache.get(self.counter_key)) + self.assertIsNone(cache.get(self.ids_key)) + self.assertIsNone(cache.get(self.created_at_key)) + mock_report_exc.assert_not_called() + + +class DecEListNumEntitiesTestCase(EntityListNumEntitiesBase): + """Tests for method `dec_elist_num_entities`""" + + def test_cache_locked(self): + """Database counter is decremented if cache is locked""" + counter_key = f"{self.counter_key_prefix}{self.entity_list.pk}" + cache.set(self.lock_key, "true") + cache.set(counter_key, 3) + dec_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + + self.assertEqual(self.entity_list.num_entities, 9) + # Cached counter should not be updated + self.assertEqual(cache.get(counter_key), 3) + + def test_cache_unlocked(self): + """Cache counter is decremented if cache is unlocked""" + counter_key = f"{self.counter_key_prefix}{self.entity_list.pk}" + cache.set(counter_key, 3) + dec_elist_num_entities(self.entity_list.pk) + + self.assertEqual(cache.get(counter_key), 2) + self.entity_list.refresh_from_db() + # Database counter should not be updated + self.assertEqual(self.entity_list.num_entities, 10) + + # Database counter is decremented if cache missing + cache.delete(counter_key) + dec_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + self.assertEqual(self.entity_list.num_entities, 9) + + # Database counter is decremented if cache inaccesible + with patch( + "onadata.libs.utils.logger_tools._dec_elist_num_entities_cache" + ) as mock_dec: + with patch("onadata.libs.utils.logger_tools.logger.exception") as mock_exc: + mock_dec.side_effect = ConnectionError + cache.set(counter_key, 3) + dec_elist_num_entities(self.entity_list.pk) + self.entity_list.refresh_from_db() + + self.assertEqual(cache.get(counter_key), 3) + self.assertEqual(self.entity_list.num_entities, 8) + mock_exc.assert_called_once() + + +class CommitCachedEListNumEntitiesTestCase(EntityListNumEntitiesBase): + """Tests for method `commit_cached_elist_num_entities`""" + + def test_counter_commited(self): + """Cached counter is commited in the database""" + cache.set(self.ids_key, {self.entity_list.pk}) + cache.set(self.counter_key, 3) + cache.set(self.created_at_key, timezone.now()) + commit_cached_elist_num_entities() + self.entity_list.refresh_from_db() + + self.assertEqual(self.entity_list.num_entities, 13) + self.assertIsNone(cache.get(self.ids_key)) + self.assertIsNone(cache.get(self.counter_key)) + self.assertIsNone(cache.get(self.created_at_key)) + + def test_cache_empty(self): + """Empty cache is handled appropriately""" + commit_cached_elist_num_entities() + self.entity_list.refresh_from_db() + self.assertEqual(self.entity_list.num_entities, 10) + + def test_lock_already_acquired(self): + """Commit unsuccessful if lock is already acquired""" + cache.set(self.lock_key, "true") + cache.set(self.ids_key, {self.entity_list.pk}) + cache.set(self.counter_key, 3) + cache.set(self.created_at_key, timezone.now()) + commit_cached_elist_num_entities() + self.entity_list.refresh_from_db() + + self.assertEqual(self.entity_list.num_entities, 10) + self.assertIsNotNone(cache.get(self.lock_key)) + self.assertIsNotNone(cache.get(self.ids_key)) + self.assertIsNotNone(cache.get(self.counter_key)) + self.assertIsNotNone(cache.get(self.created_at_key)) diff --git a/onadata/libs/utils/cache_tools.py b/onadata/libs/utils/cache_tools.py index 688540a8a8..24a2897609 100644 --- a/onadata/libs/utils/cache_tools.py +++ b/onadata/libs/utils/cache_tools.py @@ -5,6 +5,8 @@ import hashlib import logging import socket +import time +from contextlib import contextmanager from django.core.cache import cache from django.utils.encoding import force_bytes @@ -72,6 +74,17 @@ # Project date modified cache PROJECT_DATE_MODIFIED_CACHE = "project_date_modified" +LOCK_SUFFIX = "-lock" + +# Entities +ELIST_NUM_ENTITIES = "elist-num-entities-" +ELIST_NUM_ENTITIES_IDS = "elist-num-entities-ids" +ELIST_NUM_ENTITIES_LOCK = f"{ELIST_NUM_ENTITIES_IDS}{LOCK_SUFFIX}" +ELIST_NUM_ENTITIES_CREATED_AT = f"{ELIST_NUM_ENTITIES_IDS}-created-at" + +# Report exception +ELIST_FAILOVER_REPORT_SENT = "elist-failover-report-sent" + def safe_delete(key): """Safely deletes a given key from the cache.""" @@ -149,3 +162,79 @@ def safe_cache_get(key, default=None): # older Python versions logger.exception(exc) return default + + +class CacheLockError(Exception): + """Custom exception raised when a cache lock cannot be acquired.""" + + +@contextmanager +def with_cache_lock(cache_key, lock_expire=30, lock_timeout=10): + """ + Context manager for safely setting a cache value with a lock. + + Args: + cache_key (str): The key under which the value is stored in the cache. + lock_expire (int): The expiration time for the lock in seconds. + lock_timeout (int): The maximum time to wait for the lock in seconds. + + Raises: + CacheLockError: If the lock cannot be acquired within + the specified lock_timeout. + + Yields: + None + """ + lock_key = f"lock:{cache_key}" + start_time = time.time() + + # Try to acquire the lock + lock_acquired = cache.add(lock_key, "locked", lock_expire) + + while not lock_acquired and time.time() - start_time < lock_timeout: + time.sleep(0.1) + lock_acquired = cache.add(lock_key, "locked", lock_expire) + + if not lock_acquired: + raise CacheLockError(f"Could not acquire lock for {cache_key}") + + try: + yield + + finally: + cache.delete(lock_key) + + +def set_cache_with_lock( + cache_key, modify_callback, cache_timeout=None, lock_expire=30, lock_timeout=10 +): + """ + Set a cache value with a lock, using a callback function to modify the value. + + Use of lock ensures that race conditions are avoided, even when multiple processes + or threads attempt to modifiy the cache concurrently. + + Args: + cache_key (str): The key under which the value is stored in the cache. + modify_callback (callable): A callback function that takes the current cache + value and returns the modified value. + cache_timeout (int, optional): The expiration time for the cached value + in seconds. If None, the default cache + timeout is used. + lock_expire (int): The expiration time for the lock in seconds. + lock_timeout (int): The maximum time to wait for the lock in seconds. + + Raises: + CacheLockError: If the lock cannot be acquired within the specified + lock_timeout. + + Returns: + None + """ + with with_cache_lock(cache_key, lock_expire, lock_timeout): + # Get the current value from cache + current_value = cache.get(cache_key) + # Use the callback to get the modified value + new_value = modify_callback(current_value) + # Set the new value in the cache with the specified timeout + cache.set(cache_key, new_value, cache_timeout) diff --git a/onadata/libs/utils/logger_tools.py b/onadata/libs/utils/logger_tools.py index 4fb8596efe..6393311351 100644 --- a/onadata/libs/utils/logger_tools.py +++ b/onadata/libs/utils/logger_tools.py @@ -21,6 +21,7 @@ from django.conf import settings from django.contrib.auth import get_user_model +from django.core.cache import cache from django.core.exceptions import ( MultipleObjectsReturned, PermissionDenied, @@ -28,7 +29,7 @@ ) from django.core.files.storage import get_storage_class from django.db import DataError, IntegrityError, transaction -from django.db.models import Q +from django.db.models import Q, F from django.db.models.query import QuerySet from django.http import ( HttpResponse, @@ -52,12 +53,13 @@ from onadata.apps.logger.models import ( Attachment, - Entity, Instance, RegistrationForm, XForm, XFormVersion, ) +from onadata.apps.logger.models.entity import Entity +from onadata.apps.logger.models.entity_list import EntityList from onadata.apps.logger.models.instance import ( FormInactiveError, FormIsMergedDatasetError, @@ -91,6 +93,15 @@ from onadata.apps.viewer.models.parsed_instance import ParsedInstance from onadata.apps.viewer.signals import process_submission from onadata.libs.utils.analytics import TrackObjectEvent +from onadata.libs.utils.cache_tools import ( + ELIST_NUM_ENTITIES, + ELIST_NUM_ENTITIES_IDS, + ELIST_NUM_ENTITIES_LOCK, + ELIST_NUM_ENTITIES_CREATED_AT, + ELIST_FAILOVER_REPORT_SENT, + safe_delete, + set_cache_with_lock, +) from onadata.libs.utils.common_tags import METADATA_FIELDS from onadata.libs.utils.common_tools import get_uuid, report_exception from onadata.libs.utils.model_tools import set_uuid, queryset_iterator @@ -1142,3 +1153,191 @@ def create_or_update_entity_from_instance(instance: Instance) -> None: elif not exists and entity_node.getAttribute("create") in mutation_success_checks: # Create Entity create_entity_from_instance(instance, registration_form) + + +def _inc_elist_num_entities_db(pk: int, count=1) -> None: + """Increment EntityList `num_entities` counter in the database + + Args: + pk (int): Primary key for EntityList + count (int): Value to increase by + """ + # Using Queryset.update ensures we do not call the model's save method and + # signals + EntityList.objects.filter(pk=pk).update(num_entities=F("num_entities") + count) + + +def _dec_elist_num_entities_db(pk: int, count=1) -> None: + """Decrement EntityList `num_entities` counter in the database + + Args: + pk (int): Primary key for EntityList + count (int): Value to decrease by + """ + # Using Queryset.update ensures we do not call the model's save method and + # signals + EntityList.objects.filter(pk=pk).update(num_entities=F("num_entities") - count) + + +def _inc_elist_num_entities_cache(pk: int) -> None: + """Increment EntityList `num_entities` counter in cache + + Args: + pk (int): Primary key for EntityList + """ + counter_cache_key = f"{ELIST_NUM_ENTITIES}{pk}" + # Cache timeout is None (no expiry). A background task should be run + # periodically to persist the cached counters to the db + # and delete the cache. If we were to set a timeout, the cache could + # expire before the next periodic run and data will be lost. + counter_cache_ttl = None + counter_cache_created = cache.add(counter_cache_key, 1, counter_cache_ttl) + + def add_to_cached_ids(current_ids: set | None): + if current_ids is None: + current_ids = set() + + if pk not in current_ids: + current_ids.add(pk) + + return current_ids + + set_cache_with_lock(ELIST_NUM_ENTITIES_IDS, add_to_cached_ids, counter_cache_ttl) + cache.add(ELIST_NUM_ENTITIES_CREATED_AT, timezone.now(), counter_cache_ttl) + + if not counter_cache_created: + cache.incr(counter_cache_key) + + +def _dec_elist_num_entities_cache(pk: int) -> None: + """Decrement EntityList `num_entities` counter in cache + + Args: + pk (int): Primary key for EntityList + """ + counter_cache_key = f"{ELIST_NUM_ENTITIES}{pk}" + + if cache.get(counter_cache_key) is not None: + cache.decr(counter_cache_key) + + +def inc_elist_num_entities(pk: int) -> None: + """Increment EntityList `num_entities` counter + + Updates cached counter if cache is not locked. Else, the database + counter is updated + + Args: + pk (int): Primary key for EntityList + """ + + if _is_elist_num_entities_cache_locked(): + _inc_elist_num_entities_db(pk) + + else: + try: + _inc_elist_num_entities_cache(pk) + _exec_cached_elist_counter_commit_failover() + + except ConnectionError as exc: + logger.exception(exc) + # Fallback to db if cache inacessible + _inc_elist_num_entities_db(pk) + + +def dec_elist_num_entities(pk: int) -> None: + """Decrement EntityList `num_entities` counter + + Updates cached counter if cache is not locked. Else, the database + counter is updated. + + Args: + pk (int): Primary key for EntityList + """ + counter_cache_key = f"{ELIST_NUM_ENTITIES}{pk}" + + if _is_elist_num_entities_cache_locked() or cache.get(counter_cache_key) is None: + _dec_elist_num_entities_db(pk) + + else: + try: + _dec_elist_num_entities_cache(pk) + + except ConnectionError as exc: + logger.exception(exc) + # Fallback to db if cache inacessible + _dec_elist_num_entities_db(pk) + + +def _is_elist_num_entities_cache_locked() -> bool: + """Checks if EntityList `num_entities` cached counter is locked + + Typically, the cache is locked if the cached data is in the process + of being persisted in the database. + + The cache is locked to ensure no further updates are made when the + data is being committed to the database. + + Returns True, if cache is locked, False otherwise + """ + + return cache.get(ELIST_NUM_ENTITIES_LOCK) is not None + + +def commit_cached_elist_num_entities() -> None: + """Commit cached EntityList `num_entities` counter to the database + + Commit is successful if no other process holds the lock + """ + lock_acquired = cache.add(ELIST_NUM_ENTITIES_LOCK, "true", 7200) + + if lock_acquired: + entity_list_pks: set[int] = cache.get(ELIST_NUM_ENTITIES_IDS, set()) + + for pk in entity_list_pks: + counter_key = f"{ELIST_NUM_ENTITIES}{pk}" + counter: int = cache.get(counter_key, 0) + + if counter: + _inc_elist_num_entities_db(pk, counter) + + safe_delete(counter_key) + + safe_delete(ELIST_NUM_ENTITIES_IDS) + safe_delete(ELIST_NUM_ENTITIES_LOCK) + safe_delete(ELIST_NUM_ENTITIES_CREATED_AT) + + +def _exec_cached_elist_counter_commit_failover() -> None: + """Check the time lapse since the cached EntityList `num_entities` + counters were created and commit if the time lapse exceeds + the threshold allowed. + + Acts as a failover incase the cron job responsible for committing + the cached data fails or is not configured + """ + cache_created_at: datetime | None = cache.get(ELIST_NUM_ENTITIES_CREATED_AT) + + if cache_created_at is None: + return + + # If the time lapse is > ELIST_COUNTER_COMMIT_FAILOVER_TIMEOUT, run the failover + failover_timeout: int = getattr( + settings, "ELIST_COUNTER_COMMIT_FAILOVER_TIMEOUT", 7200 + ) + time_lapse = timezone.now() - cache_created_at + + if time_lapse.total_seconds() > failover_timeout: + commit_cached_elist_num_entities() + # Do not send report exception if already sent within the past 24 hrs + if cache.get(ELIST_FAILOVER_REPORT_SENT) is None: + subject = "Periodic task not running" + task_name = ( + "onadata.apps.logger.tasks.commit_cached_elist_num_entities_async" + ) + msg = ( + f"The failover has been executed because task {task_name} " + "is not configured or has malfunctioned" + ) + report_exception(subject, msg) + cache.set(ELIST_FAILOVER_REPORT_SENT, "sent", 86400)