Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@
# Number of retries - used by googleapiclient method calls to perform retries
# For requests that are "retriable"
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.google.version_compat import AIRFLOW_V_3_1_PLUS

if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import Connection
else:
from airflow.models import Connection # type: ignore[assignment,attr-defined,no-redef]

from airflow.providers.google.cloud.hooks.secret_manager import (
GoogleCloudSecretManagerHook,
)
Expand Down Expand Up @@ -1045,15 +1051,26 @@ def _generate_unique_path() -> str:
def _quote(value) -> str | None:
return quote_plus(value) if value else None

def _generate_connection_uri(self) -> str:
def _reserve_port(self):
if self.use_proxy:
if self.sql_proxy_use_tcp:
if not self.sql_proxy_tcp_port:
self.reserve_free_tcp_port()
if not self.sql_proxy_unique_path:
self.sql_proxy_unique_path = self._generate_unique_path()

def _generate_connection_uri(self) -> str:
self._reserve_port()
if not self.database_type:
raise ValueError("The database_type should be set")
if not self.user:
raise AirflowException("The login parameter needs to be set in connection")
if not self.public_ip:
raise AirflowException("The location parameter needs to be set in connection")
if not self.password:
raise AirflowException("The password parameter needs to be set in connection")
if not self.database:
raise AirflowException("The database parameter needs to be set in connection")

database_uris = CONNECTION_URIS[self.database_type]
ssl_spec = None
Expand All @@ -1072,14 +1089,6 @@ def _generate_connection_uri(self) -> str:
ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert}
else:
format_string = public_uris["non-ssl"]
if not self.user:
raise AirflowException("The login parameter needs to be set in connection")
if not self.public_ip:
raise AirflowException("The location parameter needs to be set in connection")
if not self.password:
raise AirflowException("The password parameter needs to be set in connection")
if not self.database:
raise AirflowException("The database parameter needs to be set in connection")

connection_uri = format_string.format(
user=quote_plus(self.user) if self.user else "",
Expand Down Expand Up @@ -1113,15 +1122,81 @@ def _get_sqlproxy_instance_specification(self) -> str:
instance_specification += f"=tcp:{self.sql_proxy_tcp_port}"
return instance_specification

def _generate_connection_parameters(self) -> dict:
self._reserve_port()
if not self.database_type:
raise ValueError("The database_type should be set")
if not self.user:
raise AirflowException("The login parameter needs to be set in connection")
if not self.public_ip:
raise AirflowException("The location parameter needs to be set in connection")
if not self.password:
raise AirflowException("The password parameter needs to be set in connection")
if not self.database:
raise AirflowException("The database parameter needs to be set in connection")

connection_parameters = {}

connection_parameters["conn_type"] = self.database_type
connection_parameters["login"] = self.user
connection_parameters["password"] = self.password
connection_parameters["schema"] = self.database
connection_parameters["extra"] = {}

database_uris = CONNECTION_URIS[self.database_type]
if self.use_proxy:
proxy_uris = database_uris["proxy"]
if self.sql_proxy_use_tcp:
connection_parameters["host"] = "127.0.0.1"
connection_parameters["port"] = self.sql_proxy_tcp_port
else:
socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
if "localhost" in proxy_uris["socket"]:
connection_parameters["host"] = "localhost"
connection_parameters["extra"].update({"unix_socket": socket_path})
else:
connection_parameters["host"] = socket_path
else:
public_uris = database_uris["public"]
if self.use_ssl:
connection_parameters["host"] = self.public_ip
connection_parameters["port"] = self.public_port
if "ssl_spec" in public_uris["ssl"]:
connection_parameters["extra"].update(
{
"ssl": json.dumps(
{"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert}
)
}
)
else:
connection_parameters["extra"].update(
{
"sslmode": "verify-ca",
"sslcert": self.sslcert,
"sslkey": self.sslkey,
"sslrootcert": self.sslrootcert,
}
)
else:
connection_parameters["host"] = self.public_ip
connection_parameters["port"] = self.public_port
if connection_parameters.get("extra"):
connection_parameters["extra"] = json.dumps(connection_parameters["extra"])
return connection_parameters

def create_connection(self) -> Connection:
"""
Create a connection.

Connection ID will be randomly generated according to whether it uses
proxy, TCP, UNIX sockets, SSL.
"""
uri = self._generate_connection_uri()
connection = Connection(conn_id=self.db_conn_id, uri=uri)
if AIRFLOW_V_3_1_PLUS:
kwargs = self._generate_connection_parameters()
else:
kwargs = {"uri": self._generate_connection_uri()}
connection = Connection(conn_id=self.db_conn_id, **kwargs)
self.log.info("Creating connection %s", self.db_conn_id)
return connection

Expand Down
Loading
Loading