Skip to content

Commit

Permalink
Change type definition for provider_info_cache decorator (#39750)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored May 26, 2024
1 parent ad7cb99 commit 4dffec4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 26 deletions.
23 changes: 10 additions & 13 deletions airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import sys
import warnings
from typing import TYPE_CHECKING

if os.environ.get("_AIRFLOW_PATCH_GEVENT"):
# If you are using gevents and start airflow webserver, you might want to run gevent monkeypatching
Expand Down Expand Up @@ -81,6 +82,13 @@
# Deprecated lazy imports
"AirflowException": (".exceptions", "AirflowException", True),
}
if TYPE_CHECKING:
# These objects are imported by PEP-562, however, static analyzers and IDE's
# have no idea about typing of these objects.
# Add it under TYPE_CHECKING block should help with it.
from airflow.models.dag import DAG
from airflow.models.dataset import Dataset
from airflow.models.xcom_arg import XComArg


def __getattr__(name: str):
Expand Down Expand Up @@ -119,24 +127,13 @@ def __getattr__(name: str):


if not settings.LAZY_LOAD_PROVIDERS:
from airflow import providers_manager
from airflow.providers_manager import ProvidersManager

manager = providers_manager.ProvidersManager()
manager = ProvidersManager()
manager.initialize_providers_list()
manager.initialize_providers_hooks()
manager.initialize_providers_extra_links()
if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()


# This is never executed, but tricks static analyzers (PyDev, PyCharm,)
# into knowing the types of these symbols, and what
# they contain.
STATICA_HACK = True
globals()["kcah_acitats"[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
from airflow.models.dag import DAG
from airflow.models.dataset import Dataset
from airflow.models.xcom_arg import XComArg
26 changes: 14 additions & 12 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@
from dataclasses import dataclass
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, NoReturn, TypeVar

from packaging.utils import canonicalize_name

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.hooks.filesystem import FSHook
from airflow.hooks.package_index import PackageIndexHook
from airflow.typing_compat import ParamSpec
from airflow.utils import yaml
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -51,6 +52,9 @@
else:
from importlib_resources import files as resource_files

PS = ParamSpec("PS")
RT = TypeVar("RT")

MIN_PROVIDER_VERSIONS = {
"apache-airflow-providers-celery": "2.1.0",
}
Expand Down Expand Up @@ -261,11 +265,6 @@ class ConnectionFormWidgetInfo(NamedTuple):
is_sensitive: bool


T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)


def log_debug_import_from_sources(class_name, e, provider_package):
"""Log debug imports from sources."""
log.debug(
Expand Down Expand Up @@ -362,31 +361,34 @@ def _correctness_check(provider_package: str, class_name: str, provider_info: Pr

# We want to have better control over initialization of parameters and be able to debug and test it
# So we add our own decorator
def provider_info_cache(cache_name: str) -> Callable[[T], T]:
def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, NoReturn]], Callable[PS, None]]:
"""
Decorate and cache provider info.
Decorator factory that create decorator that caches initialization of provider's parameters
:param cache_name: Name of the cache
"""

def provider_info_cache_decorator(func: T):
def provider_info_cache_decorator(func: Callable[PS, NoReturn]) -> Callable[PS, None]:
@wraps(func)
def wrapped_function(*args, **kwargs):
def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None:
providers_manager_instance = args[0]
if TYPE_CHECKING:
assert isinstance(providers_manager_instance, ProvidersManager)

if cache_name in providers_manager_instance._initialized_cache:
return
start_time = perf_counter()
logger.debug("Initializing Providers Manager[%s]", cache_name)
log.debug("Initializing Providers Manager[%s]", cache_name)
func(*args, **kwargs)
providers_manager_instance._initialized_cache[cache_name] = True
logger.debug(
log.debug(
"Initialization of Providers Manager[%s] took %.2f seconds",
cache_name,
perf_counter() - start_time,
)

return cast(T, wrapped_function)
return wrapped_function

return provider_info_cache_decorator

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ required-imports = ["from __future__ import annotations"]
combine-as-imports = true

[tool.ruff.lint.per-file-ignores]
"airflow/__init__.py" = ["F401"]
"airflow/__init__.py" = ["F401", "TCH004"]
"airflow/models/__init__.py" = ["F401", "TCH004"]
"airflow/models/sqla_models.py" = ["F401"]

Expand Down

0 comments on commit 4dffec4

Please sign in to comment.