Skip to content

Commit

Permalink
Determine needs_expansion at time of serialization (#39604)
Browse files Browse the repository at this point in the history
This way we do not necessarily need to also pass the dag and do the evaluation on the server side.

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
dstandish and uranusjr authored May 20, 2024
1 parent 4ee46b9 commit 4d525aa
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 41 deletions.
3 changes: 1 addition & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.exceptions import TaskNotFound
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
Expand Down Expand Up @@ -201,7 +200,7 @@ def get_mapped_task_instances(
except TaskNotFound:
error_message = f"Task id {task_id} not found"
raise NotFound(error_message)
if not needs_expansion(task):
if not task.get_needs_expansion():
error_message = f"Task id {task_id} is not mapped"
raise NotFound(error_message)

Expand Down
11 changes: 5 additions & 6 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from airflow.models import DagPickle, TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.operator import needs_expansion
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
Expand Down Expand Up @@ -177,7 +176,7 @@ def _get_ti_db_access(

if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
if needs_expansion(task):
if task.get_needs_expansion():
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
elif map_index >= 0:
Expand Down Expand Up @@ -228,10 +227,10 @@ def _get_ti(
pool=pool,
create_if_necessary=create_if_necessary,
)
# setting ti.task is necessary for AIP-44 since the task object does not serialize perfectly
# if we update the serialization logic for Operator to also serialize the dag object on it,
# then this would not be necessary;
ti.task = task

# we do refresh_from_task so that if TI has come back via RPC, we ensure that ti.task
# is the original task object and not the result of the round trip
ti.refresh_from_task(task, pool_override=pool)
return ti, dr_created


Expand Down
15 changes: 14 additions & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AbstractOperator(Templater, DAGNode):
outlets: list
inlets: list
trigger_rule: TriggerRule

_needs_expansion: bool | None = None
_on_failure_fail_dagrun = False

HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
Expand Down Expand Up @@ -395,6 +395,19 @@ def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
"""
return next(self.iter_mapped_task_groups(), None)

def get_needs_expansion(self) -> bool:
"""
Return true if the task is MappedOperator or is in a mapped task group.
:meta private:
"""
if self._needs_expansion is None:
if self.get_closest_mapped_task_group() is not None:
self._needs_expansion = True
else:
self._needs_expansion = False
return self._needs_expansion

def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
"""Get the "normal" operator from current abstract operator.
Expand Down
1 change: 1 addition & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,7 @@ def get_serialized_fields(cls):
"map_index_template",
"start_trigger",
"next_method",
"_needs_expansion",
}
)
DagContext.pop_context_managed_dag()
Expand Down
1 change: 1 addition & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class MappedOperator(AbstractOperator):
_operator_name: str
start_trigger: BaseTrigger | None
next_method: str | None
_needs_expansion: bool = True

dag: DAG | None
task_group: TaskGroup | None
Expand Down
26 changes: 2 additions & 24 deletions airflow/models/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,12 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Union
from typing import Union

from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator

if TYPE_CHECKING:
from airflow.models.abstractoperator import AbstractOperator
from airflow.typing_compat import TypeGuard

Operator = Union[BaseOperator, MappedOperator]


def needs_expansion(task: AbstractOperator) -> TypeGuard[Operator]:
"""Whether a task needs expansion at runtime.
A task needs expansion if it either
* Is a mapped operator, or
* Is in a mapped task group.
This is implemented as a free function (instead of a property) so we can
make it a type guard.
"""
if isinstance(task, MappedOperator):
return True
if task.get_closest_mapped_task_group() is not None:
return True
return False


__all__ = ["Operator", "needs_expansion"]
__all__ = ["Operator"]
1 change: 1 addition & 0 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def serialize(
elif isinstance(var, MappedOperator):
return cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP)
elif isinstance(var, BaseOperator):
var._needs_expansion = var.get_needs_expansion()
return cls._encode(SerializedBaseOperator.serialize_operator(var), type_=DAT.OP)
elif isinstance(var, cls._datetime_types):
return cls._encode(var.timestamp(), type_=DAT.DATETIME)
Expand Down
5 changes: 2 additions & 3 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def _evaluate_trigger_rule(
"""
from airflow.models.abstractoperator import NotMapped
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance

@functools.lru_cache
Expand Down Expand Up @@ -260,7 +259,7 @@ def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus

# Optimization: Don't need to hit the database if all upstreams are
# "simple" tasks (no task or task group mapping involved).
if not any(needs_expansion(t) for t in indirect_setups.values()):
if not any(t.get_needs_expansion() for t in indirect_setups.values()):
upstream = len(indirect_setups)
else:
task_id_counts = session.execute(
Expand Down Expand Up @@ -353,7 +352,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:

# Optimization: Don't need to hit the database if all upstreams are
# "simple" tasks (no task or task group mapping involved).
if not any(needs_expansion(t) for t in upstream_tasks.values()):
if not any(t.get_needs_expansion() for t in upstream_tasks.values()):
upstream = len(upstream_tasks)
upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup)
else:
Expand Down
3 changes: 1 addition & 2 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
from airflow.models.dagrun import RUN_ID_REGEX, DagRun, DagRunType
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.errors import ParseImportError
from airflow.models.operator import needs_expansion
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstanceNote
from airflow.plugins_manager import PLUGINS_ATTRIBUTES_TO_DUMP
Expand Down Expand Up @@ -427,7 +426,7 @@ def set_overall_state(record):
set_overall_state(record)
yield record

if item_is_mapped := needs_expansion(item):
if item_is_mapped := item.get_needs_expansion():
instances = list(_mapped_summary(grouped_tis[item.task_id]))
else:
instances = [
Expand Down
18 changes: 15 additions & 3 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i
},
"doc_md": "### Task Tutorial Documentation",
"_log_config_logger_name": "airflow.task.operators",
"_needs_expansion": False,
"weight_rule": "downstream",
"next_method": None,
"start_trigger": None,
Expand Down Expand Up @@ -224,6 +225,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i
"is_teardown": False,
"on_failure_fail_dagrun": False,
"_log_config_logger_name": "airflow.task.operators",
"_needs_expansion": False,
"weight_rule": "downstream",
"next_method": None,
"start_trigger": None,
Expand Down Expand Up @@ -456,13 +458,14 @@ def test_dag_serialization_to_timetable(self, timetable, serialized_timetable):
del expected["dag"]["schedule_interval"]
expected["dag"]["timetable"] = serialized_timetable

# these tasks are not mapped / in mapped task group
for task in expected["dag"]["tasks"]:
task["__var"]["_needs_expansion"] = False

actual, expected = self.prepare_ser_dags_for_comparison(
actual=serialized_dag,
expected=expected,
)
for task in actual["dag"]["tasks"]:
for k, v in task.items():
print(task["__var"]["task_id"], k, v)
assert actual == expected

@pytest.mark.db_test
Expand Down Expand Up @@ -654,6 +657,7 @@ def validate_deserialized_task(
# Checked separately
"resources",
"on_failure_fail_dagrun",
"_needs_expansion",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, MappedOperator)
Expand Down Expand Up @@ -2243,6 +2247,7 @@ def test_operator_expand_serde():
assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"_task_module": "airflow.operators.bash",
"_task_type": "BashOperator",
"start_trigger": None,
Expand Down Expand Up @@ -2278,6 +2283,7 @@ def test_operator_expand_serde():

assert op.operator_class == {
"_task_type": "BashOperator",
"_needs_expansion": True,
"start_trigger": None,
"next_method": None,
"downstream_task_ids": [],
Expand All @@ -2304,6 +2310,7 @@ def test_operator_expand_xcomarg_serde():
assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"_task_module": "tests.test_utils.mock_operators",
"_task_type": "MockOperator",
"downstream_task_ids": [],
Expand Down Expand Up @@ -2358,6 +2365,7 @@ def test_operator_expand_kwargs_literal_serde(strict):
assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"_task_module": "tests.test_utils.mock_operators",
"_task_type": "MockOperator",
"downstream_task_ids": [],
Expand Down Expand Up @@ -2412,6 +2420,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"_task_module": "tests.test_utils.mock_operators",
"_task_type": "MockOperator",
"downstream_task_ids": [],
Expand Down Expand Up @@ -2513,6 +2522,7 @@ def x(arg1, arg2, arg3):
assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"_task_module": "airflow.decorators.python",
"_task_type": "_PythonDecoratedOperator",
"_operator_name": "@task",
Expand Down Expand Up @@ -2610,6 +2620,7 @@ def x(arg1, arg2, arg3):
assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"_task_module": "airflow.decorators.python",
"_task_type": "_PythonDecoratedOperator",
"_operator_name": "@task",
Expand Down Expand Up @@ -2765,6 +2776,7 @@ def operator_extra_links(self):
"_task_module": "tests.serialization.test_dag_serialization",
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
"next_method": None,
"start_trigger": None,
}
Expand Down

0 comments on commit 4d525aa

Please sign in to comment.