From ab4d04870cfbc7b07a634be90a90792b04dcd1f7 Mon Sep 17 00:00:00 2001 From: Peter van 't Hof Date: Mon, 31 Dec 2018 12:46:06 +0100 Subject: [PATCH] [AIRFLOW-3600] Remove dagbag from trigger (#4407) * Remove dagbag from trigger call * Adding fix to rbac * empty commit * Added create_dagrun to DagModel * Adding testing to /trigger calls * Make session a class var --- airflow/models/__init__.py | 38 ++++++++++++++++++++++++++++++++++++ airflow/www/views.py | 8 ++++---- airflow/www_rbac/views.py | 6 +++--- tests/www/test_views.py | 31 +++++++++++++++++++++++++++++ tests/www_rbac/test_views.py | 27 +++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 7 deletions(-) diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 354377b57c515..d1236fa40b207 100755 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -3000,6 +3000,44 @@ def get_default_view(self): else: return self.default_view + def get_dag(self): + return DagBag(dag_folder=self.fileloc).get_dag(self.dag_id) + + @provide_session + def create_dagrun(self, + run_id, + state, + execution_date, + start_date=None, + external_trigger=False, + conf=None, + session=None): + """ + Creates a dag run from this dag including the tasks associated with this dag. + Returns the dag run. + + :param run_id: defines the the run id for this dag run + :type run_id: str + :param execution_date: the execution date of this dag run + :type execution_date: datetime + :param state: the state of the dag run + :type state: State + :param start_date: the date this dag run should be evaluated + :type start_date: datetime + :param external_trigger: whether this dag run is externally triggered + :type external_trigger: bool + :param session: database session + :type session: Session + """ + + return self.get_dag().create_dagrun(run_id=run_id, + state=state, + execution_date=execution_date, + start_date=start_date, + external_trigger=external_trigger, + conf=conf, + session=session) + @functools.total_ordering class DAG(BaseDag, LoggingMixin): diff --git a/airflow/www/views.py b/airflow/www/views.py index 5bde8d3ee824e..7746fb6de5fc9 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1069,11 +1069,11 @@ def delete(self): @login_required @wwwutils.action_logging @wwwutils.notify_owner - def trigger(self): + @provide_session + def trigger(self, session=None): dag_id = request.args.get('dag_id') origin = request.args.get('origin') or "/admin/" - dag = dagbag.get_dag(dag_id) - + dag = session.query(models.DagModel).filter(models.DagModel.dag_id == dag_id).first() if not dag: flash("Cannot find dag {}".format(dag_id)) return redirect(origin) @@ -1592,7 +1592,7 @@ class GraphForm(DateTimeWithNumRunsWithDagRunsForm): task_instances=json.dumps(task_instances, indent=2), tasks=json.dumps(tasks, indent=2), nodes=json.dumps(nodes, indent=2), - edges=json.dumps(edges, indent=2), ) + edges=json.dumps(edges, indent=2)) @expose('/duration') @login_required diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index 5c4476a667d4a..17efcdfe7c22d 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -798,11 +798,11 @@ def delete(self): @has_dag_access(can_dag_edit=True) @has_access @action_logging - def trigger(self): + @provide_session + def trigger(self, session=None): dag_id = request.args.get('dag_id') origin = request.args.get('origin') or "/" - dag = dagbag.get_dag(dag_id) - + dag = session.query(models.DagModel).filter(models.DagModel.dag_id == dag_id).first() if not dag: flash("Cannot find dag {}".format(dag_id)) return redirect(origin) diff --git a/tests/www/test_views.py b/tests/www/test_views.py index a0d86ac54a02a..aeaca96333e9d 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -30,6 +30,7 @@ from urllib.parse import quote_plus from werkzeug.test import Client +from sqlalchemy import func from airflow import models, configuration from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG @@ -821,5 +822,35 @@ def test_delete_dag_button_for_dag_on_scheduler_only(self): session.commit() +class TestTriggerDag(unittest.TestCase): + + def setUp(self): + conf.load_test_config() + app = application.create_app(testing=True) + app.config['WTF_CSRF_METHODS'] = [] + self.app = app.test_client() + self.session = Session() + models.DagBag().get_dag("example_bash_operator").sync_to_db() + + def test_trigger_dag_button_normal_exist(self): + resp = self.app.get('/', follow_redirects=True) + self.assertIn('/trigger?dag_id=example_bash_operator', resp.data.decode('utf-8')) + self.assertIn("return confirmDeleteDag('example_bash_operator')", resp.data.decode('utf-8')) + + def test_trigger_dag_button(self): + + test_dag_id = "example_bash_operator" + + DR = models.DagRun + self.session.query(DR).delete() + self.session.commit() + + self.app.get('/admin/airflow/trigger?dag_id={}'.format(test_dag_id)) + + run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() + self.assertIsNotNone(run) + self.assertIn("manual__", run.run_id) + + if __name__ == '__main__': unittest.main() diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py index 500fcf3d991d8..0f0de56ceaffc 100644 --- a/tests/www_rbac/test_views.py +++ b/tests/www_rbac/test_views.py @@ -1428,5 +1428,32 @@ def test_start_date_filter(self): pass +class TestTriggerDag(TestBase): + + def setUp(self): + super(TestTriggerDag, self).setUp() + self.session = Session() + models.DagBag().get_dag("example_bash_operator").sync_to_db(session=self.session) + + def test_trigger_dag_button_normal_exist(self): + resp = self.client.get('/', follow_redirects=True) + self.assertIn('/trigger?dag_id=example_bash_operator', resp.data.decode('utf-8')) + self.assertIn("return confirmDeleteDag('example_bash_operator')", resp.data.decode('utf-8')) + + def test_trigger_dag_button(self): + + test_dag_id = "example_bash_operator" + + DR = models.DagRun + self.session.query(DR).delete() + self.session.commit() + + resp = self.client.get('trigger?dag_id={}'.format(test_dag_id)) + + run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() + self.assertIsNotNone(run) + self.assertIn("manual__", run.run_id) + + if __name__ == '__main__': unittest.main()