Skip to content

Commit

Permalink
fix unnecessary imports for CloudSQL hook
Browse files Browse the repository at this point in the history
  • Loading branch information
jhongy1994 committed Jul 24, 2024
1 parent 68b3159 commit 74548fd
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@
GoogleBaseHook,
get_field,
)
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
Expand Down Expand Up @@ -856,7 +854,7 @@ def __init__(
# Port and socket path and db_hook are automatically generated
self.sql_proxy_tcp_port = None
self.sql_proxy_unique_path: str | None = None
self.db_hook: PostgresHook | MySqlHook | None = None
self.db_hook: BaseHook | None = None
self.reserved_tcp_socket: socket.socket | None = None
# Generated based on clock + clock sequence. Unique per host (!).
# This is important as different hosts share the database
Expand Down Expand Up @@ -1140,22 +1138,28 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner:
gcp_conn_id=self.gcp_conn_id,
)

def get_database_hook(self, connection: Connection) -> PostgresHook | MySqlHook:
def get_database_hook(self, connection: Connection) -> BaseHook:
"""
Retrieve database hook.
This is the actual Postgres or MySQL database hook that uses proxy or
connects directly to the Google Cloud SQL database.
"""
if self.database_type == "postgres":
db_hook: PostgresHook | MySqlHook = PostgresHook(connection=connection, database=self.database)
from airflow.providers.postgres.hooks.postgres import PostgresHook

db_hook: BaseHook = PostgresHook(connection=connection, database=self.database)
else:
from airflow.providers.mysql.hooks.mysql import MySqlHook

db_hook = MySqlHook(connection=connection, schema=self.database)
self.db_hook = db_hook
return db_hook

def cleanup_database_hook(self) -> None:
"""Clean up database hook after it was used."""
from airflow.providers.postgres.hooks.postgres import PostgresHook

if self.database_type == "postgres":
if not self.db_hook:
raise ValueError("The db_hook should be set")
Expand Down

0 comments on commit 74548fd

Please sign in to comment.