diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py index 9e361608eeee5..c689324f97dca 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -32,6 +32,7 @@ from databricks import sql from databricks.sql.types import Row +from sqlalchemy.engine import URL from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.handlers import return_single_query_results @@ -171,6 +172,37 @@ def get_conn(self) -> AirflowConnection: raise AirflowException("SQL connection is not initialized") return cast("AirflowConnection", self._sql_conn) + @property + def sqlalchemy_url(self) -> URL: + """ + Return a Sqlalchemy.engine.URL object from the connection. + + :return: the extracted sqlalchemy.engine.URL object. + """ + conn = self.get_conn() + url_query = { + "http_path": self._http_path, + "catalog": self.catalog, + "schema": self.schema, + } + url_query = {k: v for k, v in url_query.items() if v is not None} + return URL.create( + drivername="databricks", + username="token", + password=conn.password, + host=conn.host, + port=conn.port, + query=url_query, + ) + + def get_uri(self) -> str: + """ + Extract the URI from the connection. + + :return: the extracted uri. + """ + return self.sqlalchemy_url.render_as_string(hide_password=False) + @overload def run( self, diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py index 3fe7df0e83915..9e2ee58d60069 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py @@ -23,6 +23,7 @@ from datetime import timedelta from unittest import mock from unittest.mock import PropertyMock, patch +from urllib.parse import quote_plus import pandas as pd import polars as pl @@ -38,7 +39,11 @@ DEFAULT_CONN_ID = "databricks_default" HOST = "xx.cloud.databricks.com" HOST_WITH_SCHEME = "https://xx.cloud.databricks.com" +PORT = 443 TOKEN = "token" +HTTP_PATH = "sql/protocolv1/o/1234567890123456/0123-456789-abcd123" +SCHEMA = "test_schema" +CATALOG = "test_catalog" @pytest.fixture(autouse=True) @@ -107,6 +112,43 @@ def mock_timer(): yield mock_timer +def make_mock_connection(): + return Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + port=PORT, + login="token", + password=TOKEN, + ) + + +def test_sqlachemy_url_property(mock_get_conn): + mock_get_conn.return_value = make_mock_connection() + hook = DatabricksSqlHook( + databricks_conn_id=DEFAULT_CONN_ID, http_path=HTTP_PATH, catalog=CATALOG, schema=SCHEMA + ) + url = hook.sqlalchemy_url.render_as_string(hide_password=False) + expected_url = ( + f"databricks://token:{TOKEN}@{HOST}:{PORT}?" + f"catalog={CATALOG}&http_path={quote_plus(HTTP_PATH)}&schema={SCHEMA}" + ) + assert url == expected_url + + +def test_get_uri(mock_get_conn): + mock_get_conn.return_value = make_mock_connection() + hook = DatabricksSqlHook( + databricks_conn_id=DEFAULT_CONN_ID, http_path=HTTP_PATH, catalog=CATALOG, schema=SCHEMA + ) + uri = hook.get_uri() + expected_uri = ( + f"databricks://token:{TOKEN}@{HOST}:{PORT}?" + f"catalog={CATALOG}&http_path={quote_plus(HTTP_PATH)}&schema={SCHEMA}" + ) + assert uri == expected_uri + + def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: return [(field,) for field in fields]