Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import re
from typing import TYPE_CHECKING, Any

import sqlalchemy
import teradatasql
from sqlalchemy.engine import URL
from teradatasql import TeradataConnection

from airflow.providers.common.sql.hooks.sql import DbApiHook
Expand All @@ -34,6 +34,7 @@
except ImportError:
from airflow.models.connection import Connection # type: ignore[assignment]

DEFAULT_DB_PORT = 1025
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably could add a link to the doc that mentions 1025 as a comment above this line.

PARAM_TYPES = {bool, float, int, str}


Expand Down Expand Up @@ -166,7 +167,7 @@ def _get_conn_config_teradatasql(self) -> dict[str, Any]:
conn: Connection = self.get_connection(self.get_conn_id())
conn_config = {
"host": conn.host or "localhost",
"dbs_port": conn.port or "1025",
"dbs_port": conn.port or DEFAULT_DB_PORT,
"database": conn.schema or "",
"user": conn.login or "dbc",
"password": conn.password or "dbc",
Expand Down Expand Up @@ -195,12 +196,32 @@ def _get_conn_config_teradatasql(self) -> dict[str, Any]:

return conn_config

def get_sqlalchemy_engine(self, engine_kwargs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically a breaking change. @eladkal do we need to do anything?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sc250072 can you raise PR to re add the function to preserve backward compatibility and raise deprecation warning for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fall back to this method and return the same URL from sqlalchemy_url. But any case,
I’ll re-add the function with a deprecation warning to preserve backward compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically a breaking change. @eladkal do we need to do anything?

Could you please clarify what you mean by a breaking change?
Are you referring to the latest Teradata provider not being compatible with older Airflow versions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one who uses older Teradata provider might break. (as this method has been removed and might be used), but if we're to bump major version. that should be fine

"""Return a connection object using sqlalchemy."""
conn: Connection = self.get_connection(self.get_conn_id())
link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}"
connection = sqlalchemy.create_engine(link)
return connection
@property
def sqlalchemy_url(self) -> URL:
"""
Override to return a Sqlalchemy.engine.URL object from the Teradata connection.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Override to return a Sqlalchemy.engine.URL object from the Teradata connection.
Override to return a `sqlalchemy.engine.URL` object from the Teradata connection.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused. override what to return the object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be implemented in the provider subclass to return the sqlalchemy.engine.URL object.

@Lee-W to return the sqlalchemy.engine.URL object
Will add the suggestion


:return: the extracted sqlalchemy.engine.URL object.
"""
connection = self.get_connection(self.get_conn_id())
# Adding only teradatasqlalchemy supported connection parameters.
# https://pypi.org/project/teradatasqlalchemy/#ConnectionParameters
url_kwargs = {
"drivername": "teradatasql",
"username": connection.login,
"password": connection.password,
"host": connection.host,
"port": connection.port,
}

if connection.schema: # Only include database if it's not None or empty
url_kwargs["database"] = connection.schema

return URL.create(**url_kwargs)

def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
return self.sqlalchemy_url.render_as_string()

@staticmethod
def get_ui_field_behaviour() -> dict:
Expand Down
46 changes: 26 additions & 20 deletions providers/teradata/tests/unit/teradata/hooks/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def setup_method(self):
self.db_hook.get_connection.return_value = self.connection
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.login = "mock_login"
self.conn.password = "mock_password"
self.conn.host = "mock_host"
self.conn.schema = "mock_schema"
self.conn.port = 1025
self.conn.cursor.return_value = self.cur
self.conn.extra_dejson = {}
conn = self.conn
Expand All @@ -53,6 +58,7 @@ def get_connection(cls, conn_id: str) -> Connection:
return conn

self.test_db_hook = UnitTestTeradataHook(teradata_conn_id="teradata_conn_id")
self.test_db_hook.get_uri = mock.Mock(return_value="sqlite://")

@mock.patch("teradatasql.connect")
def test_get_conn(self, mock_connect):
Expand All @@ -62,7 +68,7 @@ def test_get_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"

Expand All @@ -76,7 +82,7 @@ def test_get_tmode_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["tmode"] == "tera"
Expand All @@ -91,7 +97,7 @@ def test_get_sslmode_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["sslmode"] == "require"
Expand All @@ -106,7 +112,7 @@ def test_get_sslverifyca_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["sslmode"] == "verify-ca"
Expand All @@ -122,7 +128,7 @@ def test_get_sslverifyfull_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["sslmode"] == "verify-full"
Expand All @@ -138,7 +144,7 @@ def test_get_sslcrc_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["sslcrc"] == "sslcrc"
Expand All @@ -153,7 +159,7 @@ def test_get_sslprotocol_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["sslprotocol"] == "protocol"
Expand All @@ -168,25 +174,25 @@ def test_get_sslcipher_conn(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["sslcipher"] == "cipher"

@mock.patch("sqlalchemy.create_engine")
def test_get_sqlalchemy_conn(self, mock_connect):
self.db_hook.get_sqlalchemy_engine()
assert mock_connect.call_count == 1
args = mock_connect.call_args.args
assert len(args) == 1
expected_link = (
f"teradatasql://{self.connection.login}:{self.connection.password}@{self.connection.host}"
)
assert expected_link == args[0]
def test_get_uri_without_schema(self):
self.connection.schema = "" # simulate missing schema
self.db_hook.get_connection.return_value = self.connection
uri = self.db_hook.get_uri()
expected_uri = f"teradatasql://{self.connection.login}:***@{self.connection.host}"
assert uri == expected_uri

def test_get_uri(self):
ret_uri = self.db_hook.get_uri()
expected_uri = f"teradata://{self.connection.login}:{self.connection.password}@{self.connection.host}/{self.connection.schema}"
expected_uri = (
f"teradatasql://{self.connection.login}:***@{self.connection.host}/{self.connection.schema}"
if self.connection.schema
else f"teradatasql://{self.connection.login}:***@{self.connection.host}"
)
assert expected_uri == ret_uri

def test_get_records(self):
Expand Down Expand Up @@ -260,7 +266,7 @@ def test_query_band_not_in_conn_config(self, mock_connect):
assert args == ()
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"
assert kwargs["dbs_port"] == "1025"
assert kwargs["dbs_port"] == 1025
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert "query_band" not in kwargs
Expand Down
Loading