Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ class derived from this one results in the creation of a task object,
:param do_xcom_push: if True, an XCom is pushed containing the Operator's
result
:type do_xcom_push: bool
:param current_group: the current group of the task
:type current_group: str
:param parent_group: the parent group of the task
:type parent_group: str
"""
# For derived classes to define which fields will get jinjaified
template_fields: Iterable[str] = ()
Expand Down Expand Up @@ -359,6 +363,9 @@ def __init__(
do_xcom_push: bool = True,
inlets: Optional[Any] = None,
outlets: Optional[Any] = None,
current_group: Optional[str] = None,
parent_group: Optional[str] = None,
*args,
**kwargs
):
from airflow.models.dag import DagContext
Expand Down Expand Up @@ -456,7 +463,7 @@ def __init__(
# subdag parameter is only set for SubDagOperator.
# Setting it to None by default as other Operators do not have that field
from airflow.models.dag import DAG
self.subdag: Optional[DAG] = None
self._subdag: Optional[DAG] = None

self._log = logging.getLogger("airflow.task.operators")

Expand All @@ -473,6 +480,9 @@ def __init__(
if outlets:
self._outlets = outlets if isinstance(outlets, list) else [outlets, ]

self.current_group = current_group
self.parent_group = parent_group

def __eq__(self, other):
if type(self) is type(other) and self.task_id == other.task_id:
return all(self.__dict__.get(c, None) == other.__dict__.get(c, None) for c in self._comps)
Expand Down Expand Up @@ -1112,6 +1122,13 @@ def dry_run(self) -> None:
self.log.info('Rendering template for %s', field)
self.log.info(content)

def _remove_direct_relative_id(self, task_id: str, upstream: bool = False) -> None:
"""Remove a task id from the direct relative upstream/downstream task ids"""
if upstream:
self._upstream_task_ids.remove(task_id)
else:
self._downstream_task_ids.remove(task_id)

def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
"""
Get set of the direct relative ids to the current task, upstream or
Expand Down
50 changes: 31 additions & 19 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,34 +355,46 @@ def bag_dag(self, dag, root_dag):
Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags
"""

test_cycle(dag) # throws if a task cycle is found

dag.resolve_template_files()
dag.last_loaded = timezone.utcnow()

for task in dag.tasks:
settings.policy(task)

subdags = dag.subdags
from airflow.operators.subdag_operator import SubDagOperator
from airflow.models.baseoperator import cross_downstream
for task_id, task in dag.task_dict.copy().items():
if not isinstance(task, SubDagOperator):
continue
else:
del root_dag.task_dict[task_id]

subdag = task.subdag
for subdag_task in subdag.tasks:
del subdag_task._dag
root_dag.add_task(subdag_task)
subdag_task.parent_group = dag.dag_id
subdag_task.current_group = subdag.dag_id

upstream_tasks = task.upstream_list
for upstream_task in upstream_tasks:
upstream_task._remove_direct_relative_id(task_id, upstream=False)
cross_downstream(from_tasks=upstream_tasks, to_tasks=subdag.roots)

downstream_tasks = task.downstream_list
for downstream_task in downstream_tasks:
downstream_task._remove_direct_relative_id(task_id, upstream=True)
cross_downstream(from_tasks=subdag.leaves, to_tasks=downstream_tasks)

self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag)

try:
for subdag in subdags:
subdag.full_filepath = dag.full_filepath
subdag.parent_dag = dag
subdag.is_subdag = True
self.bag_dag(dag=subdag, root_dag=root_dag)

self.dags[dag.dag_id] = dag
self.log.debug('Loaded DAG %s', dag)
if dag is root_dag:
dag.last_loaded = timezone.utcnow()
test_cycle(dag)
self.dags[dag.dag_id] = dag
self.log.debug('Loaded DAG %s', dag)
except AirflowDagCycleException as cycle_exception:
# There was an error in bagging the dag. Remove it from the list of dags
self.log.exception('Exception bagging dag: %s', dag.dag_id)
# Only necessary at the root level since DAG.subdags automatically
# performs DFS to search through all subdags
if dag == root_dag:
for subdag in subdags:
if subdag.dag_id in self.dags:
del self.dags[subdag.dag_id]
raise cycle_exception

def collect_dags(
Expand Down
31 changes: 31 additions & 0 deletions airflow/models/task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from sqlalchemy import Column, ForeignKey, String

from airflow.models.base import COLLATION_ARGS, ID_LEN, Base


class TaskGroup(Base):
"""
A task group per dag per task; grouping is rendered in the Graph/Tree view.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Serialized dag storing everything already I'm not sure if this needs it's own model or not.

Copy link
Contributor Author

@xinbinhuang xinbinhuang Jun 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Ash,

I am not familiar with the code on Serialized dag. Can you point me to the code that you are referring to? I will take a look at it

"""
__tablename__ = "task_group"
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), ForeignKey('dag.dag_id'), primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
current_group = Column(String(ID_LEN))
parent_group = Column(String(ID_LEN))
212 changes: 34 additions & 178 deletions airflow/operators/subdag_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,197 +18,53 @@
"""
The module which provides a way to nest your DAGs and so your levels of complexity.
"""
from enum import Enum
from typing import Optional
from typing import Callable, Optional
from cached_property import cached_property

from sqlalchemy.orm.session import Session

from airflow.api.common.experimental.get_task_instance import get_task_instance
from airflow.exceptions import AirflowException, TaskInstanceNotFound
from airflow.models import DagRun
from airflow.models.dag import DAG, DagContext
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.models.dag import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType


class SkippedStatePropagationOptions(Enum):
"""
Available options for skipped state propagation of subdag's tasks to parent dag tasks.
class SubDagOperator(BaseOperator):
"""
ALL_LEAVES = 'all_leaves'
ANY_LEAF = 'any_leaf'
This creates a SubDag. A SubDag's tasks will be recursively unpacked and append
to the root DAG during parsing.

The factory function should satisfy the following signature.

class SubDagOperator(BaseSensorOperator):
"""
This runs a sub dag. By convention, a sub dag's dag_id
should be prefixed by its parent and a dot. As in `parent.child`.
def dag_factory(dag_id, ...):
dag = DAG(
dag_id=dag_id,
...
)

Although SubDagOperator can occupy a pool/concurrency slot,
user can specify the mode=reschedule so that the slot will be
released periodically to avoid potential deadlock.
The first positional argument must be a dag_id passing to the DAG constructor. Internally,
it will be passed with the operator.task_id to create metadata to render grouping in the UI.

:param subdag: the DAG object to run as a subdag of the current DAG.
:param session: sqlalchemy session
:param propagate_skipped_state: by setting this argument you can define
whether the skipped state of leaf task(s) should be propagated to the parent dag's downstream task.
:param subdag_factory: a DAG factory function that returns a dag when called
:param subdag_args: a list of positional arguments that will get unpacked when
calling the factory function
:param subdag_kwargs: a dictionary of keyword arguments that will get unpacked
in the factory function
"""

ui_color = '#555'
ui_fgcolor = '#fff'

@provide_session
@apply_defaults
def __init__(self,
*,
subdag: DAG,
session: Optional[Session] = None,
propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.subdag = subdag
self.propagate_skipped_state = propagate_skipped_state

self._validate_dag(kwargs)
self._validate_pool(session)

def _validate_dag(self, kwargs):
dag = kwargs.get('dag') or DagContext.get_current_dag()

if not dag:
raise AirflowException('Please pass in the `dag` param or call within a DAG context manager')

if dag.dag_id + '.' + kwargs['task_id'] != self.subdag.dag_id:
raise AirflowException(
"The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. "
"Expected '{d}.{t}'; received '{rcvd}'.".format(
d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id)
)

def _validate_pool(self, session):
if self.pool:
conflicts = [t for t in self.subdag.tasks if t.pool == self.pool]
if conflicts:
# only query for pool conflicts if one may exist
pool = (session
.query(Pool)
.filter(Pool.slots == 1)
.filter(Pool.pool == self.pool)
.first())
if pool and any(t.pool == self.pool for t in self.subdag.tasks):
raise AirflowException(
'SubDagOperator {sd} and subdag task{plural} {t} both '
'use pool {p}, but the pool only has 1 slot. The '
'subdag tasks will never run.'.format(
sd=self.task_id,
plural=len(conflicts) > 1,
t=', '.join(t.task_id for t in conflicts),
p=self.pool
)
)

def _get_dagrun(self, execution_date):
dag_runs = DagRun.find(
dag_id=self.subdag.dag_id,
execution_date=execution_date,
)
return dag_runs[0] if dag_runs else None

def _reset_dag_run_and_task_instances(self, dag_run, execution_date):
"""
Set the DagRun state to RUNNING and set the failed TaskInstances to None state
for scheduler to pick up.

:param dag_run: DAG run
:param execution_date: Execution date
:return: None
"""
with create_session() as session:
dag_run.state = State.RUNNING
session.merge(dag_run)
failed_task_instances = (session
.query(TaskInstance)
.filter(TaskInstance.dag_id == self.subdag.dag_id)
.filter(TaskInstance.execution_date == execution_date)
.filter(TaskInstance.state.in_([State.FAILED, State.UPSTREAM_FAILED])))

for task_instance in failed_task_instances:
task_instance.state = State.NONE
session.merge(task_instance)
session.commit()

def pre_execute(self, context):
execution_date = context['execution_date']
dag_run = self._get_dagrun(execution_date)

if dag_run is None:
dag_run = self.subdag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=execution_date,
state=State.RUNNING,
external_trigger=True,
)
self.log.info("Created DagRun: %s", dag_run.run_id)
else:
self.log.info("Found existing DagRun: %s", dag_run.run_id)
if dag_run.state == State.FAILED:
self._reset_dag_run_and_task_instances(dag_run, execution_date)

def poke(self, context):
execution_date = context['execution_date']
dag_run = self._get_dagrun(execution_date=execution_date)
return dag_run.state != State.RUNNING

def post_execute(self, context, result=None):
execution_date = context['execution_date']
dag_run = self._get_dagrun(execution_date=execution_date)
self.log.info("Execution finished. State is %s", dag_run.state)

if dag_run.state != State.SUCCESS:
raise AirflowException(
"Expected state: SUCCESS. Actual state: {}".format(dag_run.state)
)

if self.propagate_skipped_state and self._check_skipped_states(context):
self._skip_downstream_tasks(context)

def _check_skipped_states(self, context):
leaves_tis = self._get_leaves_tis(context['execution_date'])

if self.propagate_skipped_state == SkippedStatePropagationOptions.ANY_LEAF:
return any(ti.state == State.SKIPPED for ti in leaves_tis)
if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES:
return all(ti.state == State.SKIPPED for ti in leaves_tis)
raise AirflowException(
'Unimplemented SkippedStatePropagationOptions {} used.'.format(self.propagate_skipped_state))

def _get_leaves_tis(self, execution_date):
leaves_tis = []
for leaf in self.subdag.leaves:
try:
ti = get_task_instance(
dag_id=self.subdag.dag_id,
task_id=leaf.task_id,
execution_date=execution_date
)
leaves_tis.append(ti)
except TaskInstanceNotFound:
continue
return leaves_tis

def _skip_downstream_tasks(self, context):
self.log.info('Skipping downstream tasks because propagate_skipped_state is set to %s '
'and skipped task(s) were found.', self.propagate_skipped_state)

downstream_tasks = context['task'].downstream_list
self.log.debug('Downstream task_ids %s', downstream_tasks)

if downstream_tasks:
self.skip(context['dag_run'], context['execution_date'], downstream_tasks)

self.log.info('Done.')
subdag_factory: Callable[..., DAG],
subdag_args: Optional[list] = None,
subdag_kwargs: Optional[dict] = None,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.subdag_args = subdag_args or []
self.subdag_kwargs = subdag_kwargs or {}
self.subdag_factory = subdag_factory

@cached_property
def subdag(self) -> DAG:
"""The SubDag carried by the operator"""
self._subdag = self.subdag_factory(self.task_id, *self.subdag_args, **self.subdag_kwargs)
return self._subdag