Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asgiref when available instead of thread locals (#176) #198

Merged
merged 5 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions django_multitenant/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
TENANT_MODEL_NAME = getattr(settings, "TENANT_MODEL_NAME", None)
CITUS_EXTENSION_INSTALLED = getattr(settings, "CITUS_EXTENSION_INSTALLED", False)
TENANT_STRICT_MODE = getattr(settings, "TENANT_STRICT_MODE", False)
TENANT_USE_ASGIREF = getattr(settings, "TENANT_USE_ASGIREF", False)
2 changes: 2 additions & 0 deletions django_multitenant/tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@

DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
USE_TZ = True

TENANT_USE_ASGIREF = False
54 changes: 54 additions & 0 deletions django_multitenant/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import sys
import importlib
from asgiref.sync import async_to_sync


from django_multitenant.utils import (
set_current_tenant,
get_current_tenant,
Expand All @@ -11,6 +16,12 @@


class UtilsTest(BaseTestCase):
async def async_get_current_tenant(self):
return get_current_tenant()

async def async_set_current_tenant(self, tenant):
return set_current_tenant(tenant)

def test_set_current_tenant(self):
projects = self.projects
account = projects[0].account
Expand All @@ -19,6 +30,49 @@ def test_set_current_tenant(self):
self.assertEqual(get_current_tenant(), account)
unset_current_tenant()

def test_tenant_persists_from_thread_to_async_task(self):
projects = self.projects
account = projects[0].account

# Set the tenant in main thread
set_current_tenant(account)

with self.settings(TENANT_USE_ASGIREF=True):
importlib.reload(sys.modules["django_multitenant.utils"])

# Check the tenant within an async task when asgiref enabled
tenant = async_to_sync(self.async_get_current_tenant)()
self.assertEqual(get_current_tenant(), tenant)
unset_current_tenant()

with self.settings(TENANT_USE_ASGIREF=False):
importlib.reload(sys.modules["django_multitenant.utils"])

# Check the tenant within an async task when asgiref is disabled
tenant = async_to_sync(self.async_get_current_tenant)()
self.assertIsNone(get_current_tenant())
unset_current_tenant()

def test_tenant_persists_from_async_task_to_thread(self):
projects = self.projects
account = projects[0].account

with self.settings(TENANT_USE_ASGIREF=True):
importlib.reload(sys.modules["django_multitenant.utils"])

# Set the tenant in task
async_to_sync(self.async_set_current_tenant)(account)
self.assertEqual(get_current_tenant(), account)
unset_current_tenant()

with self.settings(TENANT_USE_ASGIREF=False):
importlib.reload(sys.modules["django_multitenant.utils"])

# Set the tenant in task
async_to_sync(self.async_set_current_tenant)(account)
self.assertIsNone(get_current_tenant())
unset_current_tenant()

def test_get_tenant_column(self):
from .models import Project

Expand Down
26 changes: 16 additions & 10 deletions django_multitenant/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import inspect

from django.apps import apps
from django.conf import settings

try:
from threading import local
except ImportError:
from django.utils._threading_local import local

if settings.TENANT_USE_ASGIREF:
# asgiref must be installed, its included with Django >= 3.0
from asgiref.local import Local as local
else:
try:
from threading import local
except ImportError:
from django.utils._threading_local import local


_thread_locals = local()
_thread_locals = _context = local()


def get_model_by_db_table(db_table):
Expand All @@ -26,14 +32,14 @@ def get_model_by_db_table(db_table):

def get_current_tenant():
"""
Utils to get the tenant that hass been set in the current thread using `set_current_tenant`.
Utils to get the tenant that hass been set in the current thread/context using `set_current_tenant`.
Can be used by doing:
```
my_class_object = get_current_tenant()
```
Will return None if the tenant is not set
"""
return getattr(_thread_locals, "tenant", None)
return getattr(_context, "tenant", None)


def get_tenant_column(model_class_or_instance):
Expand Down Expand Up @@ -125,19 +131,19 @@ def get_tenant_filters(table, filters=None):

def set_current_tenant(tenant):
"""
Utils to set a tenant in the current thread.
Utils to set a tenant in the current thread/context.
Often used in a middleware once a user is logged in to make sure all db
calls are sharded to the current tenant.
Can be used by doing:
```
get_current_tenant(my_class_object)
```
"""
setattr(_thread_locals, "tenant", tenant)
setattr(_context, "tenant", tenant)


def unset_current_tenant():
setattr(_thread_locals, "tenant", None)
setattr(_context, "tenant", None)


def is_distributed_model(model):
Expand Down
7 changes: 7 additions & 0 deletions requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#
# pip-compile --output-file=requirements/test-requirements.txt --resolver=backtracking requirements/test.in
#
asgiref==3.7.2
# via -r requirements/test.in
coverage[toml]==7.2.7
# via
# -r requirements/test.in
Expand Down Expand Up @@ -37,3 +39,8 @@ tomli==2.0.1
# via
# coverage
# pytest
typing-extensions==4.8.0 ; python_version >= "3.8"
# via
# -r requirements/test.in
# asgiref
typing-extensions==4.7.1 ; python_version < "3.8"
4 changes: 4 additions & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ pytest
pytest-cov
pytest-django
exam
asgiref>= 3.5.2
typing-extensions==4.8.0; python_version >= "3.8"
typing-extensions==4.7.1; python_version < "3.8"