Skip to content

Commit

Permalink
Fix mypy issues for Amazon async operators and sensors
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed Feb 28, 2022
1 parent 0f5562e commit 13e57cb
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from datetime import datetime

import airflow
from airflow.operators.dummy import DummyOperator
from airflow.utils.dates import days_ago

from astronomer.providers.amazon.aws.operators.redshift_cluster import (
RedshiftPauseClusterOperatorAsync,
Expand All @@ -16,7 +16,7 @@

with airflow.DAG(
dag_id="example_async_redshift_cluster_management",
start_date=days_ago(1),
start_date=datetime(2021, 1, 1),
tags=["example", "async"],
schedule_interval="@once",
catchup=False,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from datetime import datetime

import airflow
from airflow.utils.dates import days_ago

from astronomer.providers.amazon.aws.operators.redshift_sql import (
RedshiftSQLOperatorAsync,
)

with airflow.DAG(
dag_id="example_async_redshift_sql",
start_date=days_ago(1),
start_date=datetime(2021, 1, 1),
tags=["example", "async"],
schedule_interval="@once",
catchup=False,
Expand Down
5 changes: 2 additions & 3 deletions astronomer/providers/amazon/aws/example_dags/example_s3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import timedelta
from datetime import datetime, timedelta

from airflow import DAG
from airflow.utils.dates import days_ago

from astronomer.providers.amazon.aws.sensors.s3 import S3KeySensorAsync

Expand All @@ -13,7 +12,7 @@
with DAG(
dag_id="example_s3_key_sensor",
schedule_interval="@daily",
start_date=days_ago(3),
start_date=datetime(2021, 1, 1),
catchup=False,
default_args=default_args,
tags=["async"],
Expand Down
6 changes: 4 additions & 2 deletions astronomer/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RedshiftHookAsync(AwsBaseHookAsync):
Interact with AWS Redshift using aiobotocore python library
"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "redshift"
kwargs["resource_type"] = "redshift"
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -84,7 +84,9 @@ async def resume_cluster(self, cluster_identifier: str) -> Dict[str, Any]:
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error)}

async def get_cluster_status(self, cluster_identifier, expected_state, flag) -> Dict[str, Any]:
async def get_cluster_status(
self, cluster_identifier: str, expected_state: str, flag: Any
) -> Dict[str, Any]:
"""
make call self.cluster_status to know the status and run till the expected_state is met and set the flag
Expand Down
10 changes: 5 additions & 5 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from io import StringIO
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List, Union

import botocore.exceptions
from airflow import AirflowException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from snowflake.connector.util_text import split_statements


class RedshiftDataHook(AwsBaseHook):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
client_type: str = "redshift-data"
kwargs["client_type"] = "redshift-data"
kwargs["resource_type"] = "redshift-data"
Expand All @@ -17,7 +17,7 @@ def __init__(self, *args, **kwargs) -> None:

def get_conn_params(self) -> Dict[str, Union[str, int]]:
"""Helper method to retrieve connection args"""
connection_object = self.get_connection(self.aws_conn_id)
connection_object = self.get_connection(self.aws_conn_id) # type: ignore
extra_config = connection_object.extra_dejson

conn_params: Dict[str, Union[str, int]] = {}
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]:

return conn_params

def execute_query(self, sql: Optional[Union[Dict, Iterable]], params: Optional[Dict]):
def execute_query(self, sql: Any, params: Any) -> Any:
"""
Runs an SQL statement, which can be data manipulation language (DML)
or data definition language (DDL)
Expand Down
11 changes: 7 additions & 4 deletions astronomer/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import asyncio
import logging
from typing import List
from typing import Any, Dict, List

import botocore.exceptions
from asgiref.sync import sync_to_async
from async_timeout import asyncio

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook

# from async_timeout import asyncio


log = logging.getLogger(__name__)


class RedshiftSQLHookAsync(RedshiftDataHook):
async def get_query_status(self, query_ids: List[str]):
async def get_query_status(self, query_ids: List[str]) -> Dict[Any, Any]:
"""
Async function to get the Query status by query Ids, this function
takes list of query_ids make async connection
Expand Down Expand Up @@ -42,7 +45,7 @@ async def get_query_status(self, query_ids: List[str]):
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error), "type": "ERROR"}

async def is_still_running(self, id: str):
async def is_still_running(self, id: str) -> Any:
"""
Async function to whether the query is still running or in
"PICKED", "STARTED", "SUBMITTED" state and returns True else
Expand Down
14 changes: 7 additions & 7 deletions astronomer/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
Expand Down Expand Up @@ -27,12 +27,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: "Context") -> Any:
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "paused":
Expand All @@ -52,7 +52,7 @@ def execute(self, context: "Context"):
"Unable to resume cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context, event=None):
def execute_complete(self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -83,12 +83,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: "Context") -> Any:
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "available":
Expand All @@ -108,7 +108,7 @@ def execute(self, context: "Context"):
"Unable to pause cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context, event=None):
def execute_complete(self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
12 changes: 7 additions & 5 deletions astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow import AirflowException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
Expand All @@ -19,12 +19,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
) -> None:
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: "Context") -> Any:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
query_ids = redshift_data_hook.execute_query(sql=self.sql, params=self.params)
self.defer(
Expand All @@ -38,7 +38,9 @@ def execute(self, context: "Context"):
method_name="execute_complete",
)

def execute_complete(self, context, event=None): # pylint: disable=unused-argument
def execute_complete(
self, context: Dict[Any, Any], event: Optional[Dict[Any, Any]] = None
) -> Any: # pylint: disable=unused-argument
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
Expand All @@ -26,12 +26,12 @@ def __init__(
self,
*,
poll_interval: float = 5,
**kwargs,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context"):
def execute(self, context: Dict[Any, Any]) -> None:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterSensorTrigger(
Expand All @@ -44,7 +44,9 @@ def execute(self, context: "Context"):
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event=None): # pylint: disable=unused-argument
def execute_complete(
self, context: "Context", event: Optional[Dict[Any, Any]] = None
) -> None: # pylint: disable=unused-argument
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
8 changes: 3 additions & 5 deletions astronomer/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Dict, Tuple
from typing import Any, AsyncIterator, Dict, Tuple

from airflow.exceptions import AirflowException
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -16,7 +16,6 @@ def __init__(
cluster_identifier: str,
operation_type: str,
):
super().__init__()
self.task_id = task_id
self.polling_period_seconds = polling_period_seconds
self.aws_conn_id = aws_conn_id
Expand All @@ -38,7 +37,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Make async connection to redshift, based on the operation type call
the RedshiftHookAsync functions
Expand Down Expand Up @@ -77,7 +76,6 @@ def __init__(
target_status: str,
polling_period_seconds: float,
):
super().__init__()
self.task_id = task_id
self.aws_conn_id = aws_conn_id
self.cluster_identifier = cluster_identifier
Expand All @@ -99,7 +97,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Simple async function run until the cluster status match the target status.
"""
Expand Down
5 changes: 2 additions & 3 deletions astronomer/providers/amazon/aws/triggers/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -13,7 +13,6 @@ def __init__(
aws_conn_id: str,
query_ids: List[str],
):
super().__init__()
self.task_id = task_id
self.polling_period_seconds = polling_period_seconds
self.aws_conn_id = aws_conn_id
Expand All @@ -33,7 +32,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Make async connection to redshiftSQL and execute query using
the Amazon Redshift Data API to interact with Amazon Redshift clusters
Expand Down

0 comments on commit 13e57cb

Please sign in to comment.