Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/appflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs) -> None:

@cached_property
def conn(self) -> AppflowClient:
"""Get the underlying boto3 Appflow client (cached)"""
"""Get the underlying boto3 Appflow client (cached)."""
return super().conn

def run_flow(self, flow_name: str, poll_interval: int = 20, wait_for_completion: bool = True) -> str:
Expand Down
14 changes: 8 additions & 6 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_query(
workgroup: str = "primary",
) -> str:
"""
Run Presto query on athena with provided config and return submitted query_execution_id
Run Presto query on athena with provided config and return submitted query_execution_id.

.. seealso::
- :external+boto3:py:meth:`Athena.Client.start_query_execution`
Expand Down Expand Up @@ -152,8 +152,9 @@ def get_query_results(
self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000
) -> dict | None:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
Fetch submitted athena query results.

Returns none if query is in intermediate state or failed/cancelled state else dict of query output.

.. seealso::
- :external+boto3:py:meth:`Athena.Client.get_query_results`
Expand Down Expand Up @@ -188,7 +189,7 @@ def get_query_results_paginator(
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else a paginator to iterate through pages of results. If you
wish to get all results at once, call build_full_result() on the returned PageIterator
wish to get all results at once, call build_full_result() on the returned PageIterator.

.. seealso::
- :external+boto3:py:class:`Athena.Paginator.GetQueryResults`
Expand Down Expand Up @@ -227,7 +228,8 @@ def poll_query_status(
) -> str | None:
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states

Returns one of the final states.

:param query_execution_id: Id of submitted athena query
:param max_polling_attempts: Number of times to poll for query state before function exits
Expand Down Expand Up @@ -298,7 +300,7 @@ def get_output_location(self, query_execution_id: str) -> str:

def stop_query(self, query_execution_id: str) -> dict:
"""
Cancel the submitted athena query
Cancel the submitted athena query.

.. seealso::
- :external+boto3:py:meth:`Athena.Client.stop_query_execution`
Expand Down
25 changes: 14 additions & 11 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def config(self) -> Config | None:

@property
def role_arn(self) -> str | None:
"""Assume Role ARN from AWS Connection"""
"""Assume Role ARN from AWS Connection."""
return self.conn.role_arn

def _apply_session_kwargs(self, session):
Expand Down Expand Up @@ -585,7 +585,8 @@ def get_session(self, region_name: str | None = None, deferrable: bool = False)
def _get_config(self, config: Config | None = None) -> Config:
"""
No AWS Operators use the config argument to this method.
Keep backward compatibility with other users who might use it

Keep backward compatibility with other users who might use it.
"""
if config is None:
config = deepcopy(self.config)
Expand All @@ -605,7 +606,7 @@ def get_client_type(
config: Config | None = None,
deferrable: bool = False,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
"""Get the underlying boto3 client using boto3 session."""
client_type = self.client_type
session = self.get_session(region_name=region_name, deferrable=deferrable)
if not isinstance(session, boto3.session.Session):
Expand All @@ -628,7 +629,7 @@ def get_resource_type(
region_name: str | None = None,
config: Config | None = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
"""Get the underlying boto3 resource using boto3 session."""
resource_type = self.resource_type
session = self.get_session(region_name=region_name)
return session.resource(
Expand All @@ -641,7 +642,7 @@ def get_resource_type(
@cached_property
def conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
Get the underlying boto3 client/resource (cached).

:return: boto3.client or boto3.resource
"""
Expand Down Expand Up @@ -683,7 +684,7 @@ def conn_partition(self) -> str:

def get_conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
Get the underlying boto3 client/resource (cached).

Implemented so that caching works as intended. It exists for compatibility
with subclasses that rely on a super().get_conn() method.
Expand Down Expand Up @@ -873,7 +874,7 @@ def get_waiter(

@staticmethod
def _apply_parameters_value(config: dict, waiter_name: str, parameters: dict[str, str] | None) -> dict:
"""Replaces potential jinja templates in acceptors definition"""
"""Replaces potential jinja templates in acceptors definition."""
# only process the waiter we're going to use to not raise errors for missing params for other waiters.
acceptors = config["waiters"][waiter_name]["acceptors"]
for a in acceptors:
Expand Down Expand Up @@ -927,7 +928,7 @@ class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):


def resolve_session_factory() -> type[BaseSessionFactory]:
"""Resolves custom SessionFactory class"""
"""Resolves custom SessionFactory class."""
clazz = conf.getimport("aws", "session_factory", fallback=None)
if not clazz:
return BaseSessionFactory
Expand All @@ -943,7 +944,7 @@ def resolve_session_factory() -> type[BaseSessionFactory]:


def _parse_s3_config(config_file_name: str, config_format: str | None = "boto", profile: str | None = None):
"""For compatibility with airflow.contrib.hooks.aws_hook"""
"""For compatibility with airflow.contrib.hooks.aws_hook."""
from airflow.providers.amazon.aws.utils.connection_wrapper import _parse_s3_config

return _parse_s3_config(
Expand Down Expand Up @@ -978,7 +979,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def get_role_credentials(self) -> dict:
"""Get the role_arn, method credentials from connection details and get the role credentials detail"""
"""Get the role_arn, method credentials from connection details and get the role credentials
detail.
"""
async with self._basic_session.create_client("sts", region_name=self.region_name) as client:
response = await client.assume_role(
RoleArn=self.role_arn,
Expand Down Expand Up @@ -1086,7 +1089,7 @@ def get_async_session(self) -> AioSession:
).create_session()

async def get_client_async(self):
"""Get the underlying aiobotocore client using aiobotocore session"""
"""Get the underlying aiobotocore client using aiobotocore session."""
return self.get_async_session().create_client(
self.client_type,
region_name=self.region_name,
Expand Down
24 changes: 12 additions & 12 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""
A client for AWS Batch services
A client for AWS Batch services.

.. seealso::

Expand Down Expand Up @@ -53,7 +53,7 @@ class BatchProtocol(Protocol):

def describe_jobs(self, jobs: list[str]) -> dict:
"""
Get job descriptions from AWS Batch
Get job descriptions from AWS Batch.

:param jobs: a list of JobId to describe

Expand All @@ -63,7 +63,7 @@ def describe_jobs(self, jobs: list[str]) -> dict:

def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter:
"""
Get an AWS Batch service waiter
Get an AWS Batch service waiter.

:param waiterName: The name of the waiter. The name should match
the name (including the casing) of the key name in the waiter
Expand Down Expand Up @@ -98,7 +98,7 @@ def submit_job(
tags: dict,
) -> dict:
"""
Submit a Batch job
Submit a Batch job.

:param jobName: the name for the AWS Batch job

Expand All @@ -120,7 +120,7 @@ def submit_job(

def terminate_job(self, jobId: str, reason: str) -> dict:
"""
Terminate a Batch job
Terminate a Batch job.

:param jobId: a job ID to terminate

Expand Down Expand Up @@ -216,7 +216,7 @@ def client(self) -> BatchProtocol | botocore.client.BaseClient:

def terminate_job(self, job_id: str, reason: str) -> dict:
"""
Terminate a Batch job
Terminate a Batch job.

:param job_id: a job ID to terminate

Expand All @@ -230,11 +230,11 @@ def terminate_job(self, job_id: str, reason: str) -> dict:

def check_job_success(self, job_id: str) -> bool:
"""
Check the final status of the Batch job; return True if the job
'SUCCEEDED', else raise an AirflowException
Check the final status of the Batch job.

:param job_id: a Batch job ID
Return True if the job 'SUCCEEDED', else raise an AirflowException.

:param job_id: a Batch job ID

:raises: AirflowException
"""
Expand All @@ -255,7 +255,7 @@ def check_job_success(self, job_id: str) -> bool:

def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
"""
Wait for Batch job to complete
Wait for Batch job to complete.

:param job_id: a Batch job ID

Expand Down Expand Up @@ -396,7 +396,7 @@ def get_job_description(self, job_id: str) -> dict:
@staticmethod
def parse_job_description(job_id: str, response: dict) -> dict:
"""
Parse job description to extract description for job_id
Parse job description to extract description for job_id.

:param job_id: a Batch job ID

Expand Down Expand Up @@ -488,7 +488,7 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
@staticmethod
def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float:
"""
Use delay +/- width for random jitter
Use delay +/- width for random jitter.

Adding jitter to status polling can help to avoid
AWS Batch API limits for monitoring Batch jobs with
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""
AWS Batch service waiters
AWS Batch service waiters.

.. seealso::

Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, *args, waiter_config: dict | None = None, **kwargs) -> None:
@property
def default_config(self) -> dict:
"""
An immutable default waiter configuration
An immutable default waiter configuration.

:return: a waiter configuration for AWS Batch services
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS CloudFormation Hook"""
"""This module contains AWS CloudFormation Hook."""
from __future__ import annotations

from boto3 import client, resource
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def delete_task(self, task_arn: str) -> None:
self.get_conn().delete_task(TaskArn=task_arn)

def _refresh_tasks(self) -> None:
"""Refreshes the local list of Tasks"""
"""Refreshes the local list of Tasks."""
self.tasks = []
next_token = None
while True:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/hooks/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, *args, **kwargs):

def describe_replication_tasks(self, **kwargs) -> tuple[str | None, list]:
"""
Describe replication tasks
Describe replication tasks.

.. seealso::
- :external+boto3:py:meth:`DatabaseMigrationService.Client.describe_replication_tasks`
Expand All @@ -65,7 +65,7 @@ def describe_replication_tasks(self, **kwargs) -> tuple[str | None, list]:

def find_replication_tasks_by_arn(self, replication_task_arn: str, without_settings: bool | None = False):
"""
Find and describe replication tasks by task ARN
Find and describe replication tasks by task ARN.

.. seealso::
- :external+boto3:py:meth:`DatabaseMigrationService.Client.describe_replication_tasks`
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains the Amazon DynamoDB Hook"""
"""This module contains the Amazon DynamoDB Hook."""
from __future__ import annotations

from typing import Iterable
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/hooks/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_instance(self, instance_id: str, filters: list | None = None):
@only_client_type
def stop_instances(self, instance_ids: list) -> dict:
"""
Stop instances with given ids
Stop instances with given ids.

:param instance_ids: List of instance ids to stop
:return: Dict with key `StoppingInstances` and value as list of instances being stopped
Expand All @@ -103,7 +103,7 @@ def stop_instances(self, instance_ids: list) -> dict:
@only_client_type
def start_instances(self, instance_ids: list) -> dict:
"""
Start instances with given ids
Start instances with given ids.

:param instance_ids: List of instance ids to start
:return: Dict with key `StartingInstances` and value as list of instances being started
Expand All @@ -115,7 +115,7 @@ def start_instances(self, instance_ids: list) -> dict:
@only_client_type
def terminate_instances(self, instance_ids: list) -> dict:
"""
Terminate instances with given ids
Terminate instances with given ids.

:param instance_ids: List of instance ids to terminate
:return: Dict with key `TerminatingInstances` and value as list of instances being terminated
Expand All @@ -127,7 +127,7 @@ def terminate_instances(self, instance_ids: list) -> dict:
@only_client_type
def describe_instances(self, filters: list | None = None, instance_ids: list | None = None):
"""
Describe EC2 instances, optionally applying filters and selective instance ids
Describe EC2 instances, optionally applying filters and selective instance ids.

:param filters: List of filters to specify instances to describe
:param instance_ids: List of instance IDs to describe
Expand All @@ -144,7 +144,7 @@ def describe_instances(self, filters: list | None = None, instance_ids: list | N
@only_client_type
def get_instances(self, filters: list | None = None, instance_ids: list | None = None) -> list:
"""
Get list of instance details, optionally applying filters and selective instance ids
Get list of instance details, optionally applying filters and selective instance ids.

:param instance_ids: List of ids to get instances for
:param filters: List of filters to specify instances to get
Expand All @@ -159,7 +159,7 @@ def get_instances(self, filters: list | None = None, instance_ids: list | None =
@only_client_type
def get_instance_ids(self, filters: list | None = None) -> list:
"""
Get list of instance ids, optionally applying filters to fetch selective instances
Get list of instance ids, optionally applying filters to fetch selective instances.

:param filters: List of filters to specify instances to get
:return: List of instance ids
Expand Down
Loading