Skip to content
This repository was archived by the owner on Aug 26, 2021. It is now read-only.

Allows overriding DatabaseWrapper collaborators per DATABASE #16

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions django_postgrespool/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from psycopg2 import InterfaceError, ProgrammingError, OperationalError

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.postgresql_psycopg2.base import *
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper as Psycopg2DatabaseWrapper
from django.db.backends.postgresql_psycopg2.base import CursorWrapper as DjangoCursorWrapper
from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation as Psycopg2DatabaseCreation
from django.utils.importlib import import_module

POOL_SETTINGS = 'DATABASE_POOL_ARGS'

Expand Down Expand Up @@ -57,6 +59,27 @@ def is_disconnect(e, connection, cursor):
return False


def get_class(import_path):
try:
dot = import_path.rindex('.')
except ValueError:
raise ImproperlyConfigured("%s isn't a valid module." % import_path)

module, classname = import_path[:dot], import_path[dot + 1:]

try:
mod = import_module(module)
except ImportError as e:
raise ImproperlyConfigured('Error importing module %s: '
'"%s"' % (module, e))

try:
return getattr(mod, classname)
except AttributeError:
raise ImproperlyConfigured('Module "%s" does not define a '
'"%s" class.' % (module, classname))


class CursorWrapper(DjangoCursorWrapper):
"""
A thin wrapper around psycopg2's normal cursor class so that we can catch
Expand Down Expand Up @@ -105,7 +128,26 @@ class DatabaseWrapper(Psycopg2DatabaseWrapper):

def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.creation = DatabaseCreation(self)

if 'FEATURES_CLASS' in self.settings_dict:
autocommit = self.features.uses_autocommit
self.features = get_class(self.settings_dict['FEATURES_CLASS'])(self)
self.features.uses_autocommit = autocommit

if 'OPERATIONS_CLASS' in self.settings_dict:
self.ops = get_class(self.settings_dict['OPERATIONS_CLASS'])(self)

if 'CLIENT_CLASS' in self.settings_dict:
self.client = get_class(self.settings_dict['CLIENT_CLASS'])(self)

if 'CREATION_CLASS' in self.settings_dict:
self.creation = get_class(self.settings_dict['CREATION_CLASS'])(self)

if 'INTROSPECTION_CLASS' in self.settings_dict:
self.introspection = get_class(self.settings_dict['INTROSPECTION_CLASS'])(self)

if 'VALIDATION_CLASS' in self.settings_dict:
self.validation = get_class(self.settings_dict['VALIDATION_CLASS'])(self)

def _cursor(self):
if self.connection is None or self.connection.is_valid == False:
Expand Down Expand Up @@ -159,7 +201,6 @@ def _dispose(self):
def _get_conn_params(self):
settings_dict = self.settings_dict
if not settings_dict['NAME']:
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value.")
Expand Down