Skip to content

Commit

Permalink
Fix mypy issues for HTTPSensorsAsync hooks, operators and triggers
Browse files Browse the repository at this point in the history
Fix mypy issues for S3KeySensors hooks, operators and triggers

Fix mypy issues for bigquery async

Remove duplicate mypy configs

Add some more TypeHints

Fix mypy issues for Amazon async operators and sensors
  • Loading branch information
sunank200 authored and kaxil committed Mar 2, 2022
1 parent 14301b9 commit 6d4890b
Show file tree
Hide file tree
Showing 16 changed files with 63 additions and 53 deletions.
2 changes: 1 addition & 1 deletion astronomer/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RedshiftHookAsync(AwsBaseHookAsync):
Interact with AWS Redshift using aiobotocore python library
"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "redshift"
kwargs["resource_type"] = "redshift"
super().__init__(*args, **kwargs)
Expand Down
8 changes: 4 additions & 4 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from io import StringIO
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import botocore.exceptions
from airflow import AirflowException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from snowflake.connector.util_text import split_statements


class RedshiftDataHook(AwsBaseHook):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
client_type: str = "redshift-data"
kwargs["client_type"] = "redshift-data"
kwargs["resource_type"] = "redshift-data"
Expand All @@ -17,7 +17,7 @@ def __init__(self, *args, **kwargs) -> None:

def get_conn_params(self) -> Dict[str, Union[str, int]]:
"""Helper method to retrieve connection args"""
connection_object = self.get_connection(self.aws_conn_id)
connection_object = self.get_connection(self.aws_conn_id) # type: ignore
extra_config = connection_object.extra_dejson

conn_params: Dict[str, Union[str, int]] = {}
Expand Down
3 changes: 2 additions & 1 deletion astronomer/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any

from astronomer.providers.amazon.aws.hooks.base_aws_async import AwsBaseHookAsync

Expand All @@ -13,7 +14,7 @@ class S3HookAsync(AwsBaseHookAsync):
conn_type = "s3"
hook_name = "S3"

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "s3"
kwargs["resource_type"] = "s3"
super().__init__(*args, **kwargs)
14 changes: 7 additions & 7 deletions astronomer/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
Expand Down Expand Up @@ -27,12 +27,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: "Context") -> Any:
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "paused":
Expand All @@ -52,7 +52,7 @@ def execute(self, context: "Context"):
"Unable to resume cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context, event=None):
def execute_complete(self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -83,12 +83,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: "Context") -> Any:
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "available":
Expand All @@ -108,7 +108,7 @@ def execute(self, context: "Context"):
"Unable to pause cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context, event=None):
def execute_complete(self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
12 changes: 7 additions & 5 deletions astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow import AirflowException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
Expand All @@ -19,12 +19,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
) -> None:
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: "Context") -> Any:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
query_ids = redshift_data_hook.execute_query(sql=self.sql, params=self.params)
self.defer(
Expand All @@ -38,7 +38,9 @@ def execute(self, context: "Context"):
method_name="execute_complete",
)

def execute_complete(self, context, event=None): # pylint: disable=unused-argument
def execute_complete(
self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None
) -> Any: # pylint: disable=unused-argument
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
10 changes: 6 additions & 4 deletions astronomer/providers/amazon/aws/sensors/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
Expand All @@ -26,12 +26,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: Dict[Any, Any]) -> None:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterSensorTrigger(
Expand All @@ -44,7 +44,9 @@ def execute(self, context: "Context"):
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event=None): # pylint: disable=unused-argument
def execute_complete(
self, context: "Context", event: Optional[Dict[Any, Any]] = None
) -> None: # pylint: disable=unused-argument
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
8 changes: 5 additions & 3 deletions astronomer/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
wildcard_match: bool = False,
aws_conn_id: str = "aws_default",
verify: Optional[Union[str, bool]] = None,
**kwargs,
**kwargs: Any,
):
super().__init__(**kwargs)

Expand All @@ -70,7 +70,7 @@ def _resolve_bucket_and_key(self) -> None:
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException("If bucket_name provided, bucket_key must be relative path, not URI.")

def execute(self, context: Dict) -> Any:
def execute(self, context: Dict[Any, Any]) -> Any:
self._resolve_bucket_and_key()
self.defer(
timeout=self.execution_timeout,
Expand All @@ -84,5 +84,7 @@ def execute(self, context: Dict) -> Any:
method_name="execute_complete",
)

def execute_complete(self, context: Dict, event=None): # pylint: disable=unused-argument
def execute_complete(
self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None
) -> None: # pylint: disable=unused-argument
return None
8 changes: 3 additions & 5 deletions astronomer/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Dict, Tuple
from typing import Any, AsyncIterator, Dict, Tuple

from airflow.exceptions import AirflowException
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -16,7 +16,6 @@ def __init__(
cluster_identifier: str,
operation_type: str,
):
super().__init__()
self.task_id = task_id
self.polling_period_seconds = polling_period_seconds
self.aws_conn_id = aws_conn_id
Expand All @@ -38,7 +37,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Make async connection to redshift, based on the operation type call
the RedshiftHookAsync functions
Expand Down Expand Up @@ -77,7 +76,6 @@ def __init__(
target_status: str,
polling_period_seconds: float,
):
super().__init__()
self.task_id = task_id
self.aws_conn_id = aws_conn_id
self.cluster_identifier = cluster_identifier
Expand All @@ -99,7 +97,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Simple async function run until the cluster status match the target status.
"""
Expand Down
5 changes: 2 additions & 3 deletions astronomer/providers/amazon/aws/triggers/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -13,7 +13,6 @@ def __init__(
aws_conn_id: str,
query_ids: List[str],
):
super().__init__()
self.task_id = task_id
self.polling_period_seconds = polling_period_seconds
self.aws_conn_id = aws_conn_id
Expand All @@ -33,7 +32,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Make async connection to redshiftSQL and execute query using
the Amazon Redshift Data API to interact with Amazon Redshift clusters
Expand Down
16 changes: 9 additions & 7 deletions astronomer/providers/amazon/aws/triggers/s3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import fnmatch
import logging
import re
from typing import Any, Dict, Tuple
from typing import Any, AsyncIterator, Dict, Tuple

from aiobotocore.session import ClientCreatorContext
from airflow.triggers.base import BaseTrigger, TriggerEvent
from botocore.exceptions import ClientError

Expand All @@ -18,17 +19,16 @@ def __init__(
bucket_key: str,
wildcard_match: bool = False,
aws_conn_id: str = "aws_default",
**hook_params,
**hook_params: Any,
):
super().__init__()
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params

@staticmethod
async def _check_exact_key(client, bucket, key) -> bool:
async def _check_exact_key(client: ClientCreatorContext, bucket: str, key: str) -> bool:
"""
Checks if a key exists in a bucket asynchronously
Expand All @@ -47,7 +47,7 @@ async def _check_exact_key(client, bucket, key) -> bool:
raise e

@staticmethod
async def _check_wildcard_key(client, bucket: str, wildcard_key: str) -> bool:
async def _check_wildcard_key(client: ClientCreatorContext, bucket: str, wildcard_key: str) -> bool:
"""
Checks that a key matching a wildcard expression exists in a bucket asynchronously
Expand All @@ -67,7 +67,9 @@ async def _check_wildcard_key(client, bucket: str, wildcard_key: str) -> bool:
return True
return False

async def _check_key(self, client, bucket: str, key: str, wildcard_match: bool) -> bool:
async def _check_key(
self, client: ClientCreatorContext, bucket: str, key: str, wildcard_match: bool
) -> bool:
"""
Checks if key exists or a key matching a wildcard expression exists in a bucket asynchronously
Expand Down Expand Up @@ -97,7 +99,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Make an asynchronous connection using S3HookAsync.
"""
Expand Down
2 changes: 1 addition & 1 deletion astronomer/providers/core/triggers/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> None:
"""
Simple loop until the relevant files are found.
"""
Expand Down
2 changes: 1 addition & 1 deletion astronomer/providers/google/cloud/triggers/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
object_name: str,
polling_period_seconds: float,
google_cloud_conn_id: str,
hook_params: dict[str, Any],
hook_params: Dict[str, Any],
):
self.bucket = bucket
self.object_name = object_name
Expand Down
10 changes: 5 additions & 5 deletions astronomer/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import asyncio
from typing import Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

import aiohttp
from aiohttp import ClientResponseError
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from asgiref.sync import sync_to_async

if TYPE_CHECKING:
from aiohttp.client_reqrep import ClientResponse


class HttpHookAsync(BaseHook):
"""
Interact with HTTP servers using Python Async.
:param method: the API method to be called
:type method: str
:param http_conn_id: :ref:`http connection<howto/connection:http>` that has the base
API url i.e https://www.google.com/ and optional authentication credentials. Default
headers can also be specified in the Extra field in json format.
:type http_conn_id: str
:param auth_type: The auth type for the service
:type auth_type: AuthBase of python aiohttp lib
"""
Expand All @@ -35,7 +36,6 @@ def __init__(
retry_limit: int = 3,
retry_delay: float = 1.0,
) -> None:
super().__init__()
self.http_conn_id = http_conn_id
self.method = method.upper()
self.base_url: str = ""
Expand All @@ -52,7 +52,7 @@ async def run(
data: Optional[Union[Dict[str, Any], str]] = None,
headers: Optional[Dict[str, Any]] = None,
extra_options: Optional[Dict[str, Any]] = None,
) -> Any:
) -> "ClientResponse":
r"""
Performs an asynchronous HTTP request call
Expand Down
Loading

0 comments on commit 6d4890b

Please sign in to comment.