Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mapped task immutability after clear #23667

Merged
merged 6 commits into from
Jun 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 92 additions & 22 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,15 +642,9 @@ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) ->
tis = list(self.get_task_instances(session=session, state=State.task_states))
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
dag = self.get_dag()
for ti in tis:
try:
ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
self.log.warning(
"Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id
)
ti.state = State.REMOVED
session.flush()
missing_indexes = self._find_missing_task_indexes(dag, tis, session=session)
if missing_indexes:
self.verify_integrity(missing_indexes=missing_indexes, session=session)

unfinished_tis = [t for t in tis if t.state in State.unfinished]
finished_tis = [t for t in tis if t.state in State.finished]
Expand Down Expand Up @@ -811,11 +805,17 @@ def _emit_duration_stats_for_finished_state(self):
Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)

@provide_session
def verify_integrity(self, session: Session = NEW_SESSION):
def verify_integrity(
self,
*,
missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None,
session: Session = NEW_SESSION,
):
"""
Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required.

:missing_indexes: A dictionary of task vs indexes that are missing.
:param session: Sqlalchemy ORM Session
"""
from airflow.settings import task_instance_mutation_hook
Expand All @@ -824,9 +824,16 @@ def verify_integrity(self, session: Session = NEW_SESSION):
hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)

dag = self.get_dag()
task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)
task_ids: Set[str] = set()
if missing_indexes:
tis = self.get_task_instances(session=session)
for ti in tis:
task_instance_mutation_hook(ti)
task_ids.add(ti.task_id)
else:
task_ids, missing_indexes = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)

def task_filter(task: "Operator") -> bool:
return task.task_id not in task_ids and (
Expand All @@ -841,27 +848,29 @@ def task_filter(task: "Operator") -> bool:
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)

# Create the missing tasks, including mapped tasks
tasks = self._create_missing_tasks(dag, task_creator, task_filter, session=session)
tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session)

self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)

def _check_for_removed_or_restored_tasks(
self, dag: "DAG", ti_mutation_hook, *, session: Session
) -> Set[str]:
) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]:
"""
Check for removed tasks/restored tasks.
Check for removed tasks/restored/missing tasks.

:param dag: DAG object corresponding to the dagrun
:param ti_mutation_hook: task_instance_mutation_hook function
:param session: Sqlalchemy ORM Session

:return: List of task_ids in the dagrun
:return: List of task_ids in the dagrun and missing task indexes

"""
tis = self.get_task_instances(session=session)

# check for removed or restored tasks
task_ids = set()
existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list)
expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for ti in tis:
ti_mutation_hook(ti)
task_ids.add(ti.task_id)
Expand Down Expand Up @@ -902,7 +911,8 @@ def _check_for_removed_or_restored_tasks(
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
ti.state = State.NONE
existing_indexes[task].append(ti.map_index)
expected_indexes[task] = range(num_mapped_tis)
else:
# What if it is _now_ dynamically mapped, but wasn't before?
total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
Expand All @@ -923,8 +933,16 @@ def _check_for_removed_or_restored_tasks(
total_length,
)
ti.state = State.REMOVED
...
return task_ids
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
existing_indexes[task].append(ti.map_index)
expected_indexes[task] = range(total_length)
# Check if we have some missing indexes to create ti for
missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))})
return task_ids, missing_indexes

def _get_task_creator(
self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
Expand Down Expand Up @@ -961,7 +979,13 @@ def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
return creator

def _create_missing_tasks(
self, dag: "DAG", task_creator: Callable, task_filter: Callable, *, session: Session
self,
dag: "DAG",
task_creator: Callable,
task_filter: Callable,
missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]],
*,
session: Session,
) -> Iterable["Operator"]:
"""
Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
Expand All @@ -972,7 +996,9 @@ def _create_missing_tasks(
:param session: the session to use
"""

def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
def expand_mapped_literals(
task: "Operator", sequence: Union[Sequence[int], None] = None
) -> Tuple["Operator", Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
task = cast("MappedOperator", task)
Expand All @@ -981,11 +1007,19 @@ def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]
)
if not count:
return (task, (-1,))
if sequence:
return (task, sequence)
return (task, range(count))

tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))

tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
if missing_indexes:
# If there are missing indexes, override the tasks to create
new_tasks_and_map_idxs = itertools.starmap(
expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0]
)
tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs))
return tasks

def _create_task_instances(
Expand Down Expand Up @@ -1027,6 +1061,42 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _find_missing_task_indexes(self, dag, tis, *, session) -> Dict["MappedOperator", Sequence[int]]:
"""
Here we check if the length of the mapped task instances changed
at runtime. If so, we find the missing indexes.

This function also marks task instances with missing tasks as REMOVED.

:param dag: DAG object corresponding to the dagrun
:param tis: task instances to check
:param session: the session to use
"""
existing_indexes: Dict["MappedOperator", list] = defaultdict(list)
new_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for ti in tis:
try:
task = ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
self.log.error("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id)

ti.state = State.REMOVED
session.flush()
continue
if not task.is_mapped:
continue
# skip unexpanded tasks and also tasks that expands with literal arguments
if ti.map_index < 0 or task.parse_time_mapped_ti_count:
continue
existing_indexes[task].append(ti.map_index)
task.run_time_mapped_ti_count.cache_clear()
new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0
new_indexes[task] = range(new_length)
missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
missing_indexes.update({k: list(set(new_indexes[k]).difference(v))})
return missing_indexes

@staticmethod
def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
"""
Expand Down
Loading