diff --git a/providers/amazon/pyproject.toml b/providers/amazon/pyproject.toml index 1385c77492033..a1595cb8b8d8e 100644 --- a/providers/amazon/pyproject.toml +++ b/providers/amazon/pyproject.toml @@ -151,6 +151,9 @@ dependencies = [ "common.messaging" = [ "apache-airflow-providers-common-messaging>=2.0.0" ] +"sqlalchemy" = [ + "sqlalchemy>=1.4.49", +] [dependency-groups] dev = [ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py index 451664a467c7f..e12589a75d12d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py @@ -21,7 +21,11 @@ from typing import TYPE_CHECKING, Any import pyathena -from sqlalchemy.engine.url import URL + +try: + from sqlalchemy.engine.url import URL +except ImportError: + URL = None from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper @@ -152,9 +156,15 @@ def _get_conn_params(self) -> dict[str, str | None]: def get_uri(self) -> str: """Overridden to use the Athena dialect as driver name.""" + from airflow.exceptions import AirflowOptionalProviderFeatureException + + if URL is None: + raise AirflowOptionalProviderFeatureException( + "sqlalchemy is required to generate the connection URI. " + "Install it with: pip install 'apache-airflow-providers-amazon[sqlalchemy]'" + ) conn_params = self._get_conn_params() creds = self.get_credentials(region_name=conn_params["region_name"]) - return URL.create( f"awsathena+{conn_params['driver']}", username=creds.access_key, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py index ef0dffcbd4b20..33f2e4c83b2a8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -22,9 +22,14 @@ import redshift_connector import tenacity from redshift_connector import Connection as RedshiftConnection, InterfaceError, OperationalError -from sqlalchemy import create_engine -from sqlalchemy.engine.url import URL +try: + from sqlalchemy import create_engine + from sqlalchemy.engine.url import URL +except ImportError: + URL = create_engine = None + +from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -151,6 +156,11 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: def get_uri(self) -> str: """Overridden to use the Redshift dialect as driver name.""" + if URL is None: + raise AirflowOptionalProviderFeatureException( + "sqlalchemy is required to generate the connection URI. " + "Install it with: pip install 'apache-airflow-providers-amazon[sqlalchemy]'" + ) conn_params = self._get_conn_params() if "user" in conn_params: @@ -174,6 +184,11 @@ def get_uri(self) -> str: def get_sqlalchemy_engine(self, engine_kwargs=None): """Overridden to pass Redshift-specific arguments.""" + if create_engine is None: + raise AirflowOptionalProviderFeatureException( + "sqlalchemy is required for creating the engine. Install it with" + ": pip install 'apache-airflow-providers-amazon[sqlalchemy]'" + ) conn_kwargs = self.conn.extra_dejson if engine_kwargs is None: engine_kwargs = {}