diff --git a/django_postgrespool/base.py b/django_postgrespool/base.py index d701467..3292980 100644 --- a/django_postgrespool/base.py +++ b/django_postgrespool/base.py @@ -8,10 +8,12 @@ from psycopg2 import InterfaceError, ProgrammingError, OperationalError from django.conf import settings +from django.core.exceptions import ImproperlyConfigured from django.db.backends.postgresql_psycopg2.base import * from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper as Psycopg2DatabaseWrapper from django.db.backends.postgresql_psycopg2.base import CursorWrapper as DjangoCursorWrapper from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation as Psycopg2DatabaseCreation +from django.utils.importlib import import_module POOL_SETTINGS = 'DATABASE_POOL_ARGS' @@ -57,6 +59,27 @@ def is_disconnect(e, connection, cursor): return False +def get_class(import_path): + try: + dot = import_path.rindex('.') + except ValueError: + raise ImproperlyConfigured("%s isn't a valid module." % import_path) + + module, classname = import_path[:dot], import_path[dot + 1:] + + try: + mod = import_module(module) + except ImportError as e: + raise ImproperlyConfigured('Error importing module %s: ' + '"%s"' % (module, e)) + + try: + return getattr(mod, classname) + except AttributeError: + raise ImproperlyConfigured('Module "%s" does not define a ' + '"%s" class.' % (module, classname)) + + class CursorWrapper(DjangoCursorWrapper): """ A thin wrapper around psycopg2's normal cursor class so that we can catch @@ -105,7 +128,26 @@ class DatabaseWrapper(Psycopg2DatabaseWrapper): def __init__(self, *args, **kwargs): super(DatabaseWrapper, self).__init__(*args, **kwargs) - self.creation = DatabaseCreation(self) + + if 'FEATURES_CLASS' in self.settings_dict: + autocommit = self.features.uses_autocommit + self.features = get_class(self.settings_dict['FEATURES_CLASS'])(self) + self.features.uses_autocommit = autocommit + + if 'OPERATIONS_CLASS' in self.settings_dict: + self.ops = get_class(self.settings_dict['OPERATIONS_CLASS'])(self) + + if 'CLIENT_CLASS' in self.settings_dict: + self.client = get_class(self.settings_dict['CLIENT_CLASS'])(self) + + if 'CREATION_CLASS' in self.settings_dict: + self.creation = get_class(self.settings_dict['CREATION_CLASS'])(self) + + if 'INTROSPECTION_CLASS' in self.settings_dict: + self.introspection = get_class(self.settings_dict['INTROSPECTION_CLASS'])(self) + + if 'VALIDATION_CLASS' in self.settings_dict: + self.validation = get_class(self.settings_dict['VALIDATION_CLASS'])(self) def _cursor(self): if self.connection is None or self.connection.is_valid == False: @@ -159,7 +201,6 @@ def _dispose(self): def _get_conn_params(self): settings_dict = self.settings_dict if not settings_dict['NAME']: - from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured( "settings.DATABASES is improperly configured. " "Please supply the NAME value.")