diff --git a/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py b/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py index e920e7e54b001..417ee56b1de17 100644 --- a/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py +++ b/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py @@ -24,6 +24,7 @@ import pyexasol from deprecated import deprecated from pyexasol import ExaConnection, ExaStatement +from sqlalchemy.engine import URL from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.sql.hooks.handlers import return_single_query_results @@ -53,10 +54,12 @@ class ExasolHook(DbApiHook): conn_type = "exasol" hook_name = "Exasol" supports_autocommit = True + DEFAULT_SQLALCHEMY_SCHEME = "exa+websocket" # sqlalchemy-exasol dialect - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, sqlalchemy_scheme: str | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None) + self._sqlalchemy_scheme = sqlalchemy_scheme def get_conn(self) -> ExaConnection: conn = self.get_connection(self.get_conn_id()) @@ -74,6 +77,46 @@ def get_conn(self) -> ExaConnection: conn = pyexasol.connect(**conn_args) return conn + @property + def sqlalchemy_scheme(self) -> str: + """Sqlalchemy scheme either from constructor, connection extras or default.""" + extra_scheme = self.connection is not None and self.connection_extra_lower.get("sqlalchemy_scheme") + sqlalchemy_scheme = self._sqlalchemy_scheme or extra_scheme or self.DEFAULT_SQLALCHEMY_SCHEME + if sqlalchemy_scheme not in ["exa+websocket", "exa+pyodbc", "exa+turbodbc"]: + raise ValueError( + f"sqlalchemy_scheme in connection extra should be one of 'exa+websocket', 'exa+pyodbc' or 'exa+turbodbc', " + f"but got '{sqlalchemy_scheme}'. See https://github.com/exasol/sqlalchemy-exasol?tab=readme-ov-file#using-sqlalchemy-with-exasol-db for more details." + ) + return sqlalchemy_scheme + + @property + def sqlalchemy_url(self) -> URL: + """ + Return a Sqlalchemy.engine.URL object from the connection. + + :return: the extracted sqlalchemy.engine.URL object. + """ + connection = self.connection + query = connection.extra_dejson + query = {k: v for k, v in query.items() if k.lower() != "sqlalchemy_scheme"} + return URL.create( + drivername=self.sqlalchemy_scheme, + username=connection.login, + password=connection.password, + host=connection.host, + port=connection.port, + database=self.schema or connection.schema, + query=query, + ) + + def get_uri(self) -> str: + """ + Extract the URI from the connection. + + :return: the extracted uri. + """ + return self.sqlalchemy_url.render_as_string(hide_password=False) + def _get_pandas_df( self, sql, parameters: Iterable | Mapping[str, Any] | None = None, **kwargs ) -> pd.DataFrame: diff --git a/providers/exasol/tests/unit/exasol/hooks/test_exasol.py b/providers/exasol/tests/unit/exasol/hooks/test_exasol.py index 580adf2b12019..ae85f4a6c1228 100644 --- a/providers/exasol/tests/unit/exasol/hooks/test_exasol.py +++ b/providers/exasol/tests/unit/exasol/hooks/test_exasol.py @@ -63,6 +63,91 @@ def test_get_conn_extra_args(self, mock_pyexasol): assert kwargs["encryption"] is True +class TestExasolHookSqlalchemy: + def get_connection(self, extra: dict | None = None) -> models.Connection: + return models.Connection( + login="login", + password="password", + host="host", + port=1234, + schema="schema", + extra=extra, + ) + + @pytest.mark.parametrize( + "init_scheme, extra_scheme, expected_result, expect_error", + [ + (None, None, "exa+websocket", False), + ("exa+pyodbc", None, "exa+pyodbc", False), + (None, "exa+turbodbc", "exa+turbodbc", False), + ("exa+invalid", None, None, True), + (None, "exa+invalid", None, True), + ], + ids=[ + "default", + "from_init_arg", + "from_extra", + "invalid_from_init_arg", + "invalid_from_extra", + ], + ) + def test_sqlalchemy_scheme_property(self, init_scheme, extra_scheme, expected_result, expect_error): + hook = ExasolHook(sqlalchemy_scheme=init_scheme) if init_scheme else ExasolHook() + connection = self.get_connection(extra={"sqlalchemy_scheme": extra_scheme} if extra_scheme else None) + hook.get_connection = mock.Mock(return_value=connection) + + if not expect_error: + assert hook.sqlalchemy_scheme == expected_result + else: + with pytest.raises(ValueError): + _ = hook.sqlalchemy_scheme + + @pytest.mark.parametrize( + "hook_scheme, extra, expected_url", + [ + (None, {}, "exa+websocket://login:password@host:1234/schema"), + ( + None, + {"CONNECTIONLCALL": "en_US.UTF-8", "driver": "EXAODBC"}, + "exa+websocket://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC", + ), + ( + None, + {"sqlalchemy_scheme": "exa+turbodbc", "CONNECTIONLCALL": "en_US.UTF-8", "driver": "EXAODBC"}, + "exa+turbodbc://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC", + ), + ( + "exa+pyodbc", + { + "sqlalchemy_scheme": "exa+turbodbc", # should be overridden + "CONNECTIONLCALL": "en_US.UTF-8", + "driver": "EXAODBC", + }, + "exa+pyodbc://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC", + ), + ], + ids=[ + "default", + "default_with_extra", + "scheme_from_extra_turbodbc", + "scheme_from_hook", + ], + ) + def test_sqlalchemy_url_property(self, hook_scheme, extra, expected_url): + hook = ExasolHook(sqlalchemy_scheme=hook_scheme) if hook_scheme else ExasolHook() + hook.get_connection = mock.Mock(return_value=self.get_connection(extra=extra)) + assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected_url + + def test_get_uri(self): + hook = ExasolHook() + connection = self.get_connection(extra={"CONNECTIONLCALL": "en_US.UTF-8", "driver": "EXAODBC"}) + hook.get_connection = mock.Mock(return_value=connection) + assert ( + hook.get_uri() + == "exa+websocket://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC" + ) + + class TestExasolHook: def setup_method(self): self.cur = mock.MagicMock(rowcount=lambda: 0)