From 2fb537f72f11da040831492a8d0db7630ff54536 Mon Sep 17 00:00:00 2001 From: Luke Hong Date: Wed, 23 Jun 2021 13:55:06 +0800 Subject: [PATCH] Add schema as DbApiHook instance attribute --- airflow/hooks/dbapi.py | 9 +++---- airflow/providers/postgres/hooks/postgres.py | 1 - tests/hooks/test_dbapi.py | 16 ++++++++++- .../providers/postgres/hooks/test_postgres.py | 27 +++++++++++++++++++ 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py index 031c2217ca9ea..8653d53e9c931 100644 --- a/airflow/hooks/dbapi.py +++ b/airflow/hooks/dbapi.py @@ -18,7 +18,7 @@ from contextlib import closing from datetime import datetime from typing import Any, Optional -from urllib.parse import quote_plus +from urllib.parse import quote_plus, urlunsplit from sqlalchemy import create_engine @@ -64,6 +64,7 @@ def __init__(self, *args, **kwargs): setattr(self, self.conn_name_attr, self.default_conn_name) else: setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr]) + self.schema: Optional[str] = kwargs.pop("schema", None) def get_conn(self): """Returns a connection object""" @@ -83,10 +84,8 @@ def get_uri(self) -> str: host = conn.host if conn.port is not None: host += f':{conn.port}' - uri = f'{conn.conn_type}://{login}{host}/' - if conn.schema: - uri += conn.schema - return uri + schema = self.schema or conn.schema or '' + return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', '')) def get_sqlalchemy_engine(self, engine_kwargs=None): """ diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 1a78e12595083..91a78f6da98ae 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -68,7 +68,6 @@ class PostgresHook(DbApiHook): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.schema: Optional[str] = kwargs.pop("schema", None) self.connection: Optional[Connection] = kwargs.pop("connection", None) self.conn: connection = None diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py index 383d69e16a5ae..5c60ab5bc9847 100644 --- a/tests/hooks/test_dbapi.py +++ b/tests/hooks/test_dbapi.py @@ -159,13 +159,27 @@ def test_get_uri_schema_not_none(self): ) assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri() + def test_get_uri_schema_override(self): + self.db_hook.get_connection = mock.MagicMock( + return_value=Connection( + conn_type="conn_type", + host="host", + login="login", + password="password", + schema="schema", + port=1, + ) + ) + self.db_hook.schema = 'schema-override' + assert "conn_type://login:password@host:1/schema-override" == self.db_hook.get_uri() + def test_get_uri_schema_none(self): self.db_hook.get_connection = mock.MagicMock( return_value=Connection( conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1 ) ) - assert "conn_type://login:password@host:1/" == self.db_hook.get_uri() + assert "conn_type://login:password@host:1" == self.db_hook.get_uri() def test_get_uri_special_characters(self): self.db_hook.get_connection = mock.MagicMock( diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index 9a0226fd3798f..ed7a5750ed5cd 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -139,6 +139,33 @@ def test_get_conn_rds_iam_redshift(self, mock_client, mock_connect): [get_cluster_credentials_call, get_cluster_credentials_call] ) + def test_get_uri_from_connection_without_schema_override(self): + self.db_hook.get_connection = mock.MagicMock( + return_value=Connection( + conn_type="postgres", + host="host", + login="login", + password="password", + schema="schema", + port=1, + ) + ) + assert "postgres://login:password@host:1/schema" == self.db_hook.get_uri() + + def test_get_uri_from_connection_with_schema_override(self): + hook = PostgresHook(schema='schema-override') + hook.get_connection = mock.MagicMock( + return_value=Connection( + conn_type="postgres", + host="host", + login="login", + password="password", + schema="schema", + port=1, + ) + ) + assert "postgres://login:password@host:1/schema-override" == hook.get_uri() + class TestPostgresHook(unittest.TestCase): def __init__(self, *args, **kwargs):