diff --git a/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py b/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py index 7b4c95a3b8374..e14136362b6ae 100644 --- a/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py +++ b/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py @@ -21,6 +21,7 @@ import subprocess from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any +from urllib.parse import quote_plus from pinotdb import connect @@ -74,6 +75,8 @@ def __init__( conn = self.get_connection(conn_id) self.host = conn.host self.port = str(conn.port) + self.username = conn.login + self.password = conn.password if cmd_path != "pinot-admin.sh": raise RuntimeError( "In version 4.0.0 of the PinotAdminHook the cmd_path has been hard-coded to" @@ -99,6 +102,10 @@ def add_schema(self, schema_file: str, with_exec: bool = True) -> Any: :param with_exec: bool """ cmd = ["AddSchema"] + if self.username: + cmd += ["-user", self.username] + if self.password: + cmd += ["-password", self.password] cmd += ["-controllerHost", self.host] cmd += ["-controllerPort", self.port] cmd += ["-schemaFile", schema_file] @@ -114,6 +121,10 @@ def add_table(self, file_path: str, with_exec: bool = True) -> Any: :param with_exec: bool """ cmd = ["AddTable"] + if self.username: + cmd += ["-user", self.username] + if self.password: + cmd += ["-password", self.password] cmd += ["-controllerHost", self.host] cmd += ["-controllerPort", self.port] cmd += ["-filePath", file_path] @@ -144,6 +155,11 @@ def create_segment( ) -> Any: """Create Pinot segment by run CreateSegment command.""" cmd = ["CreateSegment"] + if self.username: + cmd += ["-user", self.username] + + if self.password: + cmd += ["-password", self.password] if generator_config_file: cmd += ["-generatorConfigFile", generator_config_file] @@ -210,6 +226,10 @@ def upload_segment(self, segment_dir: str, table_name: str | None = None) -> Any :return: """ cmd = ["UploadSegment"] + if self.username: + cmd += ["-user", self.username] + if self.password: + cmd += ["-password", self.password] cmd += ["-controllerHost", self.host] cmd += ["-controllerPort", self.port] cmd += ["-segmentDir", segment_dir] @@ -277,6 +297,8 @@ def get_conn(self) -> Any: pinot_broker_conn = connect( host=conn.host, port=conn.port, + username=conn.login, + password=conn.password, path=conn.extra_dejson.get("endpoint", "/query/sql"), scheme=conn.extra_dejson.get("schema", "http"), ) @@ -291,7 +313,9 @@ def get_uri(self) -> str: """ conn = self.get_connection(self.get_conn_id()) host = conn.host - if conn.port is not None: + if conn.login and conn.password: + host = f"{quote_plus(conn.login)}:{quote_plus(conn.password)}@{host}" + if conn.port: host += f":{conn.port}" conn_type = conn.conn_type or "http" endpoint = conn.extra_dejson.get("endpoint", "query/sql") diff --git a/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py b/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py index fbedf1fb9a81f..8a433eace11fe 100644 --- a/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py +++ b/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py @@ -33,6 +33,8 @@ def setup_method(self): self.conn = conn = mock.MagicMock() self.conn.host = "host" self.conn.port = "1000" + self.conn.login = "" + self.conn.password = "" self.conn.extra_dejson = {} class PinotAdminHookTest(PinotAdminHook): @@ -217,6 +219,8 @@ def setup_method(self): self.conn = conn = mock.MagicMock() self.conn.host = "host" self.conn.port = "1000" + self.conn.login = "" + self.conn.password = "" self.conn.conn_type = "http" self.conn.extra_dejson = {"endpoint": "query/sql"} self.cur = mock.MagicMock(rowcount=0) @@ -272,3 +276,191 @@ def test_get_pandas_df(self): assert column == df.columns[0] for i, item in enumerate(result_sets): assert item[0] == df.values.tolist()[i][0] + + +class TestPinotAdminHookWithAuth: + def setup_method(self): + self.conn = conn = mock.MagicMock() + self.conn.host = "host" + self.conn.port = "1000" + self.conn.login = "user" + self.conn.password = "pwd" + self.conn.extra_dejson = {} + + class PinotAdminHookTest(PinotAdminHook): + def get_connection(self, conn_id): + return conn + + self.db_hook = PinotAdminHookTest() + + @mock.patch("airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli") + def test_add_schema_with_auth(self, mock_run_cli): + params = ["schema_file", False] + self.db_hook.add_schema(*params) + mock_run_cli.assert_called_once_with( + [ + "AddSchema", + "-user", + self.conn.login, + "-password", + self.conn.password, + "-controllerHost", + self.conn.host, + "-controllerPort", + self.conn.port, + "-schemaFile", + params[0], + ] + ) + + @mock.patch("airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli") + def test_add_table_with_auth(self, mock_run_cli): + params = ["config_file", False] + self.db_hook.add_table(*params) + mock_run_cli.assert_called_once_with( + [ + "AddTable", + "-user", + self.conn.login, + "-password", + self.conn.password, + "-controllerHost", + self.conn.host, + "-controllerPort", + self.conn.port, + "-filePath", + params[0], + ] + ) + + @mock.patch("airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli") + def test_create_segment_with_auth(self, mock_run_cli): + params = { + "generator_config_file": "a", + "data_dir": "b", + "segment_format": "c", + "out_dir": "d", + "overwrite": True, + "table_name": "e", + "segment_name": "f", + "time_column_name": "g", + "schema_file": "h", + "reader_config_file": "i", + "enable_star_tree_index": False, + "star_tree_index_spec_file": "j", + "hll_size": 9, + "hll_columns": "k", + "hll_suffix": "l", + "num_threads": 8, + "post_creation_verification": True, + "retry": 7, + } + + self.db_hook.create_segment(**params) + + mock_run_cli.assert_called_once_with( + [ + "CreateSegment", + "-user", + self.conn.login, + "-password", + self.conn.password, + "-generatorConfigFile", + params["generator_config_file"], + "-dataDir", + params["data_dir"], + "-format", + params["segment_format"], + "-outDir", + params["out_dir"], + "-overwrite", + params["overwrite"], + "-tableName", + params["table_name"], + "-segmentName", + params["segment_name"], + "-timeColumnName", + params["time_column_name"], + "-schemaFile", + params["schema_file"], + "-readerConfigFile", + params["reader_config_file"], + "-starTreeIndexSpecFile", + params["star_tree_index_spec_file"], + "-hllSize", + params["hll_size"], + "-hllColumns", + params["hll_columns"], + "-hllSuffix", + params["hll_suffix"], + "-numThreads", + params["num_threads"], + "-postCreationVerification", + params["post_creation_verification"], + "-retry", + params["retry"], + ] + ) + + @mock.patch("airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli") + def test_upload_segment_with_auth(self, mock_run_cli): + params = ["segment_dir", False] + self.db_hook.upload_segment(*params) + mock_run_cli.assert_called_once_with( + [ + "UploadSegment", + "-user", + self.conn.login, + "-password", + self.conn.password, + "-controllerHost", + self.conn.host, + "-controllerPort", + self.conn.port, + "-segmentDir", + params[0], + ] + ) + + +class TestPinotDbApiHookWithAuth: + def setup_method(self): + self.conn = conn = mock.MagicMock() + self.conn.host = "host" + self.conn.port = "1000" + self.conn.conn_type = "http" + self.conn.login = "user" + self.conn.password = "pwd" + self.conn.extra_dejson = {"endpoint": "query/sql"} + self.cur = mock.MagicMock(rowcount=0) + self.conn.cursor.return_value = self.cur + self.conn.__enter__.return_value = self.cur + self.conn.__exit__.return_value = None + + class TestPinotDBApiHook(PinotDbApiHook): + def get_conn(self): + return conn + + def get_connection(self, conn_id): + return conn + + self.db_hook = TestPinotDBApiHook + + def test_get_uri_with_auth(self): + """ + Test on getting a pinot connection uri + """ + db_hook = self.db_hook() + assert db_hook.get_uri() == "http://user:pwd@host:1000/query/sql" + + def test_get_conn_with_auth(self): + """ + Test on getting a pinot connection + """ + conn = self.db_hook().get_conn() + assert conn.host == "host" + assert conn.port == "1000" + assert conn.login == "user" + assert conn.password == "pwd" + assert conn.conn_type == "http" + assert conn.extra_dejson.get("endpoint") == "query/sql"