Skip to content

Commit

Permalink
Improve modules import in Airflow core by some of them into a type-ch…
Browse files Browse the repository at this point in the history
…ecking block (#33755)

* Improve modules import in Airflow core by some of them into a type-checking block

* Restore lazy import in models and fix import in utils.sessions

* fix unit tests

* fix unit tests

* fix static checks

(cherry picked from commit b82ce61)
  • Loading branch information
hussein-awala authored and ephraimbuddy committed Oct 30, 2023
1 parent e64ffe6 commit 334d123
Show file tree
Hide file tree
Showing 18 changed files with 80 additions and 39 deletions.
9 changes: 6 additions & 3 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,25 @@
from contextlib import contextmanager
from copy import deepcopy
from json.decoder import JSONDecodeError
from typing import IO, Any, Dict, Generator, Iterable, Pattern, Set, Tuple, Union
from typing import IO, TYPE_CHECKING, Any, Dict, Generator, Iterable, Pattern, Set, Tuple, Union
from urllib.parse import urlsplit

import re2
from packaging.version import parse as parse_version
from typing_extensions import overload

from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.exceptions import AirflowConfigException
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH
from airflow.utils import yaml
from airflow.utils.empty_set import _get_empty_set_for_configuration
from airflow.utils.module_loading import import_string
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.weight_rule import WeightRule

if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.secrets import BaseSecretsBackend

log = logging.getLogger(__name__)

# show Airflow's deprecation warnings
Expand Down
3 changes: 2 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
"""Exceptions used by Airflow."""
from __future__ import annotations

import datetime
import warnings
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple, Sized

from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
import datetime

from airflow.models import DAG, DagRun


Expand Down
8 changes: 5 additions & 3 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@
"""Base executor - this is the base class for all the implemented executors."""
from __future__ import annotations

import argparse
import logging
import sys
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple

import pendulum

from airflow.cli.cli_config import DefaultHelpParser, GroupCommand
from airflow.cli.cli_config import DefaultHelpParser
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
Expand All @@ -38,8 +36,12 @@
PARALLELISM: int = conf.getint("core", "PARALLELISM")

if TYPE_CHECKING:
import argparse
from datetime import datetime

from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

Expand Down
6 changes: 4 additions & 2 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
import subprocess
from abc import abstractmethod
from multiprocessing import Manager, Process
from multiprocessing.managers import SyncManager
from queue import Empty, Queue
from queue import Empty
from typing import TYPE_CHECKING, Any, Optional, Tuple

from setproctitle import getproctitle, setproctitle
Expand All @@ -43,6 +42,9 @@
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from multiprocessing.managers import SyncManager
from queue import Queue

from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceStateType
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down
4 changes: 3 additions & 1 deletion airflow/kubernetes/pre_7_4_0_compatibility/k8s_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

from abc import ABC, abstractmethod
from functools import reduce
from typing import TYPE_CHECKING

from kubernetes.client import models as k8s
if TYPE_CHECKING:
from kubernetes.client import models as k8s


class K8SModel(ABC):
Expand Down
5 changes: 4 additions & 1 deletion airflow/kubernetes/pre_7_4_0_compatibility/pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from __future__ import annotations

import copy
import datetime
import logging
import os
import secrets
import string
import warnings
from functools import reduce
from typing import TYPE_CHECKING

import re2
from dateutil import parser
Expand All @@ -54,6 +54,9 @@
from airflow.utils.hashlib_wrapper import md5
from airflow.version import version as airflow_version

if TYPE_CHECKING:
import datetime

log = logging.getLogger(__name__)

MAX_LABEL_LEN = 63
Expand Down
13 changes: 6 additions & 7 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable

try:
import importlib_metadata
except ImportError:
from importlib import metadata as importlib_metadata # type: ignore[no-redef]

from types import ModuleType

from airflow import settings
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
from airflow.utils.module_loading import import_string, qualname

if TYPE_CHECKING:
try:
import importlib_metadata
except ImportError:
from importlib import metadata as importlib_metadata # type: ignore[no-redef]
from types import ModuleType

from airflow.hooks.base import BaseHook
from airflow.listeners.listener import ListenerManager
from airflow.timetables.base import Timetable
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def ensure_prefix(field):


if TYPE_CHECKING:
from typing_extensions import Literal

from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook

Expand Down
6 changes: 4 additions & 2 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
import pluggy
import sqlalchemy
from sqlalchemy import create_engine, exc, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session as SASession, scoped_session, sessionmaker
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool

from airflow import policies
Expand All @@ -43,6 +42,9 @@
from airflow.utils.state import State

if TYPE_CHECKING:
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session as SASession

from airflow.www.utils import UIAlert

log = logging.getLogger(__name__)
Expand Down
5 changes: 4 additions & 1 deletion airflow/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from typing import TYPE_CHECKING, Callable

from airflow.configuration import conf
from airflow.metrics.base_stats_logger import NoStatsLogger, StatsLogger
from airflow.metrics.base_stats_logger import NoStatsLogger

if TYPE_CHECKING:
from airflow.metrics.base_stats_logger import StatsLogger

log = logging.getLogger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion airflow/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
# under the License.
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING

import jinja2.nativetypes
import jinja2.sandbox

if TYPE_CHECKING:
import datetime


class _AirflowEnvironmentMixin:
def __init__(self, **kwargs):
Expand Down
6 changes: 4 additions & 2 deletions airflow/timetables/_cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@

import datetime
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

from cron_descriptor import CasingTypeEnum, ExpressionDescriptor, FormatException, MissingFieldException
from croniter import CroniterBadCronError, CroniterBadDateError, croniter
from pendulum import DateTime
from pendulum.tz.timezone import Timezone

from airflow.exceptions import AirflowTimetableInvalid
from airflow.utils.dates import cron_presets
from airflow.utils.timezone import convert_to_utc, make_aware, make_naive

if TYPE_CHECKING:
from pendulum import DateTime


def _is_schedule_fixed(expression: str) -> bool:
"""Figures out if the schedule has a fixed time (e.g. 3 AM every day).
Expand Down
4 changes: 2 additions & 2 deletions airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence
from warnings import warn

from pendulum import DateTime

from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from pendulum import DateTime

from airflow.utils.types import DagRunType


Expand Down
10 changes: 7 additions & 3 deletions airflow/timetables/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
from __future__ import annotations

import itertools
from typing import Iterable
from typing import TYPE_CHECKING, Iterable

import pendulum
from pendulum import DateTime

from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable

if TYPE_CHECKING:
from pendulum import DateTime

from airflow.timetables.base import TimeRestriction


class EventsTimetable(Timetable):
Expand Down
7 changes: 5 additions & 2 deletions airflow/timetables/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
from __future__ import annotations

import datetime
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union

from dateutil.relativedelta import relativedelta
from pendulum import DateTime

from airflow.exceptions import AirflowTimetableInvalid
from airflow.timetables._cron import CronMixin
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
from airflow.utils.timezone import convert_to_utc

if TYPE_CHECKING:
from airflow.timetables.base import TimeRestriction

Delta = Union[datetime.timedelta, relativedelta]


Expand Down
3 changes: 2 additions & 1 deletion airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

from pendulum import DateTime

from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable

if TYPE_CHECKING:
from sqlalchemy import Session

from airflow.models.dataset import DatasetEvent
from airflow.timetables.base import TimeRestriction
from airflow.utils.types import DagRunType


Expand Down
12 changes: 8 additions & 4 deletions airflow/timetables/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
from __future__ import annotations

import datetime
from typing import Any
from typing import TYPE_CHECKING, Any

from dateutil.relativedelta import relativedelta
from pendulum import DateTime
from pendulum.tz.timezone import Timezone

from airflow.timetables._cron import CronMixin
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable

if TYPE_CHECKING:
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone

from airflow.timetables.base import TimeRestriction


class CronTriggerTimetable(CronMixin, Timetable):
Expand Down
11 changes: 8 additions & 3 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@

import asyncio
import typing
from datetime import datetime

from asgiref.sync import sync_to_async
from sqlalchemy import func
from sqlalchemy.orm import Session

from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.state import TaskInstanceState
from airflow.utils.timezone import utcnow

if typing.TYPE_CHECKING:
from datetime import datetime

from sqlalchemy.orm import Session

from airflow.utils.state import DagRunState


class TaskStateTrigger(BaseTrigger):
"""
Expand Down

0 comments on commit 334d123

Please sign in to comment.