From 6e3bc84905f7671ea60aa7e612ddf8c4095b70b8 Mon Sep 17 00:00:00 2001 From: Jiyoung Hong <82822254+jhongy1994@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:03:27 +0900 Subject: [PATCH] fix unnecessary imports for CloudSQL hook (#41009) --- airflow/providers/google/cloud/hooks/cloud_sql.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index bb7b2e6c87592..0baa30ec7adac 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -59,8 +59,6 @@ GoogleBaseHook, get_field, ) -from airflow.providers.mysql.hooks.mysql import MySqlHook -from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: @@ -856,7 +854,7 @@ def __init__( # Port and socket path and db_hook are automatically generated self.sql_proxy_tcp_port = None self.sql_proxy_unique_path: str | None = None - self.db_hook: PostgresHook | MySqlHook | None = None + self.db_hook: BaseHook | None = None self.reserved_tcp_socket: socket.socket | None = None # Generated based on clock + clock sequence. Unique per host (!). # This is important as different hosts share the database @@ -1140,7 +1138,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: gcp_conn_id=self.gcp_conn_id, ) - def get_database_hook(self, connection: Connection) -> PostgresHook | MySqlHook: + def get_database_hook(self, connection: Connection) -> BaseHook: """ Retrieve database hook. @@ -1148,14 +1146,20 @@ def get_database_hook(self, connection: Connection) -> PostgresHook | MySqlHook: connects directly to the Google Cloud SQL database. """ if self.database_type == "postgres": - db_hook: PostgresHook | MySqlHook = PostgresHook(connection=connection, database=self.database) + from airflow.providers.postgres.hooks.postgres import PostgresHook + + db_hook: BaseHook = PostgresHook(connection=connection, database=self.database) else: + from airflow.providers.mysql.hooks.mysql import MySqlHook + db_hook = MySqlHook(connection=connection, schema=self.database) self.db_hook = db_hook return db_hook def cleanup_database_hook(self) -> None: """Clean up database hook after it was used.""" + from airflow.providers.postgres.hooks.postgres import PostgresHook + if self.database_type == "postgres": if not self.db_hook: raise ValueError("The db_hook should be set")