From e080b1c1f6a1e2504f5b215416dba4a323de2a76 Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Fri, 30 May 2025 12:21:10 +0200 Subject: [PATCH 1/2] Removed unnecessary aws_conn_id param from operators constructors --- .../providers/amazon/aws/operators/cloud_formation.py | 2 -- .../airflow/providers/amazon/aws/operators/comprehend.py | 2 -- .../src/airflow/providers/amazon/aws/operators/dms.py | 2 -- .../src/airflow/providers/amazon/aws/operators/glue.py | 6 ------ 4 files changed, 12 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py index b3168dccc464b..e4f3d211bea93 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py @@ -98,13 +98,11 @@ def __init__( *, stack_name: str, cloudformation_parameters: dict | None = None, - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) self.cloudformation_parameters = cloudformation_parameters or {} self.stack_name = stack_name - self.aws_conn_id = aws_conn_id def execute(self, context: Context): self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py index e8bc64973c79d..c1b459bc34ee2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py @@ -289,7 +289,6 @@ def __init__( waiter_delay: int = 60, waiter_max_attempts: int = 20, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -305,7 +304,6 @@ def __init__( self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable - self.aws_conn_id = aws_conn_id def execute(self, context: Context) -> str: if self.output_data_config: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py index 75897f2b897ef..b9af469a506dc 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py @@ -91,7 +91,6 @@ def __init__( table_mappings: dict, migration_type: str = "full-load", create_task_kwargs: dict | None = None, - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -102,7 +101,6 @@ def __init__( self.migration_type = migration_type self.table_mappings = table_mappings self.create_task_kwargs = create_task_kwargs or {} - self.aws_conn_id = aws_conn_id def execute(self, context: Context): """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 2b7496de78a8a..4b80d47b046c8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -313,7 +313,6 @@ def __init__( description: str = "AWS Glue Data Quality Rule Set With Airflow", update_rule_set: bool = False, data_quality_ruleset_kwargs: dict | None = None, - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -322,7 +321,6 @@ def __init__( self.description = description self.update_rule_set = update_rule_set self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {} - self.aws_conn_id = aws_conn_id def validate_inputs(self) -> None: if not self.ruleset.startswith("Rules") or not self.ruleset.endswith("]"): @@ -421,7 +419,6 @@ def __init__( waiter_delay: int = 60, waiter_max_attempts: int = 20, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -437,7 +434,6 @@ def __init__( self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable - self.aws_conn_id = aws_conn_id def validate_inputs(self) -> None: glue_table = self.datasource.get("GlueTable", {}) @@ -584,7 +580,6 @@ def __init__( waiter_delay: int = 60, waiter_max_attempts: int = 20, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -598,7 +593,6 @@ def __init__( self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable - self.aws_conn_id = aws_conn_id def execute(self, context: Context) -> str: glue_table = self.datasource.get("GlueTable", {}) From 6ee6c34954dac45ba274b85d7e1dc33a3e80da47 Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Sat, 31 May 2025 13:59:21 +0200 Subject: [PATCH 2/2] Added regression tests to operators and renamed no_conn test to default_conn --- .pre-commit-config.yaml | 4 + .../aws/operators/test_cloud_formation.py | 17 ++++ .../amazon/aws/operators/test_comprehend.py | 23 +++++ .../unit/amazon/aws/operators/test_dms.py | 21 +++++ .../unit/amazon/aws/operators/test_glue.py | 84 +++++++++++++++++++ .../unit/amazon/aws/operators/test_rds.py | 12 +++ 6 files changed, 161 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7590cfa6adfe9..d8c41c3244369 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1273,6 +1273,8 @@ repos: files: ^go-sdk/ types: [go] language: golang + env: + - GOPROXY=direct - id: gci name: Consistent import ordering for Go files # Since this is invoked from the root folder, not go-sdk/, gci can't auto-detect the prefix @@ -1281,6 +1283,8 @@ repos: files: ^go-sdk/ types: [go] language: golang + env: + - GOPROXY=direct ## ADD MOST PRE-COMMITS ABOVE THAT LINE # The below pre-commits are those requiring CI image to be built - id: mypy-dev diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py index 1230e5d27fbd5..bbcd41f2176c4 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py @@ -103,6 +103,23 @@ def test_template_fields(self): validate_template_fields(op) + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = CloudFormationCreateStackOperator( + task_id="cf_create_stack_pass_conn", + stack_name="fake-stack", + cloudformation_parameters={}, + aws_conn_id=OVERWRITTEN_CONN, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = CloudFormationCreateStackOperator( + task_id="cf_create_stack_pass_default_conn", stack_name="fake-stack", cloudformation_parameters={} + ) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestCloudFormationDeleteStackOperator: def test_init(self): diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py index 5d500f7637ef9..3e74c7896fa4d 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py @@ -80,6 +80,29 @@ def test_initialize_comprehend_base_operator_hook(self, comprehend_base_operator assert comprehend_base_op.client == mocked_client comprehend_base_operator_mock_hook.assert_called_once() + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = ComprehendBaseOperator( + task_id="comprehend_base_operator", + input_data_config=INPUT_DATA_CONFIG, + output_data_config=OUTPUT_DATA_CONFIG, + language_code=LANGUAGE_CODE, + data_access_role_arn=ROLE_ARN, + aws_conn_id=OVERWRITTEN_CONN, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = ComprehendBaseOperator( + task_id="comprehend_base_operator", + input_data_config=INPUT_DATA_CONFIG, + output_data_config=OUTPUT_DATA_CONFIG, + language_code=LANGUAGE_CODE, + data_access_role_arn=ROLE_ARN, + ) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestComprehendStartPiiEntitiesDetectionJobOperator: JOB_ID = "random-job-id-1234567" diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py index 3cb97cad9d263..5771483cd320f 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py @@ -149,6 +149,27 @@ def test_template_fields(self): validate_template_fields(op) + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = DmsCreateTaskOperator( + task_id="dms_create_task_operator", + **self.TASK_DATA, + aws_conn_id=OVERWRITTEN_CONN, + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = DmsCreateTaskOperator( + task_id="dms_create_task_operator", + **self.TASK_DATA, + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestDmsDeleteTaskOperator: TASK_DATA = { diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py index 21ab76a7317fd..7a0210cf1303e 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py @@ -408,6 +408,25 @@ def test_template_fields(self): ) validate_template_fields(operator) + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = GlueJobOperator( + task_id=TASK_ID, + aws_conn_id=OVERWRITTEN_CONN, + iam_role_name="role_arn", + replace_script_file=True, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = GlueJobOperator( + task_id=TASK_ID, + iam_role_name="role_arn", + replace_script_file=True, + ) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestGlueDataQualityOperator: RULE_SET_NAME = "TestRuleSet" @@ -542,6 +561,23 @@ def test_template_fields(self): ) validate_template_fields(operator) + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = GlueDataQualityOperator( + task_id="test_overwritten_conn_passed_to_hook", + name=self.RULE_SET_NAME, + ruleset=self.RULE_SET, + aws_conn_id=OVERWRITTEN_CONN, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = GlueDataQualityOperator( + task_id="test_default_conn_passed_to_hook", name=self.RULE_SET_NAME, ruleset=self.RULE_SET + ) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestGlueDataQualityRuleSetEvaluationRunOperator: RUN_ID = "1234567890" @@ -648,6 +684,29 @@ def test_start_data_quality_ruleset_evaluation_run_wait_combinations( def test_template_fields(self): validate_template_fields(self.operator) + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = GlueDataQualityRuleSetEvaluationRunOperator( + task_id="test_overwritten_conn_passed_to_hook", + datasource=self.DATA_SOURCE, + role=self.ROLE, + rule_set_names=self.RULE_SET_NAMES, + show_results=False, + aws_conn_id=OVERWRITTEN_CONN, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = GlueDataQualityRuleSetEvaluationRunOperator( + task_id="test_default_conn_passed_to_hook", + datasource=self.DATA_SOURCE, + role=self.ROLE, + rule_set_names=self.RULE_SET_NAMES, + show_results=False, + ) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestGlueDataQualityRuleRecommendationRunOperator: RUN_ID = "1234567890" @@ -756,3 +815,28 @@ def test_start_data_quality_rule_recommendation_run_wait_combinations( def test_template_fields(self): validate_template_fields(self.operator) + + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = GlueDataQualityRuleRecommendationRunOperator( + task_id="test_overwritten_conn_passed_to_hook", + datasource=self.DATA_SOURCE, + role=self.ROLE, + number_of_workers=10, + timeout=1000, + recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"}, + aws_conn_id=OVERWRITTEN_CONN, + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = GlueDataQualityRuleRecommendationRunOperator( + task_id="test_default_conn_passed_to_hook", + datasource=self.DATA_SOURCE, + role=self.ROLE, + number_of_workers=10, + timeout=1000, + recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"}, + ) + assert op.hook.aws_conn_id == DEFAULT_CONN diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py index 6c22c75e8aff7..565780db09ad4 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py @@ -174,6 +174,18 @@ def test_hook_attribute(self): assert hasattr(self.op, "hook") assert self.op.hook.__class__.__name__ == "RdsHook" + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = RdsBaseOperator( + task_id="test_overwritten_conn_passed_to_hook_task", aws_conn_id=OVERWRITTEN_CONN, dag=self.dag + ) + assert op.hook.aws_conn_id == OVERWRITTEN_CONN + + def test_default_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = RdsBaseOperator(task_id="test_default_conn_passed_to_hook_task", dag=self.dag) + assert op.hook.aws_conn_id == DEFAULT_CONN + class TestRdsCreateDbSnapshotOperator: @classmethod