From f5b9b2c22f1bfaea2a31ab510d1fc9c2456b21bc Mon Sep 17 00:00:00 2001 From: viiccwen Date: Sun, 17 Aug 2025 13:05:08 +0800 Subject: [PATCH 1/3] fix(aws): add _ensure_db_session function to initialize orm - Add `_ensure_db_session` to initialize orm when engine or session are None - Add `test_ensure_db_session_initializes_orm` to test ORM initialization logic --- .../airflow/providers/amazon/aws/utils/eks_get_token.py | 9 +++++++++ .../tests/unit/amazon/aws/utils/test_eks_get_token.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py index 9a340671ef56e..23ce77d3fdf97 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py @@ -19,7 +19,9 @@ import argparse from datetime import datetime, timedelta, timezone +from airflow import settings from airflow.providers.amazon.aws.hooks.eks import EksHook +from airflow.settings import configure_orm # Presigned STS urls are valid for 15 minutes, set token expiration to 1 minute before it expires for # some cushion @@ -51,9 +53,16 @@ def get_parser(): return parser +def _ensure_db_session(): + if not getattr(settings, "engine", None) or not getattr(settings, "Session", None): + configure_orm() + + def main(): parser = get_parser() args = parser.parse_args() + _ensure_db_session() + eks_hook = EksHook(aws_conn_id=args.aws_conn_id, region_name=args.region_name) access_token = eks_hook.fetch_access_token_for_cluster(args.cluster_name) access_token_expiration = get_expiration_time() diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py b/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py index 0cb678a4074c0..8e735407444db 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py @@ -87,3 +87,12 @@ def test_run(self, mock_eks_hook, args, expected_aws_conn_id, expected_region_na aws_conn_id=expected_aws_conn_id, region_name=expected_region_name ) mock_eks_hook.return_value.fetch_access_token_for_cluster.assert_called_once_with("test-cluster") + + @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.configure_orm") + @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.settings.Session", None) + @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.settings.engine", None) + def test_ensure_db_session_initializes_orm(self, mock_configure_orm): + import airflow.providers.amazon.aws.utils.eks_get_token as eks_get_token + + eks_get_token._ensure_db_session() + mock_configure_orm.assert_called_once() From a3b3fc839b647df5051a818a9a7081ce4d9bd04c Mon Sep 17 00:00:00 2001 From: viiccwen Date: Sun, 17 Aug 2025 13:07:50 +0800 Subject: [PATCH 2/3] fix(aws): move config initialization inside get_session method Move config creation inside get_session for better lazy evaluation to avoid session issue --- .../src/airflow/providers/amazon/aws/hooks/base_aws.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py index b15c837ca3024..dfc423f6e71a0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py @@ -688,9 +688,10 @@ def account_id(self) -> str: def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session: """Get the underlying boto3.session.Session(region_name=region_name).""" - return SessionFactory( - conn=self.conn_config, region_name=region_name, config=self.config - ).create_session(deferrable=deferrable) + config = self._config or botocore.config.Config() + return SessionFactory(conn=self.conn_config, region_name=region_name, config=config).create_session( + deferrable=deferrable + ) def _get_config(self, config: Config | None = None) -> Config: """ From 96c776d9e0564d5efc39b2ad4e677358024c1ade Mon Sep 17 00:00:00 2001 From: viiccwen Date: Sun, 17 Aug 2025 22:49:25 +0800 Subject: [PATCH 3/3] refactor(eks_get_token): rename, simplify function and add docstring - Rename _ensure_db_session to _ensure_orm_configured - Simplify ORM check (remove session part) - add docstring with inline comment - Update corresponding test function name and calls --- .../src/airflow/providers/amazon/aws/hooks/base_aws.py | 7 +++---- .../airflow/providers/amazon/aws/utils/eks_get_token.py | 7 ++++--- .../tests/unit/amazon/aws/utils/test_eks_get_token.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py index dfc423f6e71a0..b15c837ca3024 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py @@ -688,10 +688,9 @@ def account_id(self) -> str: def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session: """Get the underlying boto3.session.Session(region_name=region_name).""" - config = self._config or botocore.config.Config() - return SessionFactory(conn=self.conn_config, region_name=region_name, config=config).create_session( - deferrable=deferrable - ) + return SessionFactory( + conn=self.conn_config, region_name=region_name, config=self.config + ).create_session(deferrable=deferrable) def _get_config(self, config: Config | None = None) -> Config: """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py index 23ce77d3fdf97..00d53c91233f0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py @@ -53,15 +53,16 @@ def get_parser(): return parser -def _ensure_db_session(): - if not getattr(settings, "engine", None) or not getattr(settings, "Session", None): +def _ensure_orm_configured(): + """Ensure Airflow ORM is configured if engine is not set.""" + if not getattr(settings, "engine", None): configure_orm() def main(): parser = get_parser() args = parser.parse_args() - _ensure_db_session() + _ensure_orm_configured() eks_hook = EksHook(aws_conn_id=args.aws_conn_id, region_name=args.region_name) access_token = eks_hook.fetch_access_token_for_cluster(args.cluster_name) diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py b/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py index 8e735407444db..d3637305498ee 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py @@ -91,8 +91,8 @@ def test_run(self, mock_eks_hook, args, expected_aws_conn_id, expected_region_na @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.configure_orm") @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.settings.Session", None) @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.settings.engine", None) - def test_ensure_db_session_initializes_orm(self, mock_configure_orm): + def test_ensure_orm_configured_initializes_orm(self, mock_configure_orm): import airflow.providers.amazon.aws.utils.eks_get_token as eks_get_token - eks_get_token._ensure_db_session() + eks_get_token._ensure_orm_configured() mock_configure_orm.assert_called_once()