Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None):
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from jwt.algorithms import OKPAlgorithm, RSAAlgorithm

if isinstance(key, RSAPrivateKey | Ed25519PrivateKey):
if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)):
key = key.public_key()

if isinstance(key, RSAPublicKey):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(
self.filter_option: FilterOptionEnum = filter_option

def to_orm(self, select: Select) -> Select:
if isinstance(self.value, list | str) and not self.value and self.skip_none:
if isinstance(self.value, (list, str)) and not self.value and self.skip_none:
return select
if self.value is None and self.skip_none:
return select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def handle_bulk_create(self, action: BulkCreateAction, results: BulkActionRespon

for variable in action.entities:
if variable.key in create_keys:
should_serialize_json = isinstance(variable.value, dict | list)
should_serialize_json = isinstance(variable.value, (dict, list))
Variable.set(
key=variable.key,
value=variable.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_child_task_map(parent_task_id: str, task_node_map: dict[str, dict[str, A


def _count_tis(node: int | MappedTaskGroup | MappedOperator, run_id: str, session: SessionDep) -> int:
if not isinstance(node, MappedTaskGroup | MappedOperator):
if not isinstance(node, (MappedTaskGroup, MappedOperator)):
return node
with contextlib.suppress(NotFullyPopulated, NotMapped):
return DBBaseOperator.get_mapped_ti_count(node, run_id=run_id, session=session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _create_ti_state_update_query_and_update_state(
dag_bag: DagBagDep,
dag_id: str,
) -> tuple[Update, TaskInstanceState]:
if isinstance(ti_patch_payload, TITerminalStatePayload | TIRetryStatePayload | TISuccessStatePayload):
if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, ti_id_str)
updated_state = ti_patch_payload.state
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/cli/commands/plugins_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def _get_name(class_like_object) -> str:
if isinstance(class_like_object, str | PluginsDirectorySource):
if isinstance(class_like_object, (str, PluginsDirectorySource)):
return str(class_like_object)
if inspect.isclass(class_like_object):
return class_like_object.__name__
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/cli/simple_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def print_as_plain_table(self, data: list[dict]):
print(output)

def _normalize_data(self, value: Any, output: str) -> list | str | dict | None:
if isinstance(value, tuple | list):
if isinstance(value, (tuple, list)):
if output == "table":
return ",".join(str(self._normalize_data(x, output)) for x in value)
return [self._normalize_data(x, output) for x in value]
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def __init__(
self.kwargs = kwargs
self.timeout: timedelta | None
# Check timeout type at runtime
if isinstance(timeout, int | float):
if isinstance(timeout, (int, float)):
self.timeout = timedelta(seconds=timeout)
else:
self.timeout = timeout
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __hash__(self):
def __eq__(self, other: object) -> bool:
from airflow.sdk.definitions.asset import AssetAlias

if isinstance(other, self.__class__ | AssetAlias):
if isinstance(other, (self.__class__, AssetAlias)):
return self.name == other.name
return NotImplemented

Expand Down Expand Up @@ -306,7 +306,7 @@ def __init__(self, name: str = "", uri: str = "", **kwargs):
def __eq__(self, other: object) -> bool:
from airflow.sdk.definitions.asset import Asset

if isinstance(other, self.__class__ | Asset):
if isinstance(other, (self.__class__, Asset)):
return self.name == other.name and self.uri == other.uri
return NotImplemented

Expand Down
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _upgrade_outdated_dag_access_control(access_control=None):
for role, perms in access_control.items():
if packaging_version.parse(FAB_VERSION) >= packaging_version.parse("1.3.0"):
updated_access_control[role] = updated_access_control.get(role, {})
if isinstance(perms, set | list):
if isinstance(perms, (set, list)):
# Support for old-style access_control where only the actions are specified
updated_access_control[role][permissions.RESOURCE_DAG] = set(perms)
else:
Expand Down Expand Up @@ -541,7 +541,7 @@ def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval:
:meta private:
"""
timetable_type = type(self.timetable)
if issubclass(timetable_type, NullTimetable | OnceTimetable | AssetTriggeredTimetable):
if issubclass(timetable_type, (NullTimetable, OnceTimetable, AssetTriggeredTimetable)):
return DataInterval.exact(timezone.coerce_datetime(logical_date))
start = timezone.coerce_datetime(logical_date)
if issubclass(timetable_type, CronDataIntervalTimetable):
Expand Down Expand Up @@ -959,7 +959,7 @@ def _get_task_instances(
tis = tis.where(DagRun.logical_date <= end_date)

if state:
if isinstance(state, str | TaskInstanceState):
if isinstance(state, (str, TaskInstanceState)):
tis = tis.where(TaskInstance.state == state)
elif len(state) == 1:
tis = tis.where(TaskInstance.state == state[0])
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def parse(mod_name, filepath):

dagbag_import_timeout = settings.get_dagbag_import_timeout(filepath)

if not isinstance(dagbag_import_timeout, int | float):
if not isinstance(dagbag_import_timeout, (int, float)):
raise TypeError(
f"Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float"
)
Expand Down Expand Up @@ -520,7 +520,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk):
from airflow.sdk import DAG as SDKDAG
from airflow.sdk.definitions._internal.contextmanager import DagContext

top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG | SDKDAG)}
top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, (DAG, SDKDAG))}

top_level_dags.update(DagContext.autoregistered_dags)

Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | SchedulerXComArg]:
from airflow.models.xcom_arg import SchedulerXComArg

return isinstance(v, MappedArgument | SchedulerXComArg)
return isinstance(v, (MappedArgument, SchedulerXComArg))


@attrs.define
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/taskmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Seq
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.settings import task_instance_mutation_hook

if not isinstance(task, BaseOperator | MappedOperator):
if not isinstance(task, (BaseOperator, MappedOperator)):
raise RuntimeError(
f"cannot expand unrecognized operator type {type(task).__module__}.{type(task).__name__}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def serialize(
return cls._encode(var.timestamp(), type_=DAT.DATETIME)
elif isinstance(var, datetime.timedelta):
return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA)
elif isinstance(var, Timezone | FixedTimezone):
elif isinstance(var, (Timezone, FixedTimezone)):
return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE)
elif isinstance(var, relativedelta.relativedelta):
return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA)
Expand All @@ -753,7 +753,7 @@ def serialize(
var._asdict(),
type_=DAT.TASK_INSTANCE_KEY,
)
elif isinstance(var, AirflowException | TaskDeferred) and hasattr(var, "serialize"):
elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"):
exc_cls_name, args, kwargs = var.serialize()
return cls._encode(
cls.serialize(
Expand All @@ -762,7 +762,7 @@ def serialize(
),
type_=DAT.AIRFLOW_EXC_SER,
)
elif isinstance(var, KeyError | AttributeError):
elif isinstance(var, (KeyError, AttributeError)):
return cls._encode(
cls.serialize(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def deserialize(classname: str, version: int, data: dict | str) -> datetime.date
if classname == qualname(DateTime) and isinstance(data, dict):
return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)

if classname == qualname(datetime.timedelta) and isinstance(data, str | float):
if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)):
return datetime.timedelta(seconds=float(data))

if classname == qualname(datetime.date) and isinstance(data, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
if not k8s:
return "", "", 0, False

if isinstance(o, k8s.V1Pod | k8s.V1ResourceRequirements):
if isinstance(o, (k8s.V1Pod, k8s.V1ResourceRequirements)):
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator

# We're running this in an except block, so we don't want it to fail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
if isinstance(o, np.bool_):
return bool(o), name, __version__, True

if isinstance(o, np.float16 | np.float32 | np.float64 | np.complex64 | np.complex128):
if isinstance(o, (np.float16, np.float32, np.float64, np.complex64, np.complex128)):
return float(o), name, __version__, True

return "", "", 0, False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
def deserialize(classname: str, version: int, data: object) -> Any:
from airflow.utils.timezone import parse_timezone

if not isinstance(data, str | int):
if not isinstance(data, (str, int)):
raise TypeError(f"{data} is not of type int or str but of {type(data)}")

if version > __version__:
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/utils/dot_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _draw_nodes(
node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
) -> None:
"""Draw the node and its children on the given parent_graph recursively."""
if isinstance(node, BaseOperator | MappedOperator):
if isinstance(node, (BaseOperator, MappedOperator)):
_draw_task(node, parent_graph, states_by_task_id)
else:
if not isinstance(node, TaskGroup):
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def is_empty(x):
for k, v in val.items():
if is_empty(v):
continue
if isinstance(v, list | dict):
if isinstance(v, (list, dict)):
new_val = prune_dict(v, mode=mode)
if not is_empty(new_val):
new_dict[k] = new_val
Expand All @@ -312,7 +312,7 @@ def is_empty(x):
for v in val:
if is_empty(v):
continue
if isinstance(v, list | dict):
if isinstance(v, (list, dict)):
new_val = prune_dict(v, mode=mode)
if not is_empty(new_val):
new_list.append(new_val)
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/utils/log/colored_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs):

@staticmethod
def _color_arg(arg: Any) -> str | float | int:
if isinstance(arg, int | float):
if isinstance(arg, (int, float)):
# In case of %d or %f formatting
return arg
return BOLD_ON + str(arg) + BOLD_OFF
Expand All @@ -69,7 +69,7 @@ def _count_number_of_arguments_in_message(record: LogRecord) -> int:
return len(matches) if matches else 0

def _color_record_args(self, record: LogRecord) -> LogRecord:
if isinstance(record.args, tuple | list):
if isinstance(record.args, (tuple, list)):
record.args = tuple(self._color_arg(arg) for arg in record.args)
elif isinstance(record.args, dict):
if self._count_number_of_arguments_in_message(record) > 1:
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/utils/setup_teardown.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def set_dependency(
new_task: AbstractOperator | list[AbstractOperator],
upstream=True,
):
if isinstance(new_task, list | tuple):
if isinstance(new_task, (list, tuple)):
for task in new_task:
cls._set_dependency(task, receiving_task, upstream)
else:
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ def sanitize_for_serialization(obj: V1Pod):
"""
if obj is None:
return None
if isinstance(obj, float | bool | bytes | str | int):
if isinstance(obj, (float, bool, bytes, str, int)):
return obj
if isinstance(obj, list):
return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
if isinstance(obj, tuple):
return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
if isinstance(obj, datetime.datetime | datetime.date):
if isinstance(obj, (datetime.datetime, datetime.date)):
return obj.isoformat()

if isinstance(obj, dict):
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_imports_from_file(filepath: str):
doc_node = ast.parse(content, filepath)
import_names: set[str] = set()
for current_node in ast.walk(doc_node):
if not isinstance(current_node, ast.Import | ast.ImportFrom):
if not isinstance(current_node, (ast.Import, ast.ImportFrom)):
continue
for alias in current_node.names:
name = alias.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def test_bulk_create_entity_serialization(
response = test_client.patch("/variables", json=actions)
assert response.status_code == 200

if isinstance(entity_value, dict | list):
if isinstance(entity_value, (dict, list)):
retrieved_value_deserialized = Variable.get(entity_key, deserialize_json=True)
assert retrieved_value_deserialized == entity_value
retrieved_value_raw_string = Variable.get(entity_key, deserialize_json=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def collect_dags(dag_folder=None):
"providers/*/*/tests/system/*/*/",
]
else:
if isinstance(dag_folder, list | tuple):
if isinstance(dag_folder, (list, tuple)):
patterns = dag_folder
else:
patterns = [dag_folder]
Expand Down Expand Up @@ -723,7 +723,7 @@ def validate_deserialized_task(
from airflow.sdk.definitions.mappedoperator import MappedOperator

assert not isinstance(task, SerializedBaseOperator)
assert isinstance(task, BaseOperator | MappedOperator)
assert isinstance(task, (BaseOperator, MappedOperator))

# Every task should have a task_group property -- even if it's the DAG's root task group
assert serialized_task.task_group
Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/src/airflow_breeze/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def bytes2human(n):
def get_printable_value(key: str, value: Any) -> str:
if key == "percent":
return f"{value} %"
if isinstance(value, int | float):
if isinstance(value, (int, float)):
return bytes2human(value)
return str(value)

Expand Down
2 changes: 1 addition & 1 deletion devel-common/src/sphinx_exts/operators_and_hooks_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def analyze_decorators(node, _file_path, object_type, _class_name=None):
if isinstance(child, ast.ClassDef):
analyze_decorators(child, file_path, object_type="class")
deprecations.extend(_iter_module_for_deprecations(child, file_path, class_name=child.name))
elif isinstance(child, ast.FunctionDef | ast.AsyncFunctionDef):
elif isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
analyze_decorators(
child, file_path, _class_name=class_name, object_type="method" if class_name else "function"
)
Expand Down
2 changes: 1 addition & 1 deletion devel-common/src/sphinx_exts/providers_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_import_mappings(tree) -> dict[str, str]:
"""
imports = {}
for node in ast.walk(tree):
if isinstance(node, ast.Import | ast.ImportFrom):
if isinstance(node, (ast.Import, ast.ImportFrom)):
for alias in node.names:
module_prefix = f"{node.module}." if hasattr(node, "module") and node.module else ""
imports[alias.asname or alias.name] = f"{module_prefix}{alias.name}"
Expand Down
2 changes: 1 addition & 1 deletion devel-common/src/sphinx_exts/removemarktransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def is_pycode(node: nodes.literal_block) -> bool:
if language == "guess":
try:
lexer = guess_lexer(node.rawsource)
return isinstance(lexer, PythonLexer | Python3Lexer)
return isinstance(lexer, (PythonLexer, Python3Lexer))
except Exception:
pass

Expand Down
2 changes: 1 addition & 1 deletion devel-common/src/sphinx_exts/substitution_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class SubstitutionCodeBlockTransform(SphinxTransform):

def apply(self, **kwargs: Any) -> None:
def condition(node):
return isinstance(node, nodes.literal_block | nodes.literal)
return isinstance(node, (nodes.literal_block, nodes.literal))

for node in self.document.traverse(condition):
if _SUBSTITUTION_OPTION_NAME not in node:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ForbiddenWarningsPlugin:
def __init__(self, config: pytest.Config, forbidden_warnings: tuple[str, ...]):
# Set by a pytest_configure hook in conftest
deprecations_ignore = config.inicfg["airflow_deprecations_ignore"]
if isinstance(deprecations_ignore, str | os.PathLike):
if isinstance(deprecations_ignore, (str, os.PathLike)):
self.deprecations_ignore = [deprecations_ignore]
else:
self.deprecations_ignore = deprecations_ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool:
"""
if (
vals is None
or not isinstance(vals, tuple | list)
or not all(isinstance(val, str | int | float) for val in vals)
or not isinstance(vals, (tuple, list))
or not all(isinstance(val, (str, int, float)) for val in vals)
):
raise ValueError("List of strings expected")
return True
Expand All @@ -322,7 +322,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:
if conf:
if not isinstance(conf, dict):
raise ValueError("'conf' argument must be a dict")
if not all(isinstance(v, str | int) and v != "" for v in conf.values()):
if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
raise ValueError("'conf' values must be either strings or ints")
return True

Expand Down
Loading