diff --git a/django_multitenant/settings.py b/django_multitenant/settings.py index 77aa45a..da182ae 100644 --- a/django_multitenant/settings.py +++ b/django_multitenant/settings.py @@ -6,3 +6,4 @@ TENANT_MODEL_NAME = getattr(settings, "TENANT_MODEL_NAME", None) CITUS_EXTENSION_INSTALLED = getattr(settings, "CITUS_EXTENSION_INSTALLED", False) TENANT_STRICT_MODE = getattr(settings, "TENANT_STRICT_MODE", False) +TENANT_USE_ASGIREF = getattr(settings, "TENANT_USE_ASGIREF", False) diff --git a/django_multitenant/tests/settings.py b/django_multitenant/tests/settings.py index 3e10723..5e90198 100644 --- a/django_multitenant/tests/settings.py +++ b/django_multitenant/tests/settings.py @@ -78,3 +78,5 @@ DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" USE_TZ = True + +TENANT_USE_ASGIREF = False diff --git a/django_multitenant/tests/test_utils.py b/django_multitenant/tests/test_utils.py index 66ddfb0..0c75dff 100644 --- a/django_multitenant/tests/test_utils.py +++ b/django_multitenant/tests/test_utils.py @@ -1,3 +1,8 @@ +import sys +import importlib +from asgiref.sync import async_to_sync + + from django_multitenant.utils import ( set_current_tenant, get_current_tenant, @@ -11,6 +16,12 @@ class UtilsTest(BaseTestCase): + async def async_get_current_tenant(self): + return get_current_tenant() + + async def async_set_current_tenant(self, tenant): + return set_current_tenant(tenant) + def test_set_current_tenant(self): projects = self.projects account = projects[0].account @@ -19,6 +30,49 @@ def test_set_current_tenant(self): self.assertEqual(get_current_tenant(), account) unset_current_tenant() + def test_tenant_persists_from_thread_to_async_task(self): + projects = self.projects + account = projects[0].account + + # Set the tenant in main thread + set_current_tenant(account) + + with self.settings(TENANT_USE_ASGIREF=True): + importlib.reload(sys.modules["django_multitenant.utils"]) + + # Check the tenant within an async task when asgiref enabled + tenant = async_to_sync(self.async_get_current_tenant)() + self.assertEqual(get_current_tenant(), tenant) + unset_current_tenant() + + with self.settings(TENANT_USE_ASGIREF=False): + importlib.reload(sys.modules["django_multitenant.utils"]) + + # Check the tenant within an async task when asgiref is disabled + tenant = async_to_sync(self.async_get_current_tenant)() + self.assertIsNone(get_current_tenant()) + unset_current_tenant() + + def test_tenant_persists_from_async_task_to_thread(self): + projects = self.projects + account = projects[0].account + + with self.settings(TENANT_USE_ASGIREF=True): + importlib.reload(sys.modules["django_multitenant.utils"]) + + # Set the tenant in task + async_to_sync(self.async_set_current_tenant)(account) + self.assertEqual(get_current_tenant(), account) + unset_current_tenant() + + with self.settings(TENANT_USE_ASGIREF=False): + importlib.reload(sys.modules["django_multitenant.utils"]) + + # Set the tenant in task + async_to_sync(self.async_set_current_tenant)(account) + self.assertIsNone(get_current_tenant()) + unset_current_tenant() + def test_get_tenant_column(self): from .models import Project diff --git a/django_multitenant/utils.py b/django_multitenant/utils.py index 0053168..c3a9d2d 100644 --- a/django_multitenant/utils.py +++ b/django_multitenant/utils.py @@ -1,14 +1,20 @@ import inspect from django.apps import apps +from django.conf import settings -try: - from threading import local -except ImportError: - from django.utils._threading_local import local + +if settings.TENANT_USE_ASGIREF: + # asgiref must be installed, its included with Django >= 3.0 + from asgiref.local import Local as local +else: + try: + from threading import local + except ImportError: + from django.utils._threading_local import local -_thread_locals = local() +_thread_locals = _context = local() def get_model_by_db_table(db_table): @@ -26,14 +32,14 @@ def get_model_by_db_table(db_table): def get_current_tenant(): """ - Utils to get the tenant that hass been set in the current thread using `set_current_tenant`. + Utils to get the tenant that hass been set in the current thread/context using `set_current_tenant`. Can be used by doing: ``` my_class_object = get_current_tenant() ``` Will return None if the tenant is not set """ - return getattr(_thread_locals, "tenant", None) + return getattr(_context, "tenant", None) def get_tenant_column(model_class_or_instance): @@ -125,7 +131,7 @@ def get_tenant_filters(table, filters=None): def set_current_tenant(tenant): """ - Utils to set a tenant in the current thread. + Utils to set a tenant in the current thread/context. Often used in a middleware once a user is logged in to make sure all db calls are sharded to the current tenant. Can be used by doing: @@ -133,11 +139,11 @@ def set_current_tenant(tenant): get_current_tenant(my_class_object) ``` """ - setattr(_thread_locals, "tenant", tenant) + setattr(_context, "tenant", tenant) def unset_current_tenant(): - setattr(_thread_locals, "tenant", None) + setattr(_context, "tenant", None) def is_distributed_model(model): diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index 068ee0d..281cc65 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -4,6 +4,8 @@ # # pip-compile --output-file=requirements/test-requirements.txt --resolver=backtracking requirements/test.in # +asgiref==3.7.2 + # via -r requirements/test.in coverage[toml]==7.2.7 # via # -r requirements/test.in @@ -37,3 +39,8 @@ tomli==2.0.1 # via # coverage # pytest +typing-extensions==4.8.0 ; python_version >= "3.8" + # via + # -r requirements/test.in + # asgiref +typing-extensions==4.7.1 ; python_version < "3.8" diff --git a/requirements/test.in b/requirements/test.in index 33d78fb..901ad9c 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -5,3 +5,7 @@ pytest pytest-cov pytest-django exam +asgiref>= 3.5.2 +typing-extensions==4.8.0; python_version >= "3.8" +typing-extensions==4.7.1; python_version < "3.8" +