diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py index 57c26259d0168..fc91a8776d0cd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -20,7 +20,8 @@ from typing import TYPE_CHECKING import redshift_connector -from redshift_connector import Connection as RedshiftConnection +import tenacity +from redshift_connector import Connection as RedshiftConnection, InterfaceError, OperationalError from sqlalchemy import create_engine from sqlalchemy.engine.url import URL @@ -206,6 +207,14 @@ def get_table_primary_key(self, table: str, schema: str | None = "public") -> li pk_columns = [row[0] for row in self.get_records(sql, (schema, table))] return pk_columns or None + @tenacity.retry( + stop=tenacity.stop_after_attempt(5), + wait=tenacity.wait_exponential(max=20), + # OperationalError is thrown when the connection times out + # InterfaceError is thrown when the connection is refused + retry=tenacity.retry_if_exception_type((OperationalError, InterfaceError)), + reraise=True, + ) def get_conn(self) -> RedshiftConnection: """Get a ``redshift_connector.Connection`` object.""" conn_params = self._get_conn_params()