Skip to content

Commit

Permalink
Make schema in DBApiHook private
Browse files Browse the repository at this point in the history
There was a change in apache#16521 that introduced schema field in
DBApiHook, but unfortunately using it in provider Hooks deriving
from DBApiHook is backwards incompatible for Airflow 2.1 and below.

This caused Postgres 2.1.0 release backwards incompatibility and
failures for Airflow 2.1.0.

Since the change is small and most of DBApi-derived hooks already
set the schema field on their own, the best approach is to
make the schema field private for the DBApiHook and make a change
in Postgres Hook to store the schema in the same way as all other
operators.

Fixes: apache#17422
  • Loading branch information
potiuk committed Aug 7, 2021
1 parent cb0b895 commit b9801ec
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
19 changes: 15 additions & 4 deletions airflow/hooks/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def connect(self, host: str, port: int, username: str, schema: str) -> Any:
# #
#########################################################################################
class DbApiHook(BaseHook):
"""Abstract base class for sql hooks."""
"""
Abstract base class for sql hooks.
:param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that
if you change the schema parameter value in the constructor of the derived Hook, such change
should be done before calling the ``DBApiHook.__init__()``.
:type schema: Optional[str]
"""

# Override to provide the connection name.
conn_name_attr = None # type: str
Expand All @@ -62,7 +69,7 @@ class DbApiHook(BaseHook):
# Override with the object that exposes the connect method
connector = None # type: Optional[ConnectorProtocol]

def __init__(self, *args, **kwargs):
def __init__(self, *args, schema: Optional[str] = None, **kwargs):
super().__init__()
if not self.conn_name_attr:
raise AirflowException("conn_name_attr is not defined")
Expand All @@ -72,7 +79,11 @@ def __init__(self, *args, **kwargs):
setattr(self, self.conn_name_attr, self.default_conn_name)
else:
setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr])
self.schema: Optional[str] = kwargs.pop("schema", None)
# We should not make schema available in deriving hooks for backwards compatibility
# If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it
# from kwargs and store it on its own. We do not run "pop" here as we want to give the
# Hook deriving from the DBApiHook to still have access to the field in it's constructor
self.__schema = schema

def get_conn(self):
"""Returns a connection object"""
Expand All @@ -92,7 +103,7 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
schema = self.schema or conn.schema or ''
schema = self.__schema or conn.schema or ''
return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', ''))

def get_sqlalchemy_engine(self, engine_kwargs=None):
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.connection: Optional[Connection] = kwargs.pop("connection", None)
self.conn: connection = None
self.schema: Optional[str] = kwargs.pop("schema", None)

def _get_cursor(self, raw_cursor: str) -> CursorType:
_cursor = raw_cursor.lower()
Expand Down
6 changes: 3 additions & 3 deletions tests/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_conn(self):
return conn

self.db_hook = UnitTestDbApiHook()
self.db_hook_schema_override = UnitTestDbApiHook(schema='schema-override')

def test_get_records(self):
statement = "SQL"
Expand Down Expand Up @@ -160,7 +161,7 @@ def test_get_uri_schema_not_none(self):
assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_schema_override(self):
self.db_hook.get_connection = mock.MagicMock(
self.db_hook_schema_override.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type",
host="host",
Expand All @@ -170,8 +171,7 @@ def test_get_uri_schema_override(self):
port=1,
)
)
self.db_hook.schema = 'schema-override'
assert "conn_type://login:password@host:1/schema-override" == self.db_hook.get_uri()
assert "conn_type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()

def test_get_uri_schema_none(self):
self.db_hook.get_connection = mock.MagicMock(
Expand Down

0 comments on commit b9801ec

Please sign in to comment.