Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve modules import in AWS probvider by move some of them into a type-checking block #33780

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
from __future__ import annotations

import warnings
from typing import Any

from botocore.paginate import PageIterator
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
from botocore.paginate import PageIterator


class AthenaHook(AwsBaseHook):
"""Interact with Amazon Athena.
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
import jinja2
import requests
import tenacity
from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
from botocore.waiter import Waiter, WaiterModel
from dateutil.tz import tzlocal
from slugify import slugify
Expand All @@ -66,6 +64,9 @@
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])

if TYPE_CHECKING:
from botocore.client import ClientMeta
from botocore.credentials import ReadOnlyCredentials

from airflow.models.connection import Connection # Avoid circular imports.


Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@
import itertools
from random import uniform
from time import sleep
from typing import Callable
from typing import TYPE_CHECKING, Callable

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher


@runtime_checkable
class BatchProtocol(Protocol):
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@
import sys
from copy import deepcopy
from pathlib import Path
from typing import Callable
from typing import TYPE_CHECKING, Callable

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher

if TYPE_CHECKING:
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher


class BatchWaitersHook(BatchClientHook):
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/hooks/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
"""This module contains AWS CloudFormation Hook."""
from __future__ import annotations

from boto3 import client, resource
from typing import TYPE_CHECKING

from botocore.exceptions import ClientError

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

if TYPE_CHECKING:
from boto3 import client, resource


class CloudFormationHook(AwsBaseHook):
"""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/ecr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import base64
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.log.secrets_masker import mask_secret

if TYPE_CHECKING:
from datetime import datetime
Comment on lines +28 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is any beneficial to put datetime from stdlib under type-checking

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I will add a pre-commit to check if the imported methods/classes are used in runtime or just for type checking, so if we want to keep it we will need to add an ignore annotation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be very cautions with this pre-commit validations, because it need to handle a lot of stuff the nice example is pydantic which use annotations in runtime , and put them under if TYPE_CHECKING block might lead serious problem

For example lets use example from the main page of pydantic documentation

This is work fine

from datetime import datetime
from typing import Tuple

from pydantic import BaseModel


class Delivery(BaseModel):
    timestamp: datetime
    dimensions: Tuple[int, int]


m = Delivery(timestamp='2020-01-02T03:04:05Z', dimensions=['10', '20'])

and this one would fail on model compilation

from typing import Tuple, TYPE_CHECKING

from pydantic import BaseModel

if TYPE_CHECKING:
    from datetime import datetime


class Delivery(BaseModel):
    timestamp: datetime
    dimensions: Tuple[int, int]
/Users/taragolis/.pyenv/versions/3.9.10/envs/narrative/bin/python /Users/taragolis/Library/Application Support/JetBrains/PyCharm2023.2/scratches/pydantic_sample.py 
Traceback (most recent call last):
  File "/Users/taragolis/Library/Application Support/JetBrains/PyCharm2023.2/scratches/pydantic_sample.py", line 9, in <module>
    class Delivery(BaseModel):
  File "/Users/taragolis/Library/Application Support/JetBrains/PyCharm2023.2/scratches/pydantic_sample.py", line 10, in Delivery
    timestamp: datetime
NameError: name 'datetime' is not defined

Process finished with exit code 1

Copy link
Contributor

@Taragolis Taragolis Aug 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather create pre-commit around deny-list, rather that around all stuff.
Just because if TYPE_CHECKING it is a hack and put everything under this block might add more pain rather than solve it, especially it valid in the project with 100k+ lines of code

Copy link
Member

@uranusjr uranusjr Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above only fails because you didn’t provide from __future__ import annotations. You probably meant to illustrate something like this instead:

from __future__ import annotations
from typing import Tuple, TYPE_CHECKING
from pydantic import BaseModel

if TYPE_CHECKING:
    from datetime import datetime

class Delivery(BaseModel):
    timestamp: datetime
    dimensions: Tuple[int, int]

Delivery(timestamp="2023-08-28", dimentions=(2, 3))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
    Delivery(timestamp="2023-08-28", dimentions=(2, 3))
  File "pydantic/main.py", line 404, in __init__
    values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data)
  File "pydantic/main.py", line 1040, in validate_model
    v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_)
  File "pydantic/fields.py", line 699, in validate
    raise ConfigError(
pydantic.errors.ConfigError: field "timestamp" not yet prepared so type is still a ForwardRef, you might need to call Delivery.update_forward_refs().


logger = logging.getLogger(__name__)


Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
# under the License.
from __future__ import annotations

from botocore.waiter import Waiter
from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import _StringCompareEnum
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from botocore.waiter import Waiter


def should_retry(exception: Exception):
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/links/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
# under the License.
from __future__ import annotations

from typing import Any

import boto3
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
import boto3


class EmrClusterLink(BaseAwsLink):
"""Helper class for constructing AWS EMR Cluster Link."""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@

from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING

import watchtower

from airflow.configuration import conf
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models import TaskInstance


class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
"""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/notifications/chime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.hooks.chime import ChimeWebhookHook
from airflow.utils.context import Context

if TYPE_CHECKING:
from airflow.utils.context import Context

try:
from airflow.notifications.basenotifier import BaseNotifier
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

import boto3

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
Expand All @@ -43,6 +41,8 @@
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
import boto3

from airflow.models import TaskInstance
from airflow.utils.context import Context

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from mypy_boto3_rds.type_defs import TagTypeDef

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
Expand All @@ -39,6 +37,8 @@
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
from mypy_boto3_rds.type_defs import TagTypeDef

from airflow.utils.context import Context


Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

import boto3

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
Expand All @@ -31,6 +29,8 @@
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
import boto3

from airflow.utils.context import Context

DEFAULT_CONN_ID: str = "aws_default"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils.sqs import process_response
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.utils.context import Context
from datetime import timedelta

Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/transfers/mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast

from bson import json_util
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.mongo.hooks.mongo import MongoHook

if TYPE_CHECKING:
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor

from airflow.utils.context import Context


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
import pandas as pd

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.context import Context


Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/triggers/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class AthenaTrigger(AwsBaseWaiterTrigger):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, AsyncIterator
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class AwsBaseWaiterTrigger(BaseTrigger):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@
import asyncio
import itertools
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

from botocore.exceptions import WaiterError
from deprecated import deprecated

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


@deprecated(reason="use BatchJobTrigger instead")
class BatchOperatorTrigger(BaseTrigger):
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator
from typing import TYPE_CHECKING, Any, AsyncIterator

from botocore.exceptions import ClientError, WaiterError

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class ClusterActiveTrigger(AwsBaseWaiterTrigger):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
from __future__ import annotations

import warnings
from typing import Any
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
"""
Expand Down
Loading