Skip to content

Commit

Permalink
Add schema as DbApiHook instance attribute (#16521)
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHong authored Jun 23, 2021
1 parent 86c2091 commit 3ee916e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
9 changes: 4 additions & 5 deletions airflow/hooks/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from contextlib import closing
from datetime import datetime
from typing import Any, Optional
from urllib.parse import quote_plus
from urllib.parse import quote_plus, urlunsplit

from sqlalchemy import create_engine

Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(self, *args, **kwargs):
setattr(self, self.conn_name_attr, self.default_conn_name)
else:
setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr])
self.schema: Optional[str] = kwargs.pop("schema", None)

def get_conn(self):
"""Returns a connection object"""
Expand All @@ -83,10 +84,8 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
uri = f'{conn.conn_type}://{login}{host}/'
if conn.schema:
uri += conn.schema
return uri
schema = self.schema or conn.schema or ''
return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', ''))

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class PostgresHook(DbApiHook):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.schema: Optional[str] = kwargs.pop("schema", None)
self.connection: Optional[Connection] = kwargs.pop("connection", None)
self.conn: connection = None

Expand Down
16 changes: 15 additions & 1 deletion tests/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,27 @@ def test_get_uri_schema_not_none(self):
)
assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_schema_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type",
host="host",
login="login",
password="password",
schema="schema",
port=1,
)
)
self.db_hook.schema = 'schema-override'
assert "conn_type://login:password@host:1/schema-override" == self.db_hook.get_uri()

def test_get_uri_schema_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1
)
)
assert "conn_type://login:password@host:1/" == self.db_hook.get_uri()
assert "conn_type://login:password@host:1" == self.db_hook.get_uri()

def test_get_uri_special_characters(self):
self.db_hook.get_connection = mock.MagicMock(
Expand Down
27 changes: 27 additions & 0 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,33 @@ def test_get_conn_rds_iam_redshift(self, mock_client, mock_connect):
[get_cluster_credentials_call, get_cluster_credentials_call]
)

def test_get_uri_from_connection_without_schema_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="postgres",
host="host",
login="login",
password="password",
schema="schema",
port=1,
)
)
assert "postgres://login:password@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_from_connection_with_schema_override(self):
hook = PostgresHook(schema='schema-override')
hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="postgres",
host="host",
login="login",
password="password",
schema="schema",
port=1,
)
)
assert "postgres://login:password@host:1/schema-override" == hook.get_uri()


class TestPostgresHook(unittest.TestCase):
def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit 3ee916e

Please sign in to comment.