diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py index 32b86cd9c0203..ab05183aef7fe 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -50,7 +50,13 @@ # Number of retries - used by googleapiclient method calls to perform retries # For requests that are "retriable" from airflow.exceptions import AirflowException -from airflow.models import Connection +from airflow.providers.google.version_compat import AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import Connection +else: + from airflow.models import Connection # type: ignore[assignment,attr-defined,no-redef] + from airflow.providers.google.cloud.hooks.secret_manager import ( GoogleCloudSecretManagerHook, ) @@ -1045,15 +1051,26 @@ def _generate_unique_path() -> str: def _quote(value) -> str | None: return quote_plus(value) if value else None - def _generate_connection_uri(self) -> str: + def _reserve_port(self): if self.use_proxy: if self.sql_proxy_use_tcp: if not self.sql_proxy_tcp_port: self.reserve_free_tcp_port() if not self.sql_proxy_unique_path: self.sql_proxy_unique_path = self._generate_unique_path() + + def _generate_connection_uri(self) -> str: + self._reserve_port() if not self.database_type: raise ValueError("The database_type should be set") + if not self.user: + raise AirflowException("The login parameter needs to be set in connection") + if not self.public_ip: + raise AirflowException("The location parameter needs to be set in connection") + if not self.password: + raise AirflowException("The password parameter needs to be set in connection") + if not self.database: + raise AirflowException("The database parameter needs to be set in connection") database_uris = CONNECTION_URIS[self.database_type] ssl_spec = None @@ -1072,14 +1089,6 @@ def _generate_connection_uri(self) -> str: ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert} else: format_string = public_uris["non-ssl"] - if not self.user: - raise AirflowException("The login parameter needs to be set in connection") - if not self.public_ip: - raise AirflowException("The location parameter needs to be set in connection") - if not self.password: - raise AirflowException("The password parameter needs to be set in connection") - if not self.database: - raise AirflowException("The database parameter needs to be set in connection") connection_uri = format_string.format( user=quote_plus(self.user) if self.user else "", @@ -1113,6 +1122,69 @@ def _get_sqlproxy_instance_specification(self) -> str: instance_specification += f"=tcp:{self.sql_proxy_tcp_port}" return instance_specification + def _generate_connection_parameters(self) -> dict: + self._reserve_port() + if not self.database_type: + raise ValueError("The database_type should be set") + if not self.user: + raise AirflowException("The login parameter needs to be set in connection") + if not self.public_ip: + raise AirflowException("The location parameter needs to be set in connection") + if not self.password: + raise AirflowException("The password parameter needs to be set in connection") + if not self.database: + raise AirflowException("The database parameter needs to be set in connection") + + connection_parameters = {} + + connection_parameters["conn_type"] = self.database_type + connection_parameters["login"] = self.user + connection_parameters["password"] = self.password + connection_parameters["schema"] = self.database + connection_parameters["extra"] = {} + + database_uris = CONNECTION_URIS[self.database_type] + if self.use_proxy: + proxy_uris = database_uris["proxy"] + if self.sql_proxy_use_tcp: + connection_parameters["host"] = "127.0.0.1" + connection_parameters["port"] = self.sql_proxy_tcp_port + else: + socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}" + if "localhost" in proxy_uris["socket"]: + connection_parameters["host"] = "localhost" + connection_parameters["extra"].update({"unix_socket": socket_path}) + else: + connection_parameters["host"] = socket_path + else: + public_uris = database_uris["public"] + if self.use_ssl: + connection_parameters["host"] = self.public_ip + connection_parameters["port"] = self.public_port + if "ssl_spec" in public_uris["ssl"]: + connection_parameters["extra"].update( + { + "ssl": json.dumps( + {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert} + ) + } + ) + else: + connection_parameters["extra"].update( + { + "sslmode": "verify-ca", + "sslcert": self.sslcert, + "sslkey": self.sslkey, + "sslrootcert": self.sslrootcert, + } + ) + else: + connection_parameters["host"] = self.public_ip + connection_parameters["port"] = self.public_port + if connection_parameters.get("extra"): + connection_parameters["extra"] = json.dumps(connection_parameters["extra"]) + return connection_parameters + def create_connection(self) -> Connection: """ Create a connection. @@ -1120,8 +1192,11 @@ def create_connection(self) -> Connection: Connection ID will be randomly generated according to whether it uses proxy, TCP, UNIX sockets, SSL. """ - uri = self._generate_connection_uri() - connection = Connection(conn_id=self.db_conn_id, uri=uri) + if AIRFLOW_V_3_1_PLUS: + kwargs = self._generate_connection_parameters() + else: + kwargs = {"uri": self._generate_connection_uri()} + connection = Connection(conn_id=self.db_conn_id, **kwargs) self.log.info("Creating connection %s", self.db_conn_id) return connection diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py index f1a332eb4fb7a..d45039ff0c80a 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py @@ -24,6 +24,7 @@ import tempfile from unittest import mock from unittest.mock import PropertyMock, call, mock_open +from urllib.parse import parse_qsl, unquote, urlsplit import aiohttp import httplib2 @@ -33,7 +34,13 @@ from yarl import URL from airflow.exceptions import AirflowException -from airflow.models import Connection + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import Connection +else: + from airflow.models import Connection # type: ignore[assignment,attr-defined,no-redef] from airflow.providers.google.cloud.hooks.cloud_sql import ( CloudSQLAsyncHook, CloudSQLDatabaseHook, @@ -761,13 +768,40 @@ def test_delete_database_overridden_project_id( ) +def _parse_from_uri(uri: str): + connection_parameters = {} + uri_parts = urlsplit(uri) + connection_parameters["conn_type"] = uri_parts.scheme + rest_of_the_url = uri.replace(f"{uri_parts.scheme}://", "//") + uri_parts = urlsplit(rest_of_the_url) + host = unquote(uri_parts.hostname or "") + connection_parameters["host"] = host + quoted_schema = uri_parts.path[1:] + connection_parameters["schema"] = unquote(quoted_schema) if quoted_schema else "" + connection_parameters["login"] = unquote(uri_parts.username) if uri_parts.username else "" + connection_parameters["password"] = unquote(uri_parts.password) if uri_parts.password else "" + connection_parameters["port"] = uri_parts.port # type: ignore[assignment] + if uri_parts.query: + query = dict(parse_qsl(uri_parts.query, keep_blank_values=True)) + connection_parameters["extra"] = json.dumps(query) + return connection_parameters + + class TestCloudSqlDatabaseHook: @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_cloudsql_database_hook_validate_ssl_certs_no_ssl(self, get_connection): - connection = Connection() - connection.set_extra( - json.dumps({"location": "test", "instance": "instance", "database_type": "postgres"}) + connection = Connection( + conn_id="my_test_connection", + conn_type="gcpcloudsqldb", ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = json.dumps( + {"location": "test", "instance": "instance", "database_type": "postgres"} + ) + else: + connection.set_extra( + json.dumps({"location": "test", "instance": "instance", "database_type": "postgres"}) + ) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -794,10 +828,16 @@ def test_cloudsql_database_hook_validate_ssl_certs_missing_cert_params( ): mock_is_file.side_effects = True mock_set_temporary_ssl_file.side_effect = cert_dict.values() - connection = Connection() + connection = Connection( + conn_id="my_test_connection", + conn_type="gcpcloudsqldb", + ) extras = {"location": "test", "instance": "instance", "database_type": "postgres", "use_ssl": "True"} extras.update(cert_dict) - connection.set_extra(json.dumps(extras)) + if AIRFLOW_V_3_1_PLUS: + connection.extra = json.dumps(extras) + else: + connection.set_extra(json.dumps(extras)) get_connection.return_value = connection hook = CloudSQLDatabaseHook( @@ -814,26 +854,31 @@ def test_cloudsql_database_hook_validate_ssl_certs_missing_cert_params( def test_cloudsql_database_hook_validate_ssl_certs_with_ssl( self, get_connection, mock_set_temporary_ssl_file, mock_is_file ): - connection = Connection() + connection = Connection( + conn_id="my_test_connection", + conn_type="gcpcloudsqldb", + ) mock_is_file.return_value = True mock_set_temporary_ssl_file.side_effect = [ "/tmp/cert_file.pem", "/tmp/rootcert_file.pem", "/tmp/key_file.pem", ] - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "instance", - "database_type": "postgres", - "use_ssl": "True", - "sslcert": "cert_file.pem", - "sslrootcert": "rootcert_file.pem", - "sslkey": "key_file.pem", - } - ) + extras = json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + "use_ssl": "True", + "sslcert": "cert_file.pem", + "sslrootcert": "rootcert_file.pem", + "sslkey": "key_file.pem", + } ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -846,26 +891,31 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl( def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable( self, get_connection, mock_set_temporary_ssl_file, mock_is_file ): - connection = Connection() + connection = Connection( + conn_id="my_test_connection", + conn_type="gcpcloudsqldb", + ) mock_is_file.return_value = False mock_set_temporary_ssl_file.side_effect = [ "/tmp/cert_file.pem", "/tmp/rootcert_file.pem", "/tmp/key_file.pem", ] - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "instance", - "database_type": "postgres", - "use_ssl": "True", - "sslcert": "cert_file.pem", - "sslrootcert": "rootcert_file.pem", - "sslkey": "key_file.pem", - } - ) + extras = json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + "use_ssl": "True", + "sslcert": "cert_file.pem", + "sslrootcert": "rootcert_file.pem", + "sslkey": "key_file.pem", + } ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -881,18 +931,23 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long( self, get_connection, gettempdir_mock ): gettempdir_mock.return_value = "/tmp" - connection = Connection() - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "very_long_instance_name_that_will_be_too_long_to_build_socket_length", - "database_type": "postgres", - "use_proxy": "True", - "use_tcp": "False", - } - ) + connection = Connection( + conn_id="my_test_connection", + conn_type="gcpcloudsqldb", ) + extras = json.dumps( + { + "location": "test", + "instance": "very_long_instance_name_that_will_be_too_long_to_build_socket_length", + "database_type": "postgres", + "use_proxy": "True", + "use_tcp": "False", + } + ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -908,18 +963,23 @@ def test_cloudsql_database_hook_validate_socket_path_length_not_too_long( self, get_connection, gettempdir_mock ): gettempdir_mock.return_value = "/tmp" - connection = Connection() - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "short_instance_name", - "database_type": "postgres", - "use_proxy": "True", - "use_tcp": "False", - } - ) + connection = Connection( + conn_id="my_test_connection", + conn_type="gcpcloudsqldb", + ) + extras = json.dumps( + { + "location": "test", + "instance": "short_instance_name", + "database_type": "postgres", + "use_proxy": "True", + "use_tcp": "False", + } ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -940,7 +1000,10 @@ def test_cloudsql_database_hook_validate_socket_path_length_not_too_long( ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_cloudsql_database_hook_create_connection_missing_fields(self, get_connection, uri): - connection = Connection(uri=uri) + if AIRFLOW_V_3_1_PLUS: + connection = Connection(conn_id="test_conn_id", **_parse_from_uri(uri)) + else: + connection = Connection(uri=uri) params = { "location": "test", "instance": "instance", @@ -948,7 +1011,11 @@ def test_cloudsql_database_hook_create_connection_missing_fields(self, get_conne "use_proxy": "True", "use_tcp": "False", } - connection.set_extra(json.dumps(params)) + extras = json.dumps(params) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -960,16 +1027,23 @@ def test_cloudsql_database_hook_create_connection_missing_fields(self, get_conne @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connection): - connection = Connection(uri="http://user:password@host:80/database") - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "instance", - "database_type": "postgres", - } + if AIRFLOW_V_3_1_PLUS: + connection = Connection( + conn_id="test_conn_id", **_parse_from_uri("http://user:password@host:80/database") ) + else: + connection = Connection(uri="http://user:password@host:80/database") + extras = json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + } ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -981,18 +1055,25 @@ def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connectio @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection): - connection = Connection(uri="http://user:password@host:80/database") - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "instance", - "database_type": "postgres", - "use_proxy": "True", - "use_tcp": "False", - } + if AIRFLOW_V_3_1_PLUS: + connection = Connection( + conn_id="test_conn_id", **_parse_from_uri("http://user:password@host:80/database") ) + else: + connection = Connection(uri="http://user:password@host:80/database") + extras = json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + "use_proxy": "True", + "use_tcp": "False", + } ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -1003,16 +1084,23 @@ def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_cloudsql_database_hook_get_database_hook(self, get_connection): - connection = Connection(uri="http://user:password@host:80/database") - connection.set_extra( - json.dumps( - { - "location": "test", - "instance": "instance", - "database_type": "postgres", - } + if AIRFLOW_V_3_1_PLUS: + connection = Connection( + conn_id="test_conn_id", **_parse_from_uri("http://user:password@host:80/database") ) + else: + connection = Connection(uri="http://user:password@host:80/database") + extras = json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + } ) + if AIRFLOW_V_3_1_PLUS: + connection.extra = extras + else: + connection.set_extra(extras) get_connection.return_value = connection hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id="cloudsql_connection", default_gcp_project_id="google_connection" @@ -1414,7 +1502,10 @@ def setup_method(self, method, mock_get_conn): "key_path": "/var/local/google_cloud_default.json", } conn_extra_json = json.dumps(conn_extra) - self.connection.set_extra(conn_extra_json) + if AIRFLOW_V_3_1_PLUS: + self.connection.extra = conn_extra_json + else: + self.connection.set_extra(conn_extra_json) mock_get_conn.side_effect = [self.sql_connection, self.connection] self.db_hook = CloudSQLDatabaseHook( @@ -1440,14 +1531,20 @@ def test_hook_with_not_too_long_unix_socket_path(self, get_connection): "test_db_with_longname_but_with_limit_of_UNIX_socket&" "use_proxy=True&sql_proxy_use_tcp=False" ) - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "postgres" assert connection.schema == "testdb" def _verify_postgres_connection(self, get_connection, uri): - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "postgres" @@ -1490,7 +1587,10 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=False" ) - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "postgres" @@ -1509,7 +1609,10 @@ def test_hook_with_correct_parameters_project_id_missing(self, get_connection): self.verify_mysql_connection(get_connection, uri) def verify_mysql_connection(self, get_connection, uri): - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "mysql" @@ -1525,7 +1628,10 @@ def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection): "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=True" ) - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "postgres" @@ -1567,7 +1673,10 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection): "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=False" ) - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "mysql" @@ -1584,7 +1693,10 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection): "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=True" ) - get_connection.side_effect = [Connection(uri=uri)] + if AIRFLOW_V_3_1_PLUS: + get_connection.side_effect = [Connection(conn_id="test_conn_id", **_parse_from_uri(uri))] + else: + get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() assert connection.conn_type == "mysql"