Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
# If we change to Union syntax then mypy is not happy with UP007 Use `X | Y` for type annotations
# The only way to workaround it for now is to keep the union syntax with ignore for mypy
# We should try to resolve this later.
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[BaseClient, ServiceResource]) # type: ignore[operator] # noqa: UP007
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[BaseClient, ServiceResource]) # noqa: UP007


if AIRFLOW_V_3_0_PLUS:
Expand Down Expand Up @@ -636,7 +636,7 @@ def conn_config(self) -> AwsConnectionWrapper:
raise

return AwsConnectionWrapper(
conn=connection, # type: ignore[arg-type]
conn=connection,
region_name=self._region_name,
botocore_config=self._config,
verify=self._verify,
Expand Down Expand Up @@ -718,10 +718,10 @@ def _get_config(self, config: Config | None = None) -> Config:
# because the user_agent_extra field is generated at runtime.
user_agent_config = Config(
user_agent_extra=self._generate_user_agent_extra_field(
existing_user_agent_extra=config.user_agent_extra # type: ignore[union-attr]
existing_user_agent_extra=config.user_agent_extra
)
)
return config.merge(user_agent_config) # type: ignore[union-attr]
return config.merge(user_agent_config)

def get_client_type(
self,
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def _list_custom_waiters(self) -> list[str]:
return WaiterModel(model_config).waiter_names


class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): # type: ignore[operator] # noqa: UP007
class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): # noqa: UP007
"""
Base class for interact with AWS.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_ui_field_behaviour(cls) -> dict:

@cached_property
def conn(self):
return self.get_connection(self.redshift_conn_id) # type: ignore[attr-defined]
return self.get_connection(self.redshift_conn_id)

def _get_conn_params(self) -> dict[str, str | int]:
"""Retrieve connection parameters."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages
f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}"
]
try:
logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value]
logs = [self.get_cloudwatch_logs(relative_path, ti)]
except Exception as e:
logs = None
messages.append(str(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def hook(self):
"""To reduce overhead cache the hook for the notifier."""
return ChimeWebhookHook(chime_conn_id=self.chime_conn_id)

def notify(self, context: Context) -> None: # type: ignore[override]
def notify(self, context: Context) -> None:
"""Send a message to a Chime Chat Room."""
self.hook.send_message(message=self.message)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _has_new_records_func(self, **kwargs) -> bool:
self.log.info("flow_name: %s", flow_name)
af_client = self.hook.conn
task_instance = kwargs["task_instance"]
execution_id = task_instance.xcom_pull(task_ids=appflow_task_id, key="execution_id") # type: ignore
execution_id = task_instance.xcom_pull(task_ids=appflow_task_id, key="execution_id")
if not execution_id:
raise AirflowException(f"No execution_id found from task_id {appflow_task_id}!")
self.log.info("execution_id: %s", execution_id)
Expand All @@ -494,5 +494,5 @@ def _has_new_records_func(self, **kwargs) -> bool:
raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info!")
records_processed = execution["recordsProcessed"]
self.log.info("records_processed: %d", records_processed)
task_instance.xcom_push("records_processed", records_processed) # type: ignore
task_instance.xcom_push("records_processed", records_processed)
return records_processed > 0
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,15 @@ def _execute_datasync_task(self, context: Context) -> None:
aws_domain=DataSyncTaskExecutionLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
task_id=self.task_arn.split("/")[-1],
task_execution_id=self.task_execution_arn.split("/")[-1], # type: ignore[union-attr]
task_execution_id=self.task_execution_arn.split("/")[-1],
)
DataSyncTaskExecutionLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
task_id=self.task_arn.split("/")[-1],
task_execution_id=self.task_execution_arn.split("/")[-1], # type: ignore[union-attr]
task_execution_id=self.task_execution_arn.split("/")[-1],
)

self.log.info("You can view this DataSync task execution at %s", execution_url)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
serialized_fields={"collection_id": collection_id, "collection_name": collection_name},
waiter_name="collection_available",
# waiter_args is a dict[str, Any], allow a possible list of None (it is caught above)
waiter_args={"ids": [collection_id]} if collection_id else {"names": [collection_name]}, # type: ignore[list-item]
waiter_args={"ids": [collection_id]} if collection_id else {"names": [collection_name]},
failure_message="OpenSearch Serverless Collection creation failed.",
status_message="Status of OpenSearch Serverless Collection is",
status_queries=["status"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
from airflow.sdk import BaseOperator, BaseOperatorLink, BaseSensorOperator
from airflow.sdk.execution_time.xcom import XCom
else:
from airflow.models import BaseOperator, XCom # type: ignore[no-redef]
from airflow.models import BaseOperator, XCom
from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def delete_pipeline(name: str):
chain(
# TEST SETUP
test_context,
create_pipeline, # type: ignore[arg-type]
create_pipeline,
# TEST BODY
start_pipeline1,
start_pipeline2,
Expand Down
14 changes: 7 additions & 7 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_create_cluster_throws_exception_when_cluster_exists(
with pytest.raises(ClientError) as raised_exception:
eks_hook.create_cluster(
name=generated_test_data.existing_cluster_name,
**dict(ClusterInputs.REQUIRED), # type: ignore
**dict(ClusterInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -434,7 +434,7 @@ def test_create_nodegroup_throws_exception_when_cluster_not_found(self) -> None:
eks_hook.create_nodegroup(
clusterName=non_existent_cluster_name,
nodegroupName=non_existent_nodegroup_name,
**dict(NodegroupInputs.REQUIRED), # type: ignore
**dict(NodegroupInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand All @@ -458,7 +458,7 @@ def test_create_nodegroup_throws_exception_when_nodegroup_already_exists(
eks_hook.create_nodegroup(
clusterName=generated_test_data.cluster_name,
nodegroupName=generated_test_data.existing_nodegroup_name,
**dict(NodegroupInputs.REQUIRED), # type: ignore
**dict(NodegroupInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -488,7 +488,7 @@ def test_create_nodegroup_throws_exception_when_cluster_not_active(
eks_hook.create_nodegroup(
clusterName=generated_test_data.cluster_name,
nodegroupName=non_existent_nodegroup_name,
**dict(NodegroupInputs.REQUIRED), # type: ignore
**dict(NodegroupInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -846,7 +846,7 @@ def test_create_fargate_profile_throws_exception_when_cluster_not_found(self) ->
eks_hook.create_fargate_profile(
clusterName=non_existent_cluster_name,
fargateProfileName=non_existent_fargate_profile_name,
**dict(FargateProfileInputs.REQUIRED), # type: ignore
**dict(FargateProfileInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand All @@ -867,7 +867,7 @@ def test_create_fargate_profile_throws_exception_when_fargate_profile_already_ex
eks_hook.create_fargate_profile(
clusterName=generated_test_data.cluster_name,
fargateProfileName=generated_test_data.existing_fargate_profile_name,
**dict(FargateProfileInputs.REQUIRED), # type: ignore
**dict(FargateProfileInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -897,7 +897,7 @@ def test_create_fargate_profile_throws_exception_when_cluster_not_active(
eks_hook.create_fargate_profile(
clusterName=generated_test_data.cluster_name,
fargateProfileName=non_existent_fargate_profile_name,
**dict(FargateProfileInputs.REQUIRED), # type: ignore
**dict(FargateProfileInputs.REQUIRED),
)

assert_client_error_exception_thrown(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def test_run(appflow_conn, ctx, waiter_mock):
args = DUMP_COMMON_ARGS.copy()
args.pop("source")
operator = AppflowRunOperator(**args)
operator.execute(ctx) # type: ignore
operator.execute(ctx)
appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
appflow_conn.describe_flow_execution_records.assert_called_once()


@pytest.mark.db_test
def test_run_full(appflow_conn, ctx, waiter_mock):
operator = AppflowRunFullOperator(**DUMP_COMMON_ARGS)
operator.execute(ctx) # type: ignore
operator.execute(ctx)
run_assertions_base(appflow_conn, [])


Expand All @@ -136,7 +136,7 @@ def test_run_after(appflow_conn, ctx, waiter_mock):
operator = AppflowRunAfterOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00", **DUMP_COMMON_ARGS
)
operator.execute(ctx) # type: ignore
operator.execute(ctx)
run_assertions_base(
appflow_conn,
[
Expand All @@ -155,7 +155,7 @@ def test_run_before(appflow_conn, ctx, waiter_mock):
operator = AppflowRunBeforeOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00", **DUMP_COMMON_ARGS
)
operator.execute(ctx) # type: ignore
operator.execute(ctx)
run_assertions_base(
appflow_conn,
[
Expand All @@ -174,7 +174,7 @@ def test_run_daily(appflow_conn, ctx, waiter_mock):
operator = AppflowRunDailyOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00", **DUMP_COMMON_ARGS
)
operator.execute(ctx) # type: ignore
operator.execute(ctx)
run_assertions_base(
appflow_conn,
[
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_short_circuit(appflow_conn, ctx):
flow_name=FLOW_NAME,
appflow_run_task_id=TASK_ID,
)
operator.execute(ctx) # type: ignore
operator.execute(ctx)
appflow_conn.describe_flow_execution_records.assert_called_once_with(
flowName=FLOW_NAME, maxResults=100
)
Expand Down
16 changes: 8 additions & 8 deletions providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_execute_without_failures(
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

self.ecs.execute(mock_context) # type: ignore[arg-type]
self.ecs.execute(mock_context)

client_mock.run_task.assert_called_once_with(
cluster="c",
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_execute_with_failures(self, client_mock):
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

with pytest.raises(EcsOperatorError):
self.ecs.execute(mock_context) # type: ignore[arg-type]
self.ecs.execute(mock_context)

client_mock.run_task.assert_called_once_with(
cluster="c",
Expand Down Expand Up @@ -715,7 +715,7 @@ def test_execute_xcom_with_log(self, log_fetcher_mock, client_mock):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

assert self.ecs.execute(mock_context) == "Log output" # type: ignore[arg-type]
assert self.ecs.execute(mock_context) == "Log output"

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
Expand All @@ -728,7 +728,7 @@ def test_execute_xcom_with_no_log(self, log_fetcher_mock, client_mock):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

assert self.ecs.execute(mock_context) is None # type: ignore[arg-type]
assert self.ecs.execute(mock_context) is None

@mock.patch.object(EcsBaseOperator, "client")
def test_execute_xcom_with_no_log_fetcher(self, client_mock):
Expand All @@ -737,7 +737,7 @@ def test_execute_xcom_with_no_log_fetcher(self, client_mock):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

assert self.ecs.execute(mock_context) is None # type: ignore[arg-type]
assert self.ecs.execute(mock_context) is None

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch.object(AwsTaskLogFetcher, "get_last_log_message", return_value="Log output")
Expand All @@ -747,7 +747,7 @@ def test_execute_xcom_disabled(self, log_fetcher_mock, client_mock):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

assert self.ecs.execute(mock_context) is None # type: ignore[arg-type]
assert self.ecs.execute(mock_context) is None

@mock.patch.object(EcsRunTaskOperator, "client")
def test_with_defer(self, client_mock):
Expand All @@ -759,7 +759,7 @@ def test_with_defer(self, client_mock):
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

with pytest.raises(TaskDeferred) as deferred:
self.ecs.execute(mock_context) # type: ignore[arg-type]
self.ecs.execute(mock_context)

assert isinstance(deferred.value.trigger, TaskDoneTrigger)
assert deferred.value.trigger.task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
Expand All @@ -772,7 +772,7 @@ def test_execute_complete(self, client_mock):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

self.ecs.execute_complete(mock_context, event) # type: ignore[arg-type]
self.ecs.execute_complete(mock_context, event)

# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", tasks=["my_arn"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def test_template_fields(self):

class TestEksCreateFargateProfileOperator:
def setup_method(self) -> None:
self.create_fargate_profile_params = CreateFargateProfileParams( # type: ignore
self.create_fargate_profile_params = CreateFargateProfileParams(
cluster_name=CLUSTER_NAME,
pod_execution_role_arn=POD_EXECUTION_ROLE_ARN[1],
selectors=SELECTORS[1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_wait_job(self, mock_async_conn, mock_get_waiter):
waiter_delay=10,
)
generator = trigger.run()
event = await generator.asend(None) # type:ignore[attr-defined]
event = await generator.asend(None)

assert_expected_waiter_type(mock_get_waiter, "job_complete")
mock_get_waiter().wait.assert_called_once()
Expand Down Expand Up @@ -87,7 +87,7 @@ async def test_wait_job_failed(self, mock_async_conn, mock_get_waiter):
generator = trigger.run()

with pytest.raises(AirflowException):
await generator.asend(None) # type:ignore[attr-defined]
await generator.asend(None)
assert_expected_waiter_type(mock_get_waiter, "job_complete")

def test_serialization(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def attributes_to_test(
:param nodegroup_name: The name of the nodegroup under test if applicable.
:return: Returns a list of tuples containing the keys and values to be validated in testing.
"""
result: list[tuple] = deepcopy(inputs.REQUIRED + inputs.OPTIONAL + [STATUS]) # type: ignore
result: list[tuple] = deepcopy(inputs.REQUIRED + inputs.OPTIONAL + [STATUS])
if inputs == ClusterInputs:
result += [(ClusterAttributes.NAME, cluster_name)]
elif inputs == FargateProfileInputs:
Expand Down Expand Up @@ -178,7 +178,7 @@ def _input_builder(options: InputTypes, minimal: bool) -> dict:
:param minimal: If True, only the required values are generated; if False all values are generated.
:return: Returns a dict containing the keys and values to be validated in testing.
"""
values: list[tuple] = deepcopy(options.REQUIRED) # type: ignore
values: list[tuple] = deepcopy(options.REQUIRED)
if not minimal:
values.extend(deepcopy(options.OPTIONAL))
return dict(values) # type: ignore
Expand Down