diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 8d82253f9de58..1d39ac2affd8d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.utils.context import Context @@ -102,7 +103,7 @@ def __init__( table: str | None = None, select_query: str | None = None, redshift_conn_id: str = "redshift_default", - aws_conn_id: str | None = "aws_default", + aws_conn_id: str | None | ArgNotSet = NOTSET, verify: bool | str | None = None, unload_options: list | None = None, autocommit: bool = False, @@ -118,7 +119,6 @@ def __init__( self.schema = schema self.table = table self.redshift_conn_id = redshift_conn_id - self.aws_conn_id = aws_conn_id self.verify = verify self.unload_options = unload_options or [] self.autocommit = autocommit @@ -127,6 +127,16 @@ def __init__( self.table_as_file_name = table_as_file_name self.redshift_data_api_kwargs = redshift_data_api_kwargs or {} self.select_query = select_query + # In execute() we attempt to fetch this aws connection to check for extras. If the user didn't + # actually provide a connection note that, because we don't want to let the exception bubble up in + # that case (since we're silently injecting a connection on their behalf). + self._aws_conn_id: str | None + if isinstance(aws_conn_id, ArgNotSet): + self.conn_set = False + self._aws_conn_id = "aws_default" + else: + self.conn_set = True + self._aws_conn_id = aws_conn_id def _build_unload_query( self, credentials_block: str, select_query: str, s3_key: str, unload_options: str @@ -176,11 +186,16 @@ def execute(self, context: Context) -> None: raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") else: redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) - conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None + conn = ( + S3Hook.get_connection(conn_id=self._aws_conn_id) + # Only fetch the connection if it was set by the user and it is not None + if self.conn_set and self._aws_conn_id + else None + ) if conn and conn.extra_dejson.get("role_arn", False): credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}" else: - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify) credentials = s3_hook.get_credentials() credentials_block = build_credentials_block(credentials)