Skip to content

Commit

Permalink
Refactor import from collections (#34406)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Sep 28, 2023
1 parent dd325b4 commit ca3ce78
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 64 deletions.
4 changes: 2 additions & 2 deletions airflow/api/common/experimental/get_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Lineage APIs."""
from __future__ import annotations

import collections
from collections import defaultdict
from typing import TYPE_CHECKING, Any

from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun
Expand All @@ -43,7 +43,7 @@ def get_lineage(
inlets = XCom.get_many(dag_ids=dag_id, run_id=dagrun.run_id, key=PIPELINE_INLETS, session=session)
outlets = XCom.get_many(dag_ids=dag_id, run_id=dagrun.run_id, key=PIPELINE_OUTLETS, session=session)

lineage: dict[str, dict[str, Any]] = collections.defaultdict(dict)
lineage: dict[str, dict[str, Any]] = defaultdict(dict)
for meta in inlets:
lineage[meta.task_id]["inlets"] = meta.value
for meta in outlets:
Expand Down
6 changes: 3 additions & 3 deletions airflow/auth/managers/fab/cli_commands/role_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
"""Roles sub-commands."""
from __future__ import annotations

import collections
import itertools
import json
import os
from collections import defaultdict
from typing import TYPE_CHECKING

from airflow.auth.managers.fab.cli_commands.utils import get_application_builder
Expand All @@ -48,7 +48,7 @@ def roles_list(args):
)
return

permission_map: dict[tuple[str, str], list[str]] = collections.defaultdict(list)
permission_map: dict[tuple[str, str], list[str]] = defaultdict(list)
for role in roles:
for permission in role.permissions:
permission_map[(role.name, permission.resource.name)].append(permission.action.name)
Expand Down Expand Up @@ -92,7 +92,7 @@ def __roles_add_or_remove_permissions(args):
is_add: bool = args.subcommand.startswith("add")

role_map = {}
perm_map: dict[tuple[str, str], set[str]] = collections.defaultdict(set)
perm_map: dict[tuple[str, str], set[str]] = defaultdict(set)
asm = appbuilder.sm
for name in args.role:
role: Role | None = asm.find_role(name)
Expand Down
4 changes: 2 additions & 2 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from __future__ import annotations

import argparse
import collections
import logging
from argparse import Action
from collections import Counter
from functools import lru_cache
from typing import TYPE_CHECKING, Iterable

Expand Down Expand Up @@ -82,7 +82,7 @@

# Check if sub-commands are defined twice, which could be an issue.
if len(ALL_COMMANDS_DICT) < len(airflow_commands):
dup = {k for k, v in collections.Counter([c.name for c in airflow_commands]).items() if v > 1}
dup = {k for k, v in Counter([c.name for c in airflow_commands]).items() if v > 1}
raise CliConflictError(
f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n"
f"This can be due to the executor '{ExecutorLoader.get_default_executor_name()}' "
Expand Down
9 changes: 4 additions & 5 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Processes DAGs."""
from __future__ import annotations

import collections
import enum
import importlib
import inspect
Expand All @@ -30,7 +29,7 @@
import sys
import time
import zipfile
from collections import defaultdict
from collections import defaultdict, deque
from datetime import datetime, timedelta
from importlib import import_module
from pathlib import Path
Expand Down Expand Up @@ -386,7 +385,7 @@ def __init__(
super().__init__()
# known files; this will be updated every `dag_dir_list_interval` and stuff added/removed accordingly
self._file_paths: list[str] = []
self._file_path_queue: collections.deque[str] = collections.deque()
self._file_path_queue: deque[str] = deque()
self._max_runs = max_runs
# signal_conn is None for dag_processor_standalone mode.
self._direct_scheduler_conn = signal_conn
Expand Down Expand Up @@ -743,7 +742,7 @@ def _add_callback_to_queue(self, request: CallbackRequest):
# Remove file paths matching request.full_filepath from self._file_path_queue
# Since we are already going to use that filepath to run callback,
# there is no need to have same file path again in the queue
self._file_path_queue = collections.deque(
self._file_path_queue = deque(
file_path for file_path in self._file_path_queue if file_path != request.full_filepath
)
self._add_paths_to_queue([request.full_filepath], True)
Expand Down Expand Up @@ -986,7 +985,7 @@ def set_file_paths(self, new_file_paths):
self._file_paths = new_file_paths

# clean up the queues; remove anything queued which no longer in the list, including callbacks
self._file_path_queue = collections.deque(x for x in self._file_path_queue if x in new_file_paths)
self._file_path_queue = deque(x for x in self._file_path_queue if x in new_file_paths)
Stats.gauge("dag_processing.file_path_queue_size", len(self._file_path_queue))

callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths]
Expand Down
1 change: 0 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import abc
import collections
import collections.abc
import contextlib
import copy
Expand Down
1 change: 0 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import collections
import collections.abc
import copy
import functools
Expand Down
1 change: 0 additions & 1 deletion airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import collections
import collections.abc
import contextlib
import copy
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
# under the License.
from __future__ import annotations

import collections
import os
import re
import tarfile
import tempfile
import time
import warnings
from collections import Counter
from collections import Counter, namedtuple
from datetime import datetime
from functools import partial
from typing import Any, Callable, Generator, cast
Expand Down Expand Up @@ -54,7 +53,7 @@ class LogState:

# Position is a tuple that includes the last read timestamp and the number of items that were read
# at that time. This is used to figure out which event to start with on the next read.
Position = collections.namedtuple("Position", ["timestamp", "skip"])
Position = namedtuple("Position", ["timestamp", "skip"])


def argmin(arr, f: Callable) -> int | None:
Expand Down
6 changes: 3 additions & 3 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
# under the License.
from __future__ import annotations

import collections
import collections.abc
import functools
from collections import Counter
from typing import TYPE_CHECKING, Iterator, KeysView, NamedTuple

from sqlalchemy import and_, func, or_, select
Expand Down Expand Up @@ -64,8 +64,8 @@ def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTISta
:param ti: the ti that we want to calculate deps for
:param finished_tis: all the finished tasks of the dag_run
"""
counter: dict[str, int] = collections.Counter()
setup_counter: dict[str, int] = collections.Counter()
counter: dict[str, int] = Counter()
setup_counter: dict[str, int] = Counter()
for ti in finished_upstreams:
curr_state = {ti.state: 1}
counter.update(curr_state)
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
from __future__ import annotations

import collections
import functools
import logging
from collections import defaultdict
from typing import Iterator, Tuple

try:
Expand All @@ -33,7 +33,7 @@

@functools.lru_cache(maxsize=None)
def _get_grouped_entry_points() -> dict[str, list[EPnD]]:
mapping: dict[str, list[EPnD]] = collections.defaultdict(list)
mapping: dict[str, list[EPnD]] = defaultdict(list)
for dist in metadata.distributions():
try:
for e in dist.entry_points:
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/serve_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
"""Serve logs process."""
from __future__ import annotations

import collections
import logging
import os
import socket
from collections import namedtuple

import gunicorn.app.base
from flask import Flask, abort, request, send_from_directory
Expand Down Expand Up @@ -134,7 +134,7 @@ def serve_logs_view(filename):
return flask_app


GunicornOption = collections.namedtuple("GunicornOption", ["key", "value"])
GunicornOption = namedtuple("GunicornOption", ["key", "value"])


class StandaloneGunicornApplication(gunicorn.app.base.BaseApplication):
Expand Down
14 changes: 7 additions & 7 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

import collections
import collections.abc
import copy
import datetime
import itertools
Expand Down Expand Up @@ -477,8 +477,8 @@ def get_mapped_group_summaries():
.order_by(TaskInstance.task_id, TaskInstance.run_id)
)
# Group tis by run_id, and then map_index.
mapped_tis: Mapping[str, Mapping[int, list[TaskInstance]]] = collections.defaultdict(
lambda: collections.defaultdict(list),
mapped_tis: Mapping[str, Mapping[int, list[TaskInstance]]] = defaultdict(
lambda: defaultdict(list)
)
for ti in mapped_ti_query:
mapped_tis[ti.run_id][ti.map_index].append(ti)
Expand All @@ -499,7 +499,7 @@ def get_mapped_group_summary(run_id: str, mapped_instances: Mapping[int, list[Ta
# TODO: This assumes TI map index has a one-to-one mapping to
# its parent mapped task group, which will not be true when we
# allow nested mapping in the future.
mapped_states: MutableMapping[str, int] = collections.defaultdict(int)
mapped_states: MutableMapping[str, int] = defaultdict(int)
for mis in mapped_instances.values():
child_states = {mi.state for mi in mis}
state = next(s for s in wwwutils.priority if s in child_states)
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def task_stats(self, session: Session = NEW_SESSION):
)
)
data = get_task_stats_from_query(qry)
payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list)
payload: dict[str, list[dict[str, Any]]] = defaultdict(list)
for dag_id, state in itertools.product(filter_dag_ids, State.task_states):
payload[dag_id].append({"state": state, "count": data.get(dag_id, {}).get(state, 0)})
return flask.json.jsonify(payload)
Expand Down Expand Up @@ -3260,8 +3260,8 @@ def landing_times(self, dag_id: str, session: Session = NEW_SESSION):
chart_attr=self.line_chart_attr,
)

y_points: dict[str, list[float]] = collections.defaultdict(list)
x_points: dict[str, list[tuple[int]]] = collections.defaultdict(list)
y_points: dict[str, list[float]] = defaultdict(list)
x_points: dict[str, list[tuple[int]]] = defaultdict(list)
for task in dag.tasks:
task_id = task.task_id
for ti in tis:
Expand Down
4 changes: 2 additions & 2 deletions dev/provider_packages/prepare_provider_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""Setup.py for the Provider packages of Airflow project."""
from __future__ import annotations

import collections
import difflib
import glob
import json
Expand All @@ -33,6 +32,7 @@
import sys
import tempfile
import textwrap
from collections import namedtuple
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime, timedelta
Expand Down Expand Up @@ -605,7 +605,7 @@ def convert_cross_package_dependencies_to_table(
"""
Keeps information about historical releases.
"""
ReleaseInfo = collections.namedtuple(
ReleaseInfo = namedtuple(
"ReleaseInfo", "release_version release_version_no_leading_zeros last_commit_hash content file_name"
)

Expand Down
Loading

0 comments on commit ca3ce78

Please sign in to comment.