Skip to content

Commit

Permalink
Use asgiref when available instead of thread locals (citusdata#176)
Browse files Browse the repository at this point in the history
add asgiref as test dependency

add tests for asgiref

add setting TENANT_USE_ASGIREF, updated tests
  • Loading branch information
darwing1210 committed Jul 12, 2023
1 parent 30e3238 commit 3465bfc
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 17 deletions.
1 change: 1 addition & 0 deletions django_multitenant/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
TENANT_MODEL_NAME = getattr(settings, "TENANT_MODEL_NAME", None)
CITUS_EXTENSION_INSTALLED = getattr(settings, "CITUS_EXTENSION_INSTALLED", False)
TENANT_STRICT_MODE = getattr(settings, "TENANT_STRICT_MODE", False)
TENANT_USE_ASGIREF = getattr(settings, "TENANT_USE_ASGIREF", False)
3 changes: 3 additions & 0 deletions django_multitenant/tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@

DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
USE_TZ = True

TENANT_USE_ASGIREF = False

55 changes: 55 additions & 0 deletions django_multitenant/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import asyncio
import sys, importlib
from asgiref.sync import async_to_sync, sync_to_async


from django_multitenant.utils import (
set_current_tenant,
get_current_tenant,
Expand All @@ -11,6 +16,12 @@


class UtilsTest(BaseTestCase):
async def async_get_current_tenant(self):
return get_current_tenant()

async def async_set_current_tenant(self, tenant):
return set_current_tenant(tenant)

def test_set_current_tenant(self):
projects = self.projects
account = projects[0].account
Expand All @@ -19,6 +30,50 @@ def test_set_current_tenant(self):
self.assertEqual(get_current_tenant(), account)
unset_current_tenant()

def test_tenant_persists_from_thread_to_async_task(self):
projects = self.projects
account = projects[0].account

# Set the tenant in main thread
set_current_tenant(account)

with self.settings(TENANT_USE_ASGIREF=True):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Check the tenant within an async task when asgiref enabled
tenant = async_to_sync(self.async_get_current_tenant)()
self.assertEqual(get_current_tenant(), tenant)
unset_current_tenant()

with self.settings(TENANT_USE_ASGIREF=False):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Check the tenant within an async task when asgiref is disabled
tenant = async_to_sync(self.async_get_current_tenant)()
self.assertIsNone(get_current_tenant())
unset_current_tenant()

def test_tenant_persists_from_async_task_to_thread(self):
projects = self.projects
account = projects[0].account

with self.settings(TENANT_USE_ASGIREF=True):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Set the tenant in task
async_to_sync(self.async_set_current_tenant)(account)
self.assertEqual(get_current_tenant(), account)
unset_current_tenant()

with self.settings(TENANT_USE_ASGIREF=False):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Set the tenant in task
async_to_sync(self.async_set_current_tenant)(account)
self.assertIsNone(get_current_tenant())
unset_current_tenant()


def test_get_tenant_column(self):
from .models import Project

Expand Down
26 changes: 16 additions & 10 deletions django_multitenant/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import inspect

from django.apps import apps
from django.conf import settings

try:
from threading import local
except ImportError:
from django.utils._threading_local import local

if settings.TENANT_USE_ASGIREF:
# asgiref must be installed, its included with Django >= 3.0
from asgiref.local import Local as local
else:
try:
from threading import local
except ImportError:
from django.utils._threading_local import local


_thread_locals = local()
_thread_locals = _context = local()


def get_model_by_db_table(db_table):
Expand All @@ -26,14 +32,14 @@ def get_model_by_db_table(db_table):

def get_current_tenant():
"""
Utils to get the tenant that hass been set in the current thread using `set_current_tenant`.
Utils to get the tenant that hass been set in the current thread/context using `set_current_tenant`.
Can be used by doing:
```
my_class_object = get_current_tenant()
```
Will return None if the tenant is not set
"""
return getattr(_thread_locals, "tenant", None)
return getattr(_context, "tenant", None)


def get_tenant_column(model_class_or_instance):
Expand Down Expand Up @@ -125,19 +131,19 @@ def get_tenant_filters(table, filters=None):

def set_current_tenant(tenant):
"""
Utils to set a tenant in the current thread.
Utils to set a tenant in the current thread/context.
Often used in a middleware once a user is logged in to make sure all db
calls are sharded to the current tenant.
Can be used by doing:
```
get_current_tenant(my_class_object)
```
"""
setattr(_thread_locals, "tenant", tenant)
setattr(_context, "tenant", tenant)


def unset_current_tenant():
setattr(_thread_locals, "tenant", None)
setattr(_context, "tenant", None)


def is_distributed_model(model):
Expand Down
10 changes: 3 additions & 7 deletions requirements/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --output-file=requirements/test-requirements.txt --resolver=backtracking requirements/test.in
#
asgiref==3.7.2
# via -r requirements/test.in
coverage[toml]==7.2.7
# via pytest-cov
exam==0.10.6
# via -r requirements/test.in
exceptiongroup==1.1.2
# via pytest
iniconfig==2.0.0
# via pytest
mock==5.0.2
Expand All @@ -29,7 +29,3 @@ pytest-cov==4.1.0
# via -r requirements/test.in
pytest-django==4.5.2
# via -r requirements/test.in
tomli==2.0.1
# via
# coverage
# pytest
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest
pytest-cov
pytest-django
exam
asgiref>= 3.5.2

0 comments on commit 3465bfc

Please sign in to comment.