diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py index 1e1f4c7094ba2..bac75a2a971b0 100644 --- a/airflow/hooks/dbapi.py +++ b/airflow/hooks/dbapi.py @@ -51,7 +51,14 @@ def connect(self, host: str, port: int, username: str, schema: str) -> Any: # # ######################################################################################### class DbApiHook(BaseHook): - """Abstract base class for sql hooks.""" + """ + Abstract base class for sql hooks. + + :param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that + if you change the schema parameter value in the constructor of the derived Hook, such change + should be done before calling the ``DBApiHook.__init__()``. + :type schema: Optional[str] + """ # Override to provide the connection name. conn_name_attr = None # type: str @@ -62,7 +69,7 @@ class DbApiHook(BaseHook): # Override with the object that exposes the connect method connector = None # type: Optional[ConnectorProtocol] - def __init__(self, *args, **kwargs): + def __init__(self, *args, schema: Optional[str] = None, **kwargs): super().__init__() if not self.conn_name_attr: raise AirflowException("conn_name_attr is not defined") @@ -72,7 +79,11 @@ def __init__(self, *args, **kwargs): setattr(self, self.conn_name_attr, self.default_conn_name) else: setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr]) - self.schema: Optional[str] = kwargs.pop("schema", None) + # We should not make schema available in deriving hooks for backwards compatibility + # If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it + # from kwargs and store it on its own. We do not run "pop" here as we want to give the + # Hook deriving from the DBApiHook to still have access to the field in it's constructor + self.__schema = schema def get_conn(self): """Returns a connection object""" @@ -92,7 +103,7 @@ def get_uri(self) -> str: host = conn.host if conn.port is not None: host += f':{conn.port}' - schema = self.schema or conn.schema or '' + schema = self.__schema or conn.schema or '' return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', '')) def get_sqlalchemy_engine(self, engine_kwargs=None): diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 4e54cf926249a..446b6f3b1327a 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -70,6 +70,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.connection: Optional[Connection] = kwargs.pop("connection", None) self.conn: connection = None + self.schema: Optional[str] = kwargs.pop("schema", None) def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py index 5c60ab5bc9847..97e2c4a17a2d6 100644 --- a/tests/hooks/test_dbapi.py +++ b/tests/hooks/test_dbapi.py @@ -43,6 +43,7 @@ def get_conn(self): return conn self.db_hook = UnitTestDbApiHook() + self.db_hook_schema_override = UnitTestDbApiHook(schema='schema-override') def test_get_records(self): statement = "SQL" @@ -160,7 +161,7 @@ def test_get_uri_schema_not_none(self): assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri() def test_get_uri_schema_override(self): - self.db_hook.get_connection = mock.MagicMock( + self.db_hook_schema_override.get_connection = mock.MagicMock( return_value=Connection( conn_type="conn_type", host="host", @@ -170,8 +171,7 @@ def test_get_uri_schema_override(self): port=1, ) ) - self.db_hook.schema = 'schema-override' - assert "conn_type://login:password@host:1/schema-override" == self.db_hook.get_uri() + assert "conn_type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri() def test_get_uri_schema_none(self): self.db_hook.get_connection = mock.MagicMock(