Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,11 @@ def __init__(
*,
stack_name: str,
cloudformation_parameters: dict | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.cloudformation_parameters = cloudformation_parameters or {}
self.stack_name = stack_name
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def __init__(
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -305,7 +304,6 @@ def __init__(
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id

def execute(self, context: Context) -> str:
if self.output_data_config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(
table_mappings: dict,
migration_type: str = "full-load",
create_task_kwargs: dict | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -102,7 +101,6 @@ def __init__(
self.migration_type = migration_type
self.table_mappings = table_mappings
self.create_task_kwargs = create_task_kwargs or {}
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def __init__(
description: str = "AWS Glue Data Quality Rule Set With Airflow",
update_rule_set: bool = False,
data_quality_ruleset_kwargs: dict | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -322,7 +321,6 @@ def __init__(
self.description = description
self.update_rule_set = update_rule_set
self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
self.aws_conn_id = aws_conn_id

def validate_inputs(self) -> None:
if not self.ruleset.startswith("Rules") or not self.ruleset.endswith("]"):
Expand Down Expand Up @@ -421,7 +419,6 @@ def __init__(
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -437,7 +434,6 @@ def __init__(
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id

def validate_inputs(self) -> None:
glue_table = self.datasource.get("GlueTable", {})
Expand Down Expand Up @@ -584,7 +580,6 @@ def __init__(
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -598,7 +593,6 @@ def __init__(
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id

def execute(self, context: Context) -> str:
glue_table = self.datasource.get("GlueTable", {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def test_template_fields(self):

validate_template_fields(op)

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = CloudFormationCreateStackOperator(
task_id="cf_create_stack_pass_conn",
stack_name="fake-stack",
cloudformation_parameters={},
aws_conn_id=OVERWRITTEN_CONN,
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = CloudFormationCreateStackOperator(
task_id="cf_create_stack_pass_default_conn", stack_name="fake-stack", cloudformation_parameters={}
)
assert op.hook.aws_conn_id == DEFAULT_CONN


class TestCloudFormationDeleteStackOperator:
def test_init(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ def test_initialize_comprehend_base_operator_hook(self, comprehend_base_operator
assert comprehend_base_op.client == mocked_client
comprehend_base_operator_mock_hook.assert_called_once()

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = ComprehendBaseOperator(
task_id="comprehend_base_operator",
input_data_config=INPUT_DATA_CONFIG,
output_data_config=OUTPUT_DATA_CONFIG,
language_code=LANGUAGE_CODE,
data_access_role_arn=ROLE_ARN,
aws_conn_id=OVERWRITTEN_CONN,
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = ComprehendBaseOperator(
task_id="comprehend_base_operator",
input_data_config=INPUT_DATA_CONFIG,
output_data_config=OUTPUT_DATA_CONFIG,
language_code=LANGUAGE_CODE,
data_access_role_arn=ROLE_ARN,
)
assert op.hook.aws_conn_id == DEFAULT_CONN


class TestComprehendStartPiiEntitiesDetectionJobOperator:
JOB_ID = "random-job-id-1234567"
Expand Down
21 changes: 21 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ def test_template_fields(self):

validate_template_fields(op)

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = DmsCreateTaskOperator(
task_id="dms_create_task_operator",
**self.TASK_DATA,
aws_conn_id=OVERWRITTEN_CONN,
verify=True,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = DmsCreateTaskOperator(
task_id="dms_create_task_operator",
**self.TASK_DATA,
verify=True,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == DEFAULT_CONN


class TestDmsDeleteTaskOperator:
TASK_DATA = {
Expand Down
84 changes: 84 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,25 @@ def test_template_fields(self):
)
validate_template_fields(operator)

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = GlueJobOperator(
task_id=TASK_ID,
aws_conn_id=OVERWRITTEN_CONN,
iam_role_name="role_arn",
replace_script_file=True,
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = GlueJobOperator(
task_id=TASK_ID,
iam_role_name="role_arn",
replace_script_file=True,
)
assert op.hook.aws_conn_id == DEFAULT_CONN


class TestGlueDataQualityOperator:
RULE_SET_NAME = "TestRuleSet"
Expand Down Expand Up @@ -542,6 +561,23 @@ def test_template_fields(self):
)
validate_template_fields(operator)

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = GlueDataQualityOperator(
task_id="test_overwritten_conn_passed_to_hook",
name=self.RULE_SET_NAME,
ruleset=self.RULE_SET,
aws_conn_id=OVERWRITTEN_CONN,
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = GlueDataQualityOperator(
task_id="test_default_conn_passed_to_hook", name=self.RULE_SET_NAME, ruleset=self.RULE_SET
)
assert op.hook.aws_conn_id == DEFAULT_CONN


class TestGlueDataQualityRuleSetEvaluationRunOperator:
RUN_ID = "1234567890"
Expand Down Expand Up @@ -648,6 +684,29 @@ def test_start_data_quality_ruleset_evaluation_run_wait_combinations(
def test_template_fields(self):
validate_template_fields(self.operator)

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = GlueDataQualityRuleSetEvaluationRunOperator(
task_id="test_overwritten_conn_passed_to_hook",
datasource=self.DATA_SOURCE,
role=self.ROLE,
rule_set_names=self.RULE_SET_NAMES,
show_results=False,
aws_conn_id=OVERWRITTEN_CONN,
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = GlueDataQualityRuleSetEvaluationRunOperator(
task_id="test_default_conn_passed_to_hook",
datasource=self.DATA_SOURCE,
role=self.ROLE,
rule_set_names=self.RULE_SET_NAMES,
show_results=False,
)
assert op.hook.aws_conn_id == DEFAULT_CONN


class TestGlueDataQualityRuleRecommendationRunOperator:
RUN_ID = "1234567890"
Expand Down Expand Up @@ -756,3 +815,28 @@ def test_start_data_quality_rule_recommendation_run_wait_combinations(

def test_template_fields(self):
validate_template_fields(self.operator)

def test_overwritten_conn_passed_to_hook(self):
OVERWRITTEN_CONN = "new-conn-id"
op = GlueDataQualityRuleRecommendationRunOperator(
task_id="test_overwritten_conn_passed_to_hook",
datasource=self.DATA_SOURCE,
role=self.ROLE,
number_of_workers=10,
timeout=1000,
recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
aws_conn_id=OVERWRITTEN_CONN,
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = GlueDataQualityRuleRecommendationRunOperator(
task_id="test_default_conn_passed_to_hook",
datasource=self.DATA_SOURCE,
role=self.ROLE,
number_of_workers=10,
timeout=1000,
recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
)
assert op.hook.aws_conn_id == DEFAULT_CONN
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def test_overwritten_conn_passed_to_hook(self):
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN

def test_no_conn_passed_to_hook(self):
def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
op = RdsBaseOperator(task_id="test_no_conn_passed_to_hook_task", dag=self.dag)
op = RdsBaseOperator(task_id="test_default_conn_passed_to_hook_task", dag=self.dag)
assert op.hook.aws_conn_id == DEFAULT_CONN


Expand Down
Loading