Skip to content

Commit

Permalink
[AIRFLOW-3129] Improve test coverage of airflow.models. (apache#3982)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarp authored and Alice Berard committed Jan 3, 2019
1 parent 0941203 commit f3af0c1
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 26 deletions.
25 changes: 0 additions & 25 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,6 @@ def dagbag_report(self):
table=pprinttable(stats),
)

@provide_session
def deactivate_inactive_dags(self, session=None):
active_dag_ids = [dag.dag_id for dag in list(self.dags.values())]
for dag in session.query(
DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
dag.is_active = False
session.merge(dag)
session.commit()

@provide_session
def paused_dags(self, session=None):
dag_ids = [dp.dag_id for dp in session.query(DagModel).filter(
DagModel.is_paused.__eq__(True))]
return dag_ids


class User(Base):
__tablename__ = "users"
Expand Down Expand Up @@ -4202,16 +4187,6 @@ def add_tasks(self, tasks):
for task in tasks:
self.add_task(task)

@provide_session
def db_merge(self, session=None):
BO = BaseOperator
tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
for t in tasks:
session.delete(t)
session.commit()
session.merge(self)
session.commit()

def run(
self,
start_date=None,
Expand Down
181 changes: 180 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,20 @@
from airflow.models import clear_task_instances
from airflow.models import XCom
from airflow.models import Connection
from airflow.models import SkipMixin
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier
from airflow.jobs import LocalTaskJob
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.operators.subdag_operator import SubDagOperator
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.weight_rule import WeightRule
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from mock import patch, ANY
from mock import patch, Mock, ANY
from parameterized import parameterized
from tempfile import mkdtemp, NamedTemporaryFile

Expand Down Expand Up @@ -575,6 +578,38 @@ def test_cycle(self):
with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()

@patch('airflow.models.timezone.utcnow')
def test_sync_to_db(self, mock_now):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
DummyOperator(task_id='task', owner='owner1')
SubDagOperator(
task_id='subtask',
owner='owner2',
subdag=DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
)
)
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
session = settings.Session()
dag.sync_to_db(session=session)

orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
self.assertEqual(orm_dag.last_scheduler_run, now)
self.assertTrue(orm_dag.is_active)

orm_subdag = session.query(DagModel).filter(
DagModel.dag_id == 'dag.subtask').one()
self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'})
self.assertEqual(orm_subdag.last_scheduler_run, now)
self.assertTrue(orm_subdag.is_active)


class DagStatTest(unittest.TestCase):
def test_dagstats_crud(self):
Expand Down Expand Up @@ -625,6 +660,25 @@ def test_dagstats_crud(self):
for stat in res:
self.assertFalse(stat.dirty)

def test_update_exception(self):
session = Mock()
(session.query.return_value
.filter.return_value
.with_for_update.return_value
.all.side_effect) = RuntimeError('it broke')
DagStat.update(session=session)
session.rollback.assert_called()

def test_set_dirty_exception(self):
session = Mock()
session.query.return_value.filter.return_value.all.return_value = []
(session.query.return_value
.filter.return_value
.with_for_update.return_value
.all.side_effect) = RuntimeError('it broke')
DagStat.set_dirty('dag', session)
session.rollback.assert_called()


class DagRunTest(unittest.TestCase):

Expand Down Expand Up @@ -2349,6 +2403,35 @@ def test_overwrite_params_with_dag_run_conf_none(self):

self.assertEqual(False, params["override"])

@patch('airflow.models.send_email')
def test_email_alert(self, mock_send_email):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.email_alert(RuntimeError('it broke'))

self.assertTrue(mock_send_email.called)
(email, title, body), _ = mock_send_email.call_args
self.assertEqual(email, 'test@test.test')
self.assertIn(repr(ti), title)
self.assertIn('it broke', body)

def test_set_duration(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(
task=task,
execution_date=datetime.datetime.now(),
)
ti.start_date = datetime.datetime(2018, 10, 1, 1)
ti.end_date = datetime.datetime(2018, 10, 1, 2)
ti.set_duration()
self.assertEqual(ti.duration, 3600)

def test_set_duration_empty_dates(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.set_duration()
self.assertIsNone(ti.duration)


class ClearTasksTest(unittest.TestCase):

Expand Down Expand Up @@ -2705,3 +2788,99 @@ def test_connection_from_uri_with_extras(self):
self.assertEqual(connection.port, 1234)
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
'extra2': '/path/'})


class TestSkipMixin(unittest.TestCase):

@patch('airflow.models.timezone.utcnow')
def test_skip(self, mock_now):
session = settings.Session()
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
tasks = [DummyOperator(task_id='task')]
dag_run = dag.create_dagrun(
run_id='manual__' + now.isoformat(),
state=State.FAILED,
)
SkipMixin().skip(
dag_run=dag_run,
execution_date=now,
tasks=tasks,
session=session)

session.query(TI).filter(
TI.dag_id == 'dag',
TI.task_id == 'task',
TI.state == State.SKIPPED,
TI.start_date == now,
TI.end_date == now,
).one()

@patch('airflow.models.timezone.utcnow')
def test_skip_none_dagrun(self, mock_now):
session = settings.Session()
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
tasks = [DummyOperator(task_id='task')]
SkipMixin().skip(
dag_run=None,
execution_date=now,
tasks=tasks,
session=session)

session.query(TI).filter(
TI.dag_id == 'dag',
TI.task_id == 'task',
TI.state == State.SKIPPED,
TI.start_date == now,
TI.end_date == now,
).one()

def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
self.assertFalse(session.query.called)
self.assertFalse(session.commit.called)


class TestKubeResourceVersion(unittest.TestCase):

def test_checkpoint_resource_version(self):
session = settings.Session()
KubeResourceVersion.checkpoint_resource_version('7', session)
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7')

def test_reset_resource_version(self):
session = settings.Session()
version = KubeResourceVersion.reset_resource_version(session)
self.assertEqual(version, '0')
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0')


class TestKubeWorkerIdentifier(unittest.TestCase):

@patch('airflow.models.uuid.uuid4')
def test_get_or_create_not_exist(self, mock_uuid):
session = settings.Session()
session.query(KubeWorkerIdentifier).update({
KubeWorkerIdentifier.worker_uuid: ''
})
mock_uuid.return_value = 'abcde'
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
self.assertEqual(worker_uuid, 'abcde')

def test_get_or_create_exist(self):
session = settings.Session()
KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session)
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
self.assertEqual(worker_uuid, 'fghij')

0 comments on commit f3af0c1

Please sign in to comment.