diff --git a/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py b/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py index c44a6be346e4a..14899c94ed7a4 100644 --- a/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py +++ b/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py @@ -79,24 +79,44 @@ def __init__( if self.timeout < 1: raise ValueError("Druid timeout should be equal or greater than 1") + self.status_endpoint = "druid/indexer/v1/task" + @cached_property def conn(self) -> Connection: return self.get_connection(self.druid_ingest_conn_id) - def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> str: - """Get Druid connection url.""" - host = self.conn.host - port = self.conn.port + @property + def get_connection_type(self) -> str: if self.conn.schema: conn_type = self.conn.schema else: conn_type = self.conn.conn_type or "http" + return conn_type + + def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> str: + """Get Druid connection url.""" + host = self.conn.host + port = self.conn.port + conn_type = self.get_connection_type if ingestion_type == IngestionType.BATCH: endpoint = self.conn.extra_dejson.get("endpoint", "") else: endpoint = self.conn.extra_dejson.get("msq_endpoint", "") return f"{conn_type}://{host}:{port}/{endpoint}" + def get_status_url(self, ingestion_type): + """Return Druid status url.""" + if ingestion_type == IngestionType.MSQ: + if self.get_connection_type == "druid": + conn_type = self.conn.extra_dejson.get("schema", "http") + else: + conn_type = self.get_connection_type + + status_endpoint = self.conn.extra_dejson.get("status_endpoint", self.status_endpoint) + return f"{conn_type}://{self.conn.host}:{self.conn.port}/{status_endpoint}" + else: + return self.get_conn_url(ingestion_type) + def get_auth(self) -> requests.auth.HTTPBasicAuth | None: """ Return username and password from connections tab as requests.auth.HTTPBasicAuth object. @@ -141,7 +161,7 @@ def submit_indexing_job( druid_task_id = req_json["task"] else: druid_task_id = req_json["taskId"] - druid_task_status_url = f"{self.get_conn_url()}/{druid_task_id}/status" + druid_task_status_url = self.get_status_url(ingestion_type) + f"/{druid_task_id}/status" self.log.info("Druid indexing task-id: %s", druid_task_id) running = True diff --git a/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py b/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py index 5b350904e65f2..478509480ea9e 100644 --- a/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py +++ b/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py @@ -316,6 +316,18 @@ def test_get_conn_url_with_ingestion_type_and_schema(self, mock_get_connection): hook = DruidHook(timeout=1, max_ingestion_time=5) assert hook.get_conn_url(IngestionType.MSQ) == "https://test_host:1/sql_ingest" + @patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection") + def test_get_status_url(self, mock_get_connection): + get_conn_value = MagicMock() + get_conn_value.host = "test_host" + get_conn_value.conn_type = "http" + get_conn_value.schema = "https" + get_conn_value.port = "1" + get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": "sql_ingest"} + mock_get_connection.return_value = get_conn_value + hook = DruidHook(timeout=1, max_ingestion_time=5) + assert hook.get_status_url(IngestionType.MSQ) == "https://test_host:1/druid/indexer/v1/task" + @patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection") def test_get_auth(self, mock_get_connection): get_conn_value = MagicMock()