Skip to content

Commit

Permalink
Add possibility to override the conn type for Druid (#42793)
Browse files Browse the repository at this point in the history
* Add possibility to override the conn type for Druid

Minor fix, which allows to use the schema which are specified in
theschema rather than `http` as default. In the same time it doesn't
changethe logic as any conn_type can be selected. Intuitevely it's
expectedthat anything specified in `schema` field will actually take
precedencein the building the desired url.

* Add druid endpoint connection from another PR

* Fix missing scheme in test

* Set schema to None where it's unused

Even though we don't need it directly set, by default the mock will set
it to an internal object, thus we need to override it to None.

---------

Co-authored-by: Oleg Auckenthaler <github.sitcom838@passmail.net>
  • Loading branch information
Rasnar and olegmayko authored Oct 11, 2024
1 parent 6c4d67f commit 7202ee8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
5 changes: 4 additions & 1 deletion providers/src/airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> s
"""Get Druid connection url."""
host = self.conn.host
port = self.conn.port
conn_type = self.conn.conn_type or "http"
if self.conn.schema:
conn_type = self.conn.schema
else:
conn_type = self.conn.conn_type or "http"
if ingestion_type == IngestionType.BATCH:
endpoint = self.conn.extra_dejson.get("endpoint", "")
else:
Expand Down
44 changes: 39 additions & 5 deletions providers/tests/apache/druid/hooks/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ class TestDRuidhook(DruidHook):
self.is_sql_based_ingestion = False

def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH):
if self.conn.schema:
conn_type = self.conn.schema
else:
conn_type = "http"

if ingestion_type == IngestionType.MSQ:
return "http://druid-overlord:8081/druid/v2/sql/task"
return "http://druid-overlord:8081/druid/indexer/v1/task"
return f"{conn_type}://druid-overlord:8081/druid/v2/sql/task"
return f"{conn_type}://druid-overlord:8081/druid/indexer/v1/task"

self.db_hook = TestDRuidhook()

Expand Down Expand Up @@ -257,7 +262,8 @@ def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH):
def test_conn_property(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "https"
get_conn_value.conn_type = "http"
get_conn_value.schema = None
get_conn_value.port = "1"
get_conn_value.extra_dejson = {"endpoint": "ingest"}
mock_get_connection.return_value = get_conn_value
Expand All @@ -268,8 +274,22 @@ def test_conn_property(self, mock_get_connection):
def test_get_conn_url(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "https"
get_conn_value.conn_type = "http"
get_conn_value.schema = None
get_conn_value.port = "1"
get_conn_value.extra_dejson = {"endpoint": "ingest"}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(timeout=1, max_ingestion_time=5)
assert hook.get_conn_url() == "http://test_host:1/ingest"

@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_conn_url_with_schema(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "http"
get_conn_value.schema = None
get_conn_value.port = "1"
get_conn_value.schema = "https"
get_conn_value.extra_dejson = {"endpoint": "ingest"}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(timeout=1, max_ingestion_time=5)
Expand All @@ -279,8 +299,21 @@ def test_get_conn_url(self, mock_get_connection):
def test_get_conn_url_with_ingestion_type(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "https"
get_conn_value.conn_type = "http"
get_conn_value.schema = None
get_conn_value.port = "1"
get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": "sql_ingest"}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(timeout=1, max_ingestion_time=5)
assert hook.get_conn_url(IngestionType.MSQ) == "http://test_host:1/sql_ingest"

@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_conn_url_with_ingestion_type_and_schema(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "http"
get_conn_value.port = "1"
get_conn_value.schema = "https"
get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": "sql_ingest"}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(timeout=1, max_ingestion_time=5)
Expand Down Expand Up @@ -343,6 +376,7 @@ def setup_method(self):
self.conn = conn = MagicMock()
self.conn.host = "host"
self.conn.port = "1000"
self.conn.schema = None
self.conn.conn_type = "druid"
self.conn.extra_dejson = {"endpoint": "druid/v2/sql"}
self.conn.cursor.return_value = self.cur
Expand Down

0 comments on commit 7202ee8

Please sign in to comment.