Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MyPy in strict mode #60

Merged
merged 11 commits into from
Mar 9, 2022
26 changes: 26 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ workflows:
test:
jobs:
- static-checks
- mypy
- test:
name: test-python<< matrix.python_version >>
matrix:
Expand Down Expand Up @@ -49,6 +50,31 @@ jobs:
- ~/.cache/pre-commit
- ~/.pyenv/versions/

mypy:
description: "Mypy"
executor:
name: docker-executor
python_version: "3.9"
steps:
- checkout
- restore_cache:
keys:
- mypy-{{ .Branch }}-{{ checksum "setup.cfg" }}-{{ checksum "/home/circleci/.pyenv/version" }}
- mypy-main-{{ checksum "setup.cfg" }}-{{ checksum "/home/circleci/.pyenv/version" }}
- run:
name: Install Dependencies
command: pip install -U -e .[mypy]
- run:
name: Run Mypy
command: |
mypy --version
mypy astronomer/
- save_cache:
paths:
- ~/.cache/pip
- ~/.pyenv/versions/
key: mypy-{{ .Branch }}-{{ checksum "setup.cfg" }}-{{ checksum "/home/circleci/.pyenv/version" }}

test:
parameters:
python_version:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ docs/*/_api/
test-report/
.coverage
coverage.xml

# Mypy Cache
.mypy_cache
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace


- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ run-mypy: ## Run MyPy in Container
docker build -f dev/Dockerfile . -t astronomer-providers-dev
docker run -v `pwd`:/usr/local/airflow/astronomer_providers -v `pwd`/dev/.cache:/home/astro/.cache \
-w /usr/local/airflow/astronomer_providers \
--rm -it astronomer-providers-dev -- mypy --install-types $(RUN_ARGS)
--rm -it astronomer-providers-dev \
-- mypy --install-types --cache-dir /home/astro/.cache/.mypy_cache $(RUN_ARGS)

help: ## Prints this message
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
14 changes: 7 additions & 7 deletions astronomer/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@ async def is_keys_unchanged(

if current_num_objects >= min_objects:
success_message = (
"SUCCESS: \nSensor found %s objects at %s.\n"
"Waited at least %s seconds, with no new objects uploaded.",
current_num_objects,
path,
inactivity_period,
"SUCCESS: Sensor found %s objects at %s. "
"Waited at least %s seconds, with no new objects uploaded."
)
self.log.info(success_message)
return {"status": "success", "message": success_message}
self.log.info(success_message, current_num_objects, path, inactivity_period)
return {
"status": "success",
"message": success_message % (current_num_objects, path, inactivity_period),
}

self.log.error("FAILURE: Inactivity Period passed, not enough objects found in %s", path)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None:
"Unable to resume cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -108,7 +108,7 @@ def execute(self, context: "Context") -> None:
"Unable to pause cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
7 changes: 5 additions & 2 deletions astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Dict, cast
from typing import TYPE_CHECKING, Any, Dict, cast

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from astronomer.providers.amazon.aws.triggers.redshift_sql import RedshiftSQLTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context


class RedshiftSQLOperatorAsync(RedshiftSQLOperator):
"""
Expand All @@ -21,7 +24,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: Dict[str, Any]) -> None:
def execute(self, context: "Context") -> None:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
query_ids, response = redshift_data_hook.execute_query(sql=cast(str, self.sql), params=self.params)
if response.get("status") == "error":
Expand Down
9 changes: 6 additions & 3 deletions astronomer/providers/amazon/aws/sensors/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
Expand All @@ -8,6 +8,9 @@
RedshiftClusterSensorTrigger,
)

if TYPE_CHECKING:
from airflow.utils.context import Context

log = logging.getLogger(__name__)


Expand All @@ -28,7 +31,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: Dict[Any, Any]) -> None:
def execute(self, context: Dict[str, Any]) -> None:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterSensorTrigger(
Expand All @@ -41,7 +44,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
def execute_complete(self, context: "Context", event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
12 changes: 6 additions & 6 deletions astronomer/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _resolve_bucket_and_key(self) -> None:
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException("If bucket_name provided, bucket_key must be relative path, not URI.")

def execute(self, context: Dict[Any, Any]) -> None:
def execute(self, context: Dict[str, Any]) -> None:
self._resolve_bucket_and_key()
self.defer(
timeout=self.execution_timeout,
Expand All @@ -89,7 +89,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
return None
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
super().__init__(**kwargs)
self.check_fn_user = check_fn

def execute(self, context: Dict[Any, Any]) -> None:
def execute(self, context: Dict[str, Any]) -> None:
self._resolve_bucket_and_key()
self.defer(
timeout=self.execution_timeout,
Expand All @@ -152,7 +152,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
return None
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(
self.verify = verify
self.last_activity_time: Optional[datetime] = None

def execute(self, context: Dict[Any, Any]) -> None:
def execute(self, context: Dict[str, Any]) -> None:
self.defer(
timeout=self.execution_timeout,
trigger=S3KeysUnchangedTrigger(
Expand All @@ -237,7 +237,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
return None
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import aiofiles
from airflow import AirflowException
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from kubernetes_asyncio import client, config


class KubernetesHookAsync(KubernetesHook):
async def _load_config(self):
async def _load_config(self) -> client.ApiClient:
"""
cluster_context: Optional[str] = None,
config_file: Optional[str] = None,
Expand Down Expand Up @@ -50,7 +50,7 @@ async def _load_config(self):
return client.ApiClient()

if kubeconfig is not None:
async with aiofiles.tempfile.NamedTemporaryFile() as temp_config:
async with aiofiles.tempfile.NamedTemporaryFile() as temp_config: # type: ignore[attr-defined]
self.log.debug("loading kube_config from: connection kube_config")
temp_config.write(kubeconfig.encode())
temp_config.flush()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from airflow import AirflowException
from typing import Any, Dict

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
)
from airflow.utils.context import Context

from astronomer.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
Expand All @@ -20,12 +23,14 @@ class KubernetesPodOperatorAsync(KubernetesPodOperator):
:param poll_interval: interval in seconds to sleep between checking pod status
"""

def __init__(self, *, poll_interval: int = 5, **kwargs):
def __init__(self, *, poll_interval: int = 5, **kwargs: Any):
self.pod = None
self.pod_request_obj = None
self.poll_interval = poll_interval
super().__init__(**kwargs)

@staticmethod
def raise_for_trigger_status(event):
def raise_for_trigger_status(event: Dict[str, Any]) -> None:
if event["status"] == "error":
error_type = event["error_type"]
description = event["description"]
Expand All @@ -34,7 +39,7 @@ def raise_for_trigger_status(event):
else:
raise AirflowException(description)

def execute(self, context):
def execute(self, context: Context) -> None:
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.get_or_create_pod(self.pod_request_obj, context)
self.defer(
Expand All @@ -54,7 +59,7 @@ def execute(self, context):
method_name=self.execute_complete.__name__,
)

def execute_complete(self, context, event=None):
def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any:
remote_pod = None
try:
self.pod_request_obj = self.build_pod_request_obj(context)
Expand Down
27 changes: 11 additions & 16 deletions astronomer/providers/cncf/kubernetes/triggers/wait_container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import traceback
from datetime import timedelta
from typing import Any, Dict, Optional, Tuple
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
Expand All @@ -27,32 +27,27 @@ class WaitContainerTrigger(BaseTrigger):
Next, waits for ``container_name`` to reach a terminal state.

:param kubernetes_conn_id: Airflow connection ID to use
:type kubernetes_conn_id: str
:param hook_params: kwargs for hook
:type hook_params: dict
:param container_name: container to wait for
:type container_name: str
:param pod_name: name of pod to monitor
:type pod_name: str
:param pod_namespace: pod namespace
:type pod_namespace: str
:param pending_phase_timeout: max time in seconds to wait for pod to leave pending phase
:type pending_phase_timeout: float
:param poll_interval: number of seconds between reading pod state
:type poll_interval: float

"""

def __init__(
self,
*,
container_name: str,
pod_name: str,
pod_namespace: str,
kubernetes_conn_id: Optional[str] = None,
hook_params: Optional[dict] = None,
container_name: Optional[str] = None,
pod_name: Optional[str] = None,
pod_namespace: Optional[str] = None,
hook_params: Optional[Dict[str, Any]] = None,
pending_phase_timeout: float = 120,
poll_interval: float = 5,
):
super().__init__()
self.kubernetes_conn_id = kubernetes_conn_id
self.hook_params = hook_params
self.container_name = container_name
Expand All @@ -78,7 +73,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
async def get_hook(self) -> KubernetesHookAsync:
return KubernetesHookAsync(conn_id=self.kubernetes_conn_id, **(self.hook_params or {}))

async def wait_for_pod_start(self, v1_api: CoreV1Api):
async def wait_for_pod_start(self, v1_api: CoreV1Api) -> Any:
"""
Loops until pod phase leaves ``PENDING``
If timeout is reached, throws error.
Expand All @@ -92,15 +87,15 @@ async def wait_for_pod_start(self, v1_api: CoreV1Api):
await asyncio.sleep(self.poll_interval)
raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout")

async def wait_for_container_completion(self, v1_api: CoreV1Api):
async def wait_for_container_completion(self, v1_api: CoreV1Api) -> None:
"""Waits until container ``self.container_name`` is no longer in running state."""
while True:
pod = await v1_api.read_namespaced_pod(self.pod_name, self.pod_namespace)
if not container_is_running(pod=pod, container_name=self.container_name):
break
await asyncio.sleep(self.poll_interval)

async def run(self):
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
self.log.debug("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
hook = await self.get_hook()
Expand All @@ -120,7 +115,7 @@ async def run(self):
}
)

def _format_exception_description(self, exc: Exception):
def _format_exception_description(self, exc: Exception) -> Any:
if isinstance(exc, PodLaunchTimeoutException):
return exc.args[0]

Expand Down
Loading