diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py b/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py index 4f0d90e4a298e..866a51db18f8c 100644 --- a/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py +++ b/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py @@ -22,8 +22,8 @@ import re from typing import TYPE_CHECKING, Any -import sqlalchemy import teradatasql +from sqlalchemy.engine import URL from teradatasql import TeradataConnection from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -34,6 +34,7 @@ except ImportError: from airflow.models.connection import Connection # type: ignore[assignment] +DEFAULT_DB_PORT = 1025 PARAM_TYPES = {bool, float, int, str} @@ -166,7 +167,7 @@ def _get_conn_config_teradatasql(self) -> dict[str, Any]: conn: Connection = self.get_connection(self.get_conn_id()) conn_config = { "host": conn.host or "localhost", - "dbs_port": conn.port or "1025", + "dbs_port": conn.port or DEFAULT_DB_PORT, "database": conn.schema or "", "user": conn.login or "dbc", "password": conn.password or "dbc", @@ -195,12 +196,32 @@ def _get_conn_config_teradatasql(self) -> dict[str, Any]: return conn_config - def get_sqlalchemy_engine(self, engine_kwargs=None): - """Return a connection object using sqlalchemy.""" - conn: Connection = self.get_connection(self.get_conn_id()) - link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}" - connection = sqlalchemy.create_engine(link) - return connection + @property + def sqlalchemy_url(self) -> URL: + """ + Override to return a Sqlalchemy.engine.URL object from the Teradata connection. + + :return: the extracted sqlalchemy.engine.URL object. + """ + connection = self.get_connection(self.get_conn_id()) + # Adding only teradatasqlalchemy supported connection parameters. + # https://pypi.org/project/teradatasqlalchemy/#ConnectionParameters + url_kwargs = { + "drivername": "teradatasql", + "username": connection.login, + "password": connection.password, + "host": connection.host, + "port": connection.port, + } + + if connection.schema: # Only include database if it's not None or empty + url_kwargs["database"] = connection.schema + + return URL.create(**url_kwargs) + + def get_uri(self) -> str: + """Override DbApiHook get_uri method for get_sqlalchemy_engine().""" + return self.sqlalchemy_url.render_as_string() @staticmethod def get_ui_field_behaviour() -> dict: diff --git a/providers/teradata/tests/unit/teradata/hooks/test_teradata.py b/providers/teradata/tests/unit/teradata/hooks/test_teradata.py index f9a38e0d607f0..f10c1e629d211 100644 --- a/providers/teradata/tests/unit/teradata/hooks/test_teradata.py +++ b/providers/teradata/tests/unit/teradata/hooks/test_teradata.py @@ -40,6 +40,11 @@ def setup_method(self): self.db_hook.get_connection.return_value = self.connection self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() + self.conn.login = "mock_login" + self.conn.password = "mock_password" + self.conn.host = "mock_host" + self.conn.schema = "mock_schema" + self.conn.port = 1025 self.conn.cursor.return_value = self.cur self.conn.extra_dejson = {} conn = self.conn @@ -53,6 +58,7 @@ def get_connection(cls, conn_id: str) -> Connection: return conn self.test_db_hook = UnitTestTeradataHook(teradata_conn_id="teradata_conn_id") + self.test_db_hook.get_uri = mock.Mock(return_value="sqlite://") @mock.patch("teradatasql.connect") def test_get_conn(self, mock_connect): @@ -62,7 +68,7 @@ def test_get_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" @@ -76,7 +82,7 @@ def test_get_tmode_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["tmode"] == "tera" @@ -91,7 +97,7 @@ def test_get_sslmode_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["sslmode"] == "require" @@ -106,7 +112,7 @@ def test_get_sslverifyca_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["sslmode"] == "verify-ca" @@ -122,7 +128,7 @@ def test_get_sslverifyfull_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["sslmode"] == "verify-full" @@ -138,7 +144,7 @@ def test_get_sslcrc_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["sslcrc"] == "sslcrc" @@ -153,7 +159,7 @@ def test_get_sslprotocol_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["sslprotocol"] == "protocol" @@ -168,25 +174,25 @@ def test_get_sslcipher_conn(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert kwargs["sslcipher"] == "cipher" - @mock.patch("sqlalchemy.create_engine") - def test_get_sqlalchemy_conn(self, mock_connect): - self.db_hook.get_sqlalchemy_engine() - assert mock_connect.call_count == 1 - args = mock_connect.call_args.args - assert len(args) == 1 - expected_link = ( - f"teradatasql://{self.connection.login}:{self.connection.password}@{self.connection.host}" - ) - assert expected_link == args[0] + def test_get_uri_without_schema(self): + self.connection.schema = "" # simulate missing schema + self.db_hook.get_connection.return_value = self.connection + uri = self.db_hook.get_uri() + expected_uri = f"teradatasql://{self.connection.login}:***@{self.connection.host}" + assert uri == expected_uri def test_get_uri(self): ret_uri = self.db_hook.get_uri() - expected_uri = f"teradata://{self.connection.login}:{self.connection.password}@{self.connection.host}/{self.connection.schema}" + expected_uri = ( + f"teradatasql://{self.connection.login}:***@{self.connection.host}/{self.connection.schema}" + if self.connection.schema + else f"teradatasql://{self.connection.login}:***@{self.connection.host}" + ) assert expected_uri == ret_uri def test_get_records(self): @@ -260,7 +266,7 @@ def test_query_band_not_in_conn_config(self, mock_connect): assert args == () assert kwargs["host"] == "host" assert kwargs["database"] == "schema" - assert kwargs["dbs_port"] == "1025" + assert kwargs["dbs_port"] == 1025 assert kwargs["user"] == "login" assert kwargs["password"] == "password" assert "query_band" not in kwargs