Skip to content
Merged
2 changes: 1 addition & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def dag_backfill(args, dag=None):

if args.task_regex:
dag = dag.partial_subset(
task_regex=args.task_regex,
task_ids_or_regex=args.task_regex,
include_upstream=not args.ignore_dependencies)

run_conf = None
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def task_clear(args):
if args.task_regex:
for idx, dag in enumerate(dags):
dags[idx] = dag.partial_subset(
task_regex=args.task_regex,
task_ids_or_regex=args.task_regex,
include_downstream=args.downstream,
include_upstream=args.upstream)

Expand Down
11 changes: 11 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,17 @@
type: string
default: ~
see_also: ":ref:`scheduler:ha:tunables`"
- name: schedule_after_task_execution
description: |
Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the
same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other
dags in some circumstances

Default: True
example: ~
version_added: 2.0.0
type: boolean
default: ~
- name: statsd_on
description: |
Statsd (https://github.com/etsy/statsd) integration settings
Expand Down
7 changes: 7 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,13 @@ use_row_level_locking = True
# Default: 20
# max_dagruns_per_loop_to_schedule =

# Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the
# same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other
# dags in some circumstances
#
# Default: True
# schedule_after_task_execution =

# Statsd (https://github.com/etsy/statsd) integration settings
statsd_on = False
statsd_host = localhost
Expand Down
1 change: 1 addition & 0 deletions airflow/config_templates/default_test.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ sync_parallelism = 0

[scheduler]
job_heartbeat_sec = 1
schedule_after_task_execution = False
scheduler_heartbeat_sec = 5
scheduler_health_check_threshold = 30
max_threads = 2
Expand Down
33 changes: 1 addition & 32 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,42 +1652,11 @@ def _schedule_dag_run(

self._send_dag_callbacks_to_processor(dag_run, callback_to_run)

# Get list of TIs that do not need to executed, these are
# tasks using DummyOperator and without on_execute_callback / on_success_callback
dummy_tis = [
ti for ti in schedulable_tis
if
(
ti.task.task_type == "DummyOperator"
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
)
]

# This will do one query per dag run. We "could" build up a complex
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that we should also move the comment. Now it has lost context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't make sense when called on a (instance) method on DagRun, as that almost by definition only operators on a single dag run. The comment is kept here in the scheduler because that's where might think we want to batch the queries up, but shouldn't.

# query to update all the TIs across all the execution dates and dag
# IDs in a single query, but it turns out that can be _very very slow_
# see #11147/commit ee90807ac for more details
count = session.query(TI).filter(
TI.dag_id == dag_run.dag_id,
TI.execution_date == dag_run.execution_date,
TI.task_id.in_(ti.task_id for ti in schedulable_tis if ti not in dummy_tis)
).update({TI.state: State.SCHEDULED}, synchronize_session=False)

# Tasks using DummyOperator should not be executed, mark them as success
if dummy_tis:
session.query(TI).filter(
TI.dag_id == dag_run.dag_id,
TI.execution_date == dag_run.execution_date,
TI.task_id.in_(ti.task_id for ti in dummy_tis)
).update({
TI.state: State.SUCCESS,
TI.start_date: timezone.utcnow(),
TI.end_date: timezone.utcnow(),
TI.duration: 0
}, synchronize_session=False)

return count
return dag_run.schedule_tis(schedulable_tis, session)

@provide_session
def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,8 @@ def get_task_instances(self, start_date: Optional[datetime] = None,

def get_flat_relative_ids(self,
upstream: bool = False,
found_descendants: Optional[Set[str]] = None) -> Set[str]:
found_descendants: Optional[Set[str]] = None,
) -> Set[str]:
"""Get a flat set of relatives' ids, either upstream or downstream."""
if not self._dag:
return set()
Expand All @@ -1026,8 +1027,7 @@ def get_flat_relative_ids(self,
if relative_id not in found_descendants:
found_descendants.add(relative_id)
relative_task = self._dag.task_dict[relative_id]
relative_task.get_flat_relative_ids(upstream,
found_descendants)
relative_task.get_flat_relative_ids(upstream, found_descendants)

return found_descendants

Expand Down
44 changes: 35 additions & 9 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@
if TYPE_CHECKING:
from airflow.utils.task_group import TaskGroup


# Before Py 3.7, there is no re.Pattern class
try:
from re import Pattern as PatternType # type: ignore
except ImportError:
PatternType = type(re.compile('', 0))


log = logging.getLogger(__name__)

ScheduleInterval = Union[str, timedelta, relativedelta]
Expand Down Expand Up @@ -1172,7 +1180,7 @@ def clear(

if include_parentdag and self.is_subdag and self.parent_dag is not None:
p_dag = self.parent_dag.sub_dag(
task_regex=r"^{}$".format(self.dag_id.split('.')[1]),
task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]),
include_upstream=False,
include_downstream=True)

Expand Down Expand Up @@ -1245,7 +1253,7 @@ def clear(
if not external_dag:
raise AirflowException("Could not find dag {}".format(tii.dag_id))
downstream = external_dag.sub_dag(
task_regex=r"^{}$".format(tii.task_id),
task_ids_or_regex=r"^{}$".format(tii.task_id),
include_upstream=False,
include_downstream=True
)
Expand Down Expand Up @@ -1394,36 +1402,54 @@ def sub_dag(self, *args, **kwargs):
return self.partial_subset(*args, **kwargs)

def partial_subset(
self, task_regex, include_downstream=False, include_upstream=True
self,
task_ids_or_regex: Union[str, PatternType, Iterable[str]],
include_downstream=False,
include_upstream=True,
include_direct_upstream=False,
):
"""
Returns a subset of the current dag as a deep copy of the current dag
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.

:param task_ids_or_regex: Either a list of task_ids, or a regex to
match against task ids (as a string, or compiled regex pattern).
:type task_ids_or_regex: [str] or str or re.Pattern
:param include_downstream: Include all downstream tasks of matched
tasks, in addition to matched tasks.
:param include_upstream: Include all upstream tasks of matched tasks,
in addition to matched tasks.
"""
# deep-copying self.task_dict and self._task_group takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
task_dict = self.task_dict
task_group = self._task_group
self.task_dict = {}
self._task_group = None
self._task_group = None # type: ignore
dag = copy.deepcopy(self)
self.task_dict = task_dict
self._task_group = task_group

regex_match = [
t for t in self.tasks if re.findall(task_regex, t.task_id)]
if isinstance(task_ids_or_regex, (str, PatternType)):
matched_tasks = [
t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)]
else:
matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex]

also_include = []
for t in regex_match:
for t in matched_tasks:
if include_downstream:
also_include += t.get_flat_relatives(upstream=False)
if include_upstream:
also_include += t.get_flat_relatives(upstream=True)
elif include_direct_upstream:
also_include += t.upstream_list

# Compiling the unique list of tasks that made the cut
# Make sure to not recursively deepcopy the dag while copying the task
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
for t in regex_match + also_include}
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore
for t in matched_tasks + also_include}

def filter_task_group(group, parent_group):
"""Exclude tasks not included in the subdag from the given TaskGroup."""
Expand Down
121 changes: 101 additions & 20 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, Union

from sqlalchemy import (
Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_,
Expand All @@ -43,6 +43,16 @@
from airflow.utils.types import DagRunType


class TISchedulingDecision(NamedTuple):
"""Type of return for DagRun.task_instance_scheduling_decisions"""

tis: List[TI]
schedulable_tis: List[TI]
changed_tis: bool
unfinished_tasks: List[TI]
finished_tasks: List[TI]


class DagRun(Base, LoggingMixin):
"""
DagRun describes an instance of a Dag. It can be created
Expand Down Expand Up @@ -380,27 +390,21 @@ def update_state(
self.last_scheduling_decision = start_dttm

dag = self.get_dag()
ready_tis: List[TI] = []
tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,)))
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
for ti in tis:
ti.task = dag.get_task(ti.task_id)
info = self.task_instance_scheduling_decisions(session)

tis = info.tis
schedulable_tis = info.schedulable_tis
changed_tis = info.changed_tis
finished_tasks = info.finished_tasks
unfinished_tasks = info.unfinished_tasks

unfinished_tasks = [t for t in tis if t.state in State.unfinished]
finished_tasks = [t for t in tis if t.state in State.finished]
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)
if unfinished_tasks:
scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
self.log.debug(
"number of scheduleable tasks for %s: %s task(s)",
self, len(scheduleable_tasks))
ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis))
if none_depends_on_past and none_task_concurrency:
# small speed up
are_runnable_tasks = ready_tis or self._are_premature_tis(
unfinished_tasks, finished_tasks, session) or changed_tis

if unfinished_tasks and none_depends_on_past and none_task_concurrency:
# small speed up
are_runnable_tasks = schedulable_tis or self._are_premature_tis(
unfinished_tasks, finished_tasks, session) or changed_tis

duration = (timezone.utcnow() - start_dttm)
Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)
Expand Down Expand Up @@ -466,7 +470,35 @@ def update_state(

session.merge(self)

return ready_tis, callback
return schedulable_tis, callback

@provide_session
def task_instance_scheduling_decisions(self, session: Session = None) -> TISchedulingDecision:

schedulable_tis: List[TI] = []
changed_tis = False

tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,)))
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
for ti in tis:
ti.task = self.get_dag().get_task(ti.task_id)

unfinished_tasks = [t for t in tis if t.state in State.unfinished]
finished_tasks = [t for t in tis if t.state in State.finished]
if unfinished_tasks:
scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
self.log.debug(
"number of scheduleable tasks for %s: %s task(s)",
self, len(scheduleable_tasks))
schedulable_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)

return TISchedulingDecision(
tis=tis,
schedulable_tis=schedulable_tis,
changed_tis=changed_tis,
unfinished_tasks=unfinished_tasks,
finished_tasks=finished_tasks,
)

def _get_ready_tis(
self,
Expand Down Expand Up @@ -638,3 +670,52 @@ def get_latest_runs(cls, session=None):
.all()
)
return dagruns

@provide_session
def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -> int:
"""
Set the given task instances in to the scheduled state.

Each element of ``schedulable_tis`` should have it's ``task`` attribute already set.

Any DummyOperator without callbacks is instead set straight to the success state.

All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it
is the caller's responsibility to call this function only with TIs from a single dag run.
"""
# Get list of TIs that do not need to executed, these are
# tasks using DummyOperator and without on_execute_callback / on_success_callback
dummy_tis = {
ti for ti in schedulable_tis
if
(
ti.task.task_type == "DummyOperator"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it work for operators that inherit from DummyOperator?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is already an issue for that though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and not ti.task.on_execute_callback
and not ti.task.on_success_callback
)
}

schedulable_ti_ids = [ti.task_id for ti in schedulable_tis if ti not in dummy_tis]
count = 0

if schedulable_ti_ids:
count += session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.execution_date == self.execution_date,
TI.task_id.in_(schedulable_ti_ids)
).update({TI.state: State.SCHEDULED}, synchronize_session=False)

# Tasks using DummyOperator should not be executed, mark them as success
if dummy_tis:
count += session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.execution_date == self.execution_date,
TI.task_id.in_(ti.task_id for ti in dummy_tis)
).update({
TI.state: State.SUCCESS,
TI.start_date: timezone.utcnow(),
TI.end_date: timezone.utcnow(),
TI.duration: 0
}, synchronize_session=False)

return count
Loading