From 0472c2dfe3513151776a720c67d9db913643c8d7 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 5 Aug 2019 16:37:19 +0200 Subject: [PATCH 01/15] Remove provide context --- .../example_dags/example_qubole_operator.py | 1 - .../contrib/utils/mlengine_operator_utils.py | 1 - airflow/example_dags/docker_copy_data.py | 1 - .../example_branch_python_dop_operator_3.py | 1 - ...example_passing_params_via_test_command.py | 9 ++--- .../example_dags/example_python_operator.py | 1 - .../example_trigger_target_dag.py | 1 - airflow/example_dags/example_xcom.py | 1 - airflow/operators/python_operator.py | 39 +++++++++---------- docs/howto/operator/python.rst | 3 +- .../operators/test_aws_athena_operator.py | 1 - .../operators/test_s3_to_sftp_operator.py | 1 - tests/contrib/operators/test_sftp_operator.py | 1 - .../operators/test_sftp_to_s3_operator.py | 1 - tests/contrib/operators/test_ssh_operator.py | 1 - tests/core.py | 1 - tests/dags/test_cli_triggered_dags.py | 1 - tests/operators/test_python_operator.py | 15 ++++--- tests/operators/test_virtualenv_operator.py | 14 ------- tests/utils/test_log_handlers.py | 2 - 20 files changed, 31 insertions(+), 65 deletions(-) diff --git a/airflow/contrib/example_dags/example_qubole_operator.py b/airflow/contrib/example_dags/example_qubole_operator.py index b07f2734e8ff..45f0f30d6a93 100644 --- a/airflow/contrib/example_dags/example_qubole_operator.py +++ b/airflow/contrib/example_dags/example_qubole_operator.py @@ -88,7 +88,6 @@ def compare_result(ds, **kwargs): t3 = PythonOperator( task_id='compare_result', - provide_context=True, python_callable=compare_result, trigger_rule="all_done", dag=dag) diff --git a/airflow/contrib/utils/mlengine_operator_utils.py b/airflow/contrib/utils/mlengine_operator_utils.py index e1682ef45ade..ed545fdb4663 100644 --- a/airflow/contrib/utils/mlengine_operator_utils.py +++ b/airflow/contrib/utils/mlengine_operator_utils.py @@ -238,7 +238,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/example_dags/docker_copy_data.py b/airflow/example_dags/docker_copy_data.py index 6aba3f2cb355..aa7600f3d8f2 100644 --- a/airflow/example_dags/docker_copy_data.py +++ b/airflow/example_dags/docker_copy_data.py @@ -70,7 +70,6 @@ # # t_is_data_available = ShortCircuitOperator( # task_id='check_if_data_available', -# provide_context=True, # python_callable=is_data_available, # dag=dag) # diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index ec60cfc01b90..7455ef7ebbd2 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -58,7 +58,6 @@ def should_run(**kwargs): cond = BranchPythonOperator( task_id='condition', - provide_context=True, python_callable=should_run, dag=dag, ) diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 152b8cde9e63..e8fc9c963a91 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -37,17 +37,17 @@ ) -def my_py_command(**kwargs): +def my_py_command(test_mode, params): """ Print out the "foo" param passed in via `airflow tasks test example_passing_params_via_test_command run_this -tp '{"foo":"bar"}'` """ - if kwargs["test_mode"]: + if test_mode: print(" 'foo' was passed in via test={} command : kwargs[params][foo] \ - = {}".format(kwargs["test_mode"], kwargs["params"]["foo"])) + = {}".format(test_mode, params["foo"])) # Print out the value of "miff", passed in below via the Python Operator - print(" 'miff' was passed in via task params = {}".format(kwargs["params"]["miff"])) + print(" 'miff' was passed in via task params = {}".format(params["miff"])) return 1 @@ -58,7 +58,6 @@ def my_py_command(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=my_py_command, params={"miff": "agg"}, dag=dag, diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 29c664f0a65e..86403ceb25b6 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -48,7 +48,6 @@ def print_context(ds, **kwargs): run_this = PythonOperator( task_id='print_the_context', - provide_context=True, python_callable=print_context, dag=dag, ) diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 475817698115..32255103d804 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -69,7 +69,6 @@ def run_this_func(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=run_this_func, dag=dag, ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index 5b7b79aca80f..1c2bb2e53296 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -26,7 +26,6 @@ args = { 'owner': 'Airflow', 'start_date': airflow.utils.dates.days_ago(2), - 'provide_context': True, } dag = DAG('example_xcom', schedule_interval="@once", default_args=args) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 46430b215e93..ad674d12f5a1 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -25,6 +25,7 @@ import types from textwrap import dedent from typing import Optional, Iterable, Dict, Callable +from inspect import signature import dill @@ -51,12 +52,6 @@ class PythonOperator(BaseOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list (templated) - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -79,7 +74,6 @@ def __init__( python_callable: Callable, op_args: Optional[Iterable] = None, op_kwargs: Optional[Dict] = None, - provide_context: bool = False, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, *args, @@ -91,7 +85,6 @@ def __init__( self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict if templates_exts: self.template_ext = templates_exts @@ -104,10 +97,21 @@ def execute(self, context): for k, v in airflow_context_vars.items()])) os.environ.update(airflow_context_vars) - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + + if {parameter for name, parameter + in signature(self.python_callable).parameters.items() + if str(parameter).startswith("**")}: + # If there is a **kwargs, **context or **_ then just pass everything. self.op_kwargs = context + else: + # If there is only for example, an execution_date, then pass only these in :-) + self.op_kwargs = { + name: context[name] for name, parameter + in signature(self.python_callable).parameters.items() + if name in context # If it isn't available on the context, then ignore + } return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) @@ -130,6 +134,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin): downstream to allow for the DAG state to fill up and the DAG run's state to be inferred. """ + def execute(self, context): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -147,6 +152,7 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ + def execute(self, context): condition = super().execute(context) self.log.info("Condition result is %s", condition) @@ -200,12 +206,6 @@ class PythonVirtualenvOperator(PythonOperator): :type op_kwargs: list :param op_kwargs: A dict of keyword arguments to pass to python_callable. :type op_kwargs: dict - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param string_args: Strings that are present in the global var virtualenv_string_args, available to python_callable at runtime as a list[str]. Note that args are split by newline. @@ -219,6 +219,7 @@ class PythonVirtualenvOperator(PythonOperator): processing templated fields, for examples ``['.sql', '.hql']`` :type templates_exts: list[str] """ + @apply_defaults def __init__( self, @@ -229,7 +230,6 @@ def __init__( system_site_packages: bool = True, op_args: Iterable = None, op_kwargs: Dict = None, - provide_context: bool = False, string_args: Optional[Iterable[str]] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, @@ -242,7 +242,6 @@ def __init__( op_kwargs=op_kwargs, templates_dict=templates_dict, templates_exts=templates_exts, - provide_context=provide_context, *args, **kwargs) self.requirements = requirements or [] @@ -383,7 +382,7 @@ def _generate_python_code(self): fn = self.python_callable # dont try to read pickle if we didnt pass anything if self._pass_op_args(): - load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)'\ + load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \ .format(pickling_library) else: load_args_line = 'arg_dict = {"args": [], "kwargs": {}}' diff --git a/docs/howto/operator/python.rst b/docs/howto/operator/python.rst index da2180138e06..1f361735fbfe 100644 --- a/docs/howto/operator/python.rst +++ b/docs/howto/operator/python.rst @@ -42,8 +42,7 @@ to the Python callable. Templating ^^^^^^^^^^ -When you set the ``provide_context`` argument to ``True``, Airflow passes in -an additional set of keyword arguments: one for each of the :doc:`Jinja +Airflow passes in a set of keyword arguments: one for each of the :doc:`Jinja template variables <../../macros>` and a ``templates_dict`` argument. The ``templates_dict`` argument is templated, so each value in the dictionary diff --git a/tests/contrib/operators/test_aws_athena_operator.py b/tests/contrib/operators/test_aws_athena_operator.py index b86bbab0f136..c3d98e25db97 100644 --- a/tests/contrib/operators/test_aws_athena_operator.py +++ b/tests/contrib/operators/test_aws_athena_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', diff --git a/tests/contrib/operators/test_s3_to_sftp_operator.py b/tests/contrib/operators/test_s3_to_sftp_operator.py index fc78f0c8698d..a40fc4557247 100644 --- a/tests/contrib/operators/test_s3_to_sftp_operator.py +++ b/tests/contrib/operators/test_s3_to_sftp_operator.py @@ -69,7 +69,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index b54c328ba5ef..30fa74101d7b 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_to_s3_operator.py b/tests/contrib/operators/test_sftp_to_s3_operator.py index 02f4e84c010b..9b45e1da1723 100644 --- a/tests/contrib/operators/test_sftp_to_s3_operator.py +++ b/tests/contrib/operators/test_sftp_to_s3_operator.py @@ -68,7 +68,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index a27dc27bc7ca..a7fe90fa95ee 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -51,7 +51,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/core.py b/tests/core.py index c6de2e0f9679..698c996e6cbf 100644 --- a/tests/core.py +++ b/tests/core.py @@ -567,7 +567,6 @@ def test_py_op(templates_dict, ds, **kwargs): t = PythonOperator( task_id='test_py_op', - provide_context=True, python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag) diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py index 9f53ca4c3ab0..f2dc7b63895d 100644 --- a/tests/dags/test_cli_triggered_dags.py +++ b/tests/dags/test_cli_triggered_dags.py @@ -51,6 +51,5 @@ def success(ti=None, *args, **kwargs): dag1_task2 = PythonOperator( task_id='test_run_dependent_task', python_callable=success, - provide_context=True, dag=dag1) dag1_task1.set_downstream(dag1_task2) diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index e5e8049aa134..de931fd863e2 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -57,8 +57,8 @@ def build_recording_function(calls_collection): Then using this custom function recording custom Call objects for further testing (replacing Mock.assert_called_with assertion method) """ - def recording_function(*args, **kwargs): - calls_collection.append(Call(*args, **kwargs)) + def recording_function(*args): + calls_collection.append(Call(*args)) return recording_function @@ -129,11 +129,10 @@ def test_python_operator_python_callable_is_callable(self): task_id='python_operator', dag=self.dag) - def _assertCallsEqual(self, first, second): + def _assert_calls_equal(self, first, second): self.assertIsInstance(first, Call) self.assertIsInstance(second, Call) self.assertTupleEqual(first.args, second.args) - self.assertDictEqual(first.kwargs, second.kwargs) def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" @@ -148,7 +147,7 @@ def test_python_callable_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_args=[ 4, date(2019, 1, 1), @@ -167,7 +166,7 @@ def test_python_callable_arguments_are_templatized(self): ds_templated = DEFAULT_DATE.date().isoformat() self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(4, date(2019, 1, 1), @@ -183,7 +182,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), @@ -200,7 +199,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(an_int=4, a_date=date(2019, 1, 1), diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 52c14f88806d..95ff2142510e 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -202,17 +202,3 @@ def test_context(self): def f(**kwargs): return kwargs['templates_dict']['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) - - def test_provide_context(self): - def fn(): - pass - task = PythonVirtualenvOperator( - python_callable=fn, - python_version=sys.version_info[0], - task_id='task', - dag=self.dag, - provide_context=True, - ) - self.assertTrue( - task.provide_context - ) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index e06bdae90891..7126871cb580 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -71,7 +71,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) @@ -123,7 +122,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.try_number = 2 From 66ab4f71092cfc8e02afa26d10d5b89d5730c156 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 11:32:02 +0200 Subject: [PATCH 02/15] Simplify the arguments --- tests/operators/test_virtualenv_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 95ff2142510e..97c3dcc4eb47 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -199,6 +199,6 @@ def f(_): self._run_as_operator(f, op_args=[datetime.datetime.utcnow()]) def test_context(self): - def f(**kwargs): - return kwargs['templates_dict']['ds'] + def f(templates_dict): + return templates_dict['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) From 15b8eefd8972a67e5ee9a12c7ae40ab11eb08eea Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 11:47:12 +0200 Subject: [PATCH 03/15] Feedback is een cadeautje --- airflow/operators/python_operator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index ad674d12f5a1..5ec578f6aa9d 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -100,16 +100,16 @@ def execute(self, context): context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict - if {parameter for name, parameter + if {param for param in signature(self.python_callable).parameters.items() - if str(parameter).startswith("**")}: + if str(param).startswith("**")}: # If there is a **kwargs, **context or **_ then just pass everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { - name: context[name] for name, parameter - in signature(self.python_callable).parameters.items() + name: context[name] + for name in signature(self.python_callable).parameters.keys() if name in context # If it isn't available on the context, then ignore } From 6c0c8bb61f4c23c5b12a8757e7fdf76eff33016c Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 12:22:03 +0200 Subject: [PATCH 04/15] Add additional tests --- tests/operators/test_python_operator.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 31f18edf1dde..76867f1b4ab7 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -250,6 +250,28 @@ def test_echo_env_variables(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_dynamic_provide_context(self): + def fn(dag): + if dag != 1: + raise ValueError("Should be 1") + + python_operator = PythonOperator( + op_kwargs={'dag': 1}, + python_callable=fn + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_dynamic_provide_context(self): + def fn(dag): + if dag != 1: + raise ValueError("Should be 1") + + python_operator = PythonOperator( + op_args=[1], + python_callable=fn + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + class TestBranchOperator(unittest.TestCase): @classmethod From c6ba062e90ab400771e136b12915b2d8e785facb Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 12:44:46 +0200 Subject: [PATCH 05/15] Cover some edge cases --- airflow/operators/python_operator.py | 17 +++++++++++------ tests/operators/test_python_operator.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 5ec578f6aa9d..3cc8f23b71a8 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -23,9 +23,10 @@ import subprocess import sys import types +from inspect import signature +from itertools import islice from textwrap import dedent from typing import Optional, Iterable, Dict, Callable -from inspect import signature import dill @@ -100,16 +101,20 @@ def execute(self, context): context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict - if {param for param - in signature(self.python_callable).parameters.items() - if str(param).startswith("**")}: - # If there is a **kwargs, **context or **_ then just pass everything. + sig_full = signature(self.python_callable).parameters.items() + # Remove the first n arguments equal to len(op_args). + # The notation is a bit akward since the OrderedDict is not slice-able + # https://stackoverflow.com/questions/30975339/slicing-a-python-ordereddict + sig_without_op_args = islice(sig_full, len(self.op_args), sys.maxsize) + + if any(str(param).startswith("**") for param in sig_without_op_args): + # If there is a **kwargs, **context or **_ then just dump everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { name: context[name] - for name in signature(self.python_callable).parameters.keys() + for name in sig_without_op_args if name in context # If it isn't available on the context, then ignore } diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 76867f1b4ab7..852a6bdf8970 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -250,7 +250,7 @@ def test_echo_env_variables(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_dynamic_provide_context(self): + def test_conflicting_kwargs(self): def fn(dag): if dag != 1: raise ValueError("Should be 1") @@ -261,7 +261,7 @@ def fn(dag): ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_dynamic_provide_context(self): + def test_context_with_conflicting_op_args(self): def fn(dag): if dag != 1: raise ValueError("Should be 1") From d99bb9c96b56aa16e0cd26877f743ee7e952d8df Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 22:30:21 +0200 Subject: [PATCH 06/15] Extend the tests --- tests/operators/test_python_operator.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 852a6bdf8970..27abed22a8b3 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -251,24 +251,44 @@ def test_echo_env_variables(self): t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_conflicting_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + def fn(dag): if dag != 1: raise ValueError("Should be 1") python_operator = PythonOperator( + task_id='python_operator', op_kwargs={'dag': 1}, - python_callable=fn + python_callable=fn, + dag=self.dag ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_context_with_conflicting_op_args(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + def fn(dag): if dag != 1: raise ValueError("Should be 1") python_operator = PythonOperator( + task_id='python_operator', op_args=[1], - python_callable=fn + python_callable=fn, + dag=self.dag ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) From 145f83594c1cebf7456c2082b3e57e3444de65fc Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 10:41:56 +0200 Subject: [PATCH 07/15] Update the tests --- airflow/operators/python_operator.py | 22 +++++++++++++--------- tests/operators/test_python_operator.py | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 3cc8f23b71a8..6d5a92b1b339 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -100,21 +100,25 @@ def execute(self, context): context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict + context_keys = context.keys() - sig_full = signature(self.python_callable).parameters.items() - # Remove the first n arguments equal to len(op_args). - # The notation is a bit akward since the OrderedDict is not slice-able - # https://stackoverflow.com/questions/30975339/slicing-a-python-ordereddict - sig_without_op_args = islice(sig_full, len(self.op_args), sys.maxsize) + sig = signature(self.python_callable).parameters.items() + op_args_names = islice(sig, len(self.op_args)) - if any(str(param).startswith("**") for param in sig_without_op_args): + for name in op_args_names: + # Check if it part of the context + if name in context_keys: + # Raise an exception + raise ValueError("The key {} in the op_args is part of the context, and therefore reserved".format(name)) + + if any(str(param).startswith("**") for param in sig): # If there is a **kwargs, **context or **_ then just dump everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { name: context[name] - for name in sig_without_op_args + for name in sig if name in context # If it isn't available on the context, then ignore } @@ -268,8 +272,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 27abed22a8b3..8a34887098a0 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -24,6 +24,8 @@ from collections import namedtuple from datetime import timedelta, date +import pytest + from airflow.exceptions import AirflowException from airflow.models import TaskInstance as TI, DAG, DagRun from airflow.operators.dummy_operator import DummyOperator @@ -259,17 +261,20 @@ def test_conflicting_kwargs(self): external_trigger=False, ) + # dag is not allowed since it is a reserved keyword def fn(dag): - if dag != 1: - raise ValueError("Should be 1") + # An ValueError should be triggered since we're using dag as a + # reserved keyword + raise RuntimeError("Should not be triggered, dag: {}".format(dag)) python_operator = PythonOperator( task_id='python_operator', - op_kwargs={'dag': 1}, + op_args=[1], python_callable=fn, dag=self.dag ) - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + with pytest.raises(ValueError, match=r".* dag .*"): + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_context_with_conflicting_op_args(self): self.dag.create_dagrun( @@ -280,13 +285,13 @@ def test_context_with_conflicting_op_args(self): external_trigger=False, ) - def fn(dag): - if dag != 1: - raise ValueError("Should be 1") + def fn(custom, dag): + if custom != 1: + raise ValueError("Should be 1, but was {}, dag: {}".format(custom, dag)) python_operator = PythonOperator( task_id='python_operator', - op_args=[1], + op_kwargs={'custom': 1}, python_callable=fn, dag=self.dag ) From d0be33328b16b0794ef483da5938b6404c5f9242 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 11:08:30 +0200 Subject: [PATCH 08/15] Remove flake8 violations --- airflow/operators/python_operator.py | 8 +++++--- tests/contrib/hooks/test_aws_glue_catalog_hook.py | 1 + tests/core.py | 1 + tests/dags/test_dag_serialization.py | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 6d5a92b1b339..bc2b6d850ed4 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -109,7 +109,9 @@ def execute(self, context): # Check if it part of the context if name in context_keys: # Raise an exception - raise ValueError("The key {} in the op_args is part of the context, and therefore reserved".format(name)) + raise ValueError( + "The key {} in the op_args is part of the context, and therefore reserved".format(name) + ) if any(str(param).startswith("**") for param in sig): # If there is a **kwargs, **context or **_ then just dump everything. @@ -272,8 +274,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " diff --git a/tests/contrib/hooks/test_aws_glue_catalog_hook.py b/tests/contrib/hooks/test_aws_glue_catalog_hook.py index 311dbcaf3390..85b2777e88c5 100644 --- a/tests/contrib/hooks/test_aws_glue_catalog_hook.py +++ b/tests/contrib/hooks/test_aws_glue_catalog_hook.py @@ -43,6 +43,7 @@ } } + @unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") class TestAwsGlueCatalogHook(unittest.TestCase): diff --git a/tests/core.py b/tests/core.py index 62a1271139f0..f7e6f4c5f5ef 100644 --- a/tests/core.py +++ b/tests/core.py @@ -2178,6 +2178,7 @@ def test_init_proxy_user(self): HDFSHook = None # type: Optional[hdfs_hook.HDFSHook] snakebite = None # type: None + @unittest.skipIf(HDFSHook is None, "Skipping test because HDFSHook is not installed") class TestHDFSHook(unittest.TestCase): diff --git a/tests/dags/test_dag_serialization.py b/tests/dags/test_dag_serialization.py index 9618925efa31..1934c642a423 100644 --- a/tests/dags/test_dag_serialization.py +++ b/tests/dags/test_dag_serialization.py @@ -158,7 +158,7 @@ def make_example_dags(module, dag_ids): def make_simple_dag(): """Make very simple DAG to verify serialization result.""" dag = DAG(dag_id='simple_dag') - _ = BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) + BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) return {'simple_dag': dag} @@ -186,7 +186,7 @@ def compute_next_execution_date(dag, execution_date): }, catchup=False ) - _ = BashOperator( + BashOperator( task_id='echo', bash_command='echo "{{ next_execution_date(dag, execution_date) }}"', dag=dag, From 3fc2e98f8f29a5056373573847148be012363a88 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 11:40:24 +0200 Subject: [PATCH 09/15] Works on my machine (using Breeze :-) --- airflow/operators/python_operator.py | 14 +++++++---- tests/operators/test_python_operator.py | 31 +++++++++++++++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index bc2b6d850ed4..ec846556543b 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -104,26 +104,30 @@ def execute(self, context): sig = signature(self.python_callable).parameters.items() op_args_names = islice(sig, len(self.op_args)) - - for name in op_args_names: + for name, _ in op_args_names: # Check if it part of the context if name in context_keys: - # Raise an exception + # Raise an exception to let the user know that the keyword is reserved raise ValueError( "The key {} in the op_args is part of the context, and therefore reserved".format(name) ) - if any(str(param).startswith("**") for param in sig): + print(sig) + + if any(str(param).startswith("**") for _, param in sig): # If there is a **kwargs, **context or **_ then just dump everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { name: context[name] - for name in sig + for name, _ in sig if name in context # If it isn't available on the context, then ignore } + print(self.op_kwargs) + print(sig) + return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) return return_value diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 8a34887098a0..497a00193926 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -24,8 +24,6 @@ from collections import namedtuple from datetime import timedelta, date -import pytest - from airflow.exceptions import AirflowException from airflow.models import TaskInstance as TI, DAG, DagRun from airflow.operators.dummy_operator import DummyOperator @@ -273,8 +271,10 @@ def fn(dag): python_callable=fn, dag=self.dag ) - with pytest.raises(ValueError, match=r".* dag .*"): + + with self.assertRaises(ValueError) as context: python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.assertTrue('dag' in context.exception, "'dag' not found in the exception") def test_context_with_conflicting_op_args(self): self.dag.create_dagrun( @@ -286,8 +286,29 @@ def test_context_with_conflicting_op_args(self): ) def fn(custom, dag): - if custom != 1: - raise ValueError("Should be 1, but was {}, dag: {}".format(custom, dag)) + self.assertEqual(1, custom, "custom should be 1") + self.assertIsNotNone(dag, "dag should be set") + + python_operator = PythonOperator( + task_id='python_operator', + op_kwargs={'custom': 1}, + python_callable=fn, + dag=self.dag + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_context_with_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + def fn(**context): + # check if context is being set + self.assertGreater(len(context), 0, "Context has not been injected") python_operator = PythonOperator( task_id='python_operator', From c9995eef84afbbb476476d4411f57b62e39cfddf Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 12:51:26 +0200 Subject: [PATCH 10/15] Clean up occurrences with provide_context --- UPDATING.md | 5 ++ ...kins_job_trigger_operator.py.notexecutable | 5 +- airflow/contrib/sensors/python_sensor.py | 27 ++++------- airflow/gcp/utils/mlengine_operator_utils.py | 1 - airflow/operators/python_operator.py | 42 +++++++++-------- airflow/sensors/http_sensor.py | 46 +++++++++---------- docs/concepts.rst | 13 +++--- tests/contrib/sensors/test_file_sensor.py | 3 +- .../sensors/test_gcs_upload_session_sensor.py | 1 - tests/sensors/test_http_sensor.py | 13 ++---- 10 files changed, 72 insertions(+), 84 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index daac8adf23b5..dd80e9e25266 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -37,8 +37,13 @@ assists users migrating to a new version. - [Airflow 1.7.1.2](#airflow-1712) + ## Airflow Master +### Remove provide_context + +Instead of settings `provide_context` we're automagically inferring the signature of the callable that is being passed to the PythonOperator. The only behavioural change in is that using a key that is already in the context in the function, such as `dag` or `ds` is not allowed anymore and will thrown an exception. If the `provide_context` is still explicitly passed to the function, it will just end up in the `kwargs`, which can cause no harm. + ### Change dag loading duration metric name Change DAG file loading duration metric from `dag.loading-duration.` to `dag.loading-duration.`. This is to diff --git a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable index 2d8906b0b3b9..c0c3df61f5c1 100644 --- a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable +++ b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,6 @@ def grabArtifactFromJenkins(**context): artifact_grabber = PythonOperator( task_id='artifact_grabber', - provide_context=True, python_callable=grabArtifactFromJenkins, dag=dag) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index a2dc2031a86b..ec611be3a236 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -16,9 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults +from typing import Optional, Iterable, Dict, Callable class PythonSensor(BaseSensorOperator): @@ -38,12 +40,6 @@ class PythonSensor(BaseSensorOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -56,24 +52,21 @@ class PythonSensor(BaseSensorOperator): @apply_defaults def __init__( self, - python_callable, - op_args=None, - op_kwargs=None, - provide_context=False, - templates_dict=None, + python_callable: Callable, + op_args: Optional[Iterable] = None, + op_kwargs: Optional[Dict] = None, + templates_dict: Optional[Dict] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict - def poke(self, context): - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + def poke(self, context: Dict): + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) self.log.info("Poking callable: %s", str(self.python_callable)) return_value = self.python_callable(*self.op_args, **self.op_kwargs) diff --git a/airflow/gcp/utils/mlengine_operator_utils.py b/airflow/gcp/utils/mlengine_operator_utils.py index 66cdad8a171d..658a2088b542 100644 --- a/airflow/gcp/utils/mlengine_operator_utils.py +++ b/airflow/gcp/utils/mlengine_operator_utils.py @@ -240,7 +240,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index ec846556543b..c760dbec20d3 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, Tuple import dill @@ -90,20 +90,13 @@ def __init__( if templates_exts: self.template_ext = templates_exts - def execute(self, context): - # Export context to make it available for callables to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - self.log.info("Exporting the following env vars:\n%s", - '\n'.join(["{}={}".format(k, v) - for k, v in airflow_context_vars.items()])) - os.environ.update(airflow_context_vars) - - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict + @staticmethod + def determine_op_kwargs(python_callable: Callable, + context: Dict, + num_op_args: int = 0) -> Dict: context_keys = context.keys() - - sig = signature(self.python_callable).parameters.items() - op_args_names = islice(sig, len(self.op_args)) + sig = signature(python_callable).parameters.items() + op_args_names = islice(sig, num_op_args) for name, _ in op_args_names: # Check if it part of the context if name in context_keys: @@ -112,21 +105,30 @@ def execute(self, context): "The key {} in the op_args is part of the context, and therefore reserved".format(name) ) - print(sig) - if any(str(param).startswith("**") for _, param in sig): # If there is a **kwargs, **context or **_ then just dump everything. - self.op_kwargs = context + op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) - self.op_kwargs = { + op_kwargs = { name: context[name] for name, _ in sig if name in context # If it isn't available on the context, then ignore } + return op_kwargs + + def execute(self, context: Dict): + # Export context to make it available for callables to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info("Exporting the following env vars:\n%s", + '\n'.join(["{}={}".format(k, v) + for k, v in airflow_context_vars.items()])) + os.environ.update(airflow_context_vars) + + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict - print(self.op_kwargs) - print(sig) + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 6e5d69946232..3102d33f4c04 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -16,6 +16,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict, Callable + +from airflow.operators.python_operator import PythonOperator from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook @@ -31,13 +34,17 @@ class HttpSensor(BaseSensorOperator): HTTP Error codes other than 404 (like 403) or Connection Refused Error would fail the sensor itself directly (no more poking). - The response check can access the template context by passing ``provide_context=True`` to the operator:: + The response check can access the template context to the operator: - def response_check(response, **context): - # Can look at context['ti'] etc. + def response_check(response, task_instance): + # The task_instance is injected, so you can pull data form xcom + # Other context variables such as dag, ds, execution_date are also available. + xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + # In practice you would do something more sensible with this data.. + print(xcom_data) return True - HttpSensor(task_id='my_http_sensor', ..., provide_context=True, response_check=response_check) + HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) :param http_conn_id: The connection to run the sensor against @@ -50,12 +57,6 @@ def response_check(response, **context): :type request_params: a dictionary of string key/value pairs :param headers: The HTTP headers to be added to the GET request :type headers: a dictionary of string key/value pairs - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define context in your - function header. - :type provide_context: bool :param response_check: A check against the 'requests' response object. Returns True for 'pass' and False otherwise. :type response_check: A lambda or defined function. @@ -69,14 +70,13 @@ def response_check(response, **context): @apply_defaults def __init__(self, - endpoint, - http_conn_id='http_default', - method='GET', - request_params=None, - headers=None, - response_check=None, - provide_context=False, - extra_options=None, *args, **kwargs): + endpoint: str, + http_conn_id: str = 'http_default', + method: str = 'GET', + request_params: Dict = None, + headers: Dict = None, + response_check: Callable = None, + extra_options: Dict = None, *args, **kwargs): super().__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -84,13 +84,12 @@ def __init__(self, self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check - self.provide_context = provide_context self.hook = HttpHook( method=method, http_conn_id=http_conn_id) - def poke(self, context): + def poke(self, context: Dict): self.log.info('Poking: %s', self.endpoint) try: response = self.hook.run(self.endpoint, @@ -98,10 +97,9 @@ def poke(self, context): headers=self.headers, extra_options=self.extra_options) if self.response_check: - if self.provide_context: - return self.response_check(response, **context) - else: - return self.response_check(response) + op_kwargs = PythonOperator.determine_op_kwargs(self.response_check, context) + return self.response_check(response, **op_kwargs) + except AirflowException as ae: if str(ae).startswith("404"): return False diff --git a/docs/concepts.rst b/docs/concepts.rst index d258b4da3815..96930d185a8d 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -548,9 +548,12 @@ passed, then a corresponding list of XCom values is returned. def push_function(): return value - # inside another PythonOperator where provide_context=True - def pull_function(**context): - value = context['task_instance'].xcom_pull(task_ids='pushing_task') + # inside another PythonOperator + def pull_function(task_instance): + value = task_instance.xcom_pull(task_ids='pushing_task') + +When specifying arguments that are part of the context, they will be +automatically passed to the function. It is also possible to pull XCom directly in a template, here's an example of what this may look like: @@ -632,8 +635,7 @@ For example: .. code:: python - def branch_func(**kwargs): - ti = kwargs['ti'] + def branch_func(ti): xcom_value = int(ti.xcom_pull(task_ids='start_task')) if xcom_value >= 5: return 'continue_task' @@ -648,7 +650,6 @@ For example: branch_op = BranchPythonOperator( task_id='branch_task', - provide_context=True, python_callable=branch_func, dag=dag) diff --git a/tests/contrib/sensors/test_file_sensor.py b/tests/contrib/sensors/test_file_sensor.py index 34720f5cbccc..8d520ce2572c 100644 --- a/tests/contrib/sensors/test_file_sensor.py +++ b/tests/contrib/sensors/test_file_sensor.py @@ -49,8 +49,7 @@ def setUp(self): hook = FSHook() args = { 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True + 'start_date': DEFAULT_DATE } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/sensors/test_gcs_upload_session_sensor.py b/tests/contrib/sensors/test_gcs_upload_session_sensor.py index c23083592312..ea4515b6ecb5 100644 --- a/tests/contrib/sensors/test_gcs_upload_session_sensor.py +++ b/tests/contrib/sensors/test_gcs_upload_session_sensor.py @@ -62,7 +62,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 37a0a48e5090..40c374eb2185 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -96,19 +96,13 @@ def resp_check(_): @patch("airflow.hooks.http_hook.requests.Session.send") def test_poke_context(self, mock_session_send): - """ - test provide_context - """ response = requests.Response() response.status_code = 200 mock_session_send.return_value = response - def resp_check(resp, **context): - if context: - if "execution_date" in context: - if context["execution_date"] == DEFAULT_DATE: - return True - + def resp_check(resp, execution_date): + if execution_date == DEFAULT_DATE: + return True raise AirflowException('AirflowException raised here!') task = HttpSensor( @@ -117,7 +111,6 @@ def resp_check(resp, **context): endpoint='', request_params={}, response_check=resp_check, - provide_context=True, timeout=5, poke_interval=1, dag=self.dag) From d7f35d6b8e0061ef9d8b3fcec09d40c0e6297c16 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 13:02:00 +0200 Subject: [PATCH 11/15] Some cleanup --- airflow/operators/python_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index c760dbec20d3..f1e6d1dbd15f 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable, Tuple +from typing import Optional, Iterable, Dict, Callable import dill From 7ab68bc90d4cf34d93ef01f491df1e425e2c89d7 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 13:47:47 +0200 Subject: [PATCH 12/15] Fix the types --- airflow/contrib/sensors/python_sensor.py | 4 ++-- airflow/operators/python_operator.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index ec611be3a236..146ab7ec39ba 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -20,7 +20,7 @@ from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List class PythonSensor(BaseSensorOperator): @@ -53,7 +53,7 @@ class PythonSensor(BaseSensorOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, templates_dict: Optional[Dict] = None, *args, **kwargs): diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index f1e6d1dbd15f..49b89658aa9c 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List import dill @@ -73,10 +73,10 @@ class PythonOperator(BaseOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, templates_dict: Optional[Dict] = None, - templates_exts: Optional[Iterable[str]] = None, + templates_exts: Optional[List[str]] = None, *args, **kwargs ) -> None: @@ -94,6 +94,15 @@ def __init__( def determine_op_kwargs(python_callable: Callable, context: Dict, num_op_args: int = 0) -> Dict: + """ + Function that will inspect the signature of a python_callable to determine which + values need to be passed to the function. + + :param python_callable: The function that you want to invoke + :param context: The context provided by the execute method of the Operator/Sensor + :param num_op_args: The number of op_args provided, so we know how many to skip + :return: The op_args dictionary which contains the values that are compatible with the Callable + """ context_keys = context.keys() sig = signature(python_callable).parameters.items() op_args_names = islice(sig, num_op_args) @@ -152,7 +161,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin): to be inferred. """ - def execute(self, context): + def execute(self, context: Dict): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -170,7 +179,7 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ - def execute(self, context): + def execute(self, context: Dict): condition = super().execute(context) self.log.info("Condition result is %s", condition) From 2cde97b63a1670185de714f0bfd61ed1604c3db6 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 13:56:27 +0200 Subject: [PATCH 13/15] Less is more --- airflow/contrib/sensors/python_sensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index 146ab7ec39ba..a4e5ec77aa52 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -20,7 +20,7 @@ from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from typing import Optional, Iterable, Dict, Callable, List +from typing import Optional, Dict, Callable, List class PythonSensor(BaseSensorOperator): From 95764ee35ca8295099c1dd95f3458b93f26c7795 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 14:35:58 +0200 Subject: [PATCH 14/15] Patch tests --- tests/sensors/test_http_sensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 40c374eb2185..1b6c3667ba65 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -63,7 +63,7 @@ def resp_check(resp): timeout=5, poke_interval=1) with self.assertRaisesRegex(AirflowException, 'AirflowException raised here!'): - task.execute(None) + task.execute(context={}) @patch("airflow.hooks.http_hook.requests.Session.send") def test_head_method(self, mock_session_send): @@ -81,7 +81,7 @@ def resp_check(_): timeout=5, poke_interval=1) - task.execute(None) + task.execute(context={}) args, kwargs = mock_session_send.call_args received_request = args[0] From b592e68f7ed0bff18dbbe4b326d18e4db97afe3f Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 10 Sep 2019 12:59:58 +0200 Subject: [PATCH 15/15] Feedback from Bas --- UPDATING.md | 38 +++++++++++++++++++++++++++- airflow/operators/python_operator.py | 4 +-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 6792ef694807..306a2da45478 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -42,7 +42,43 @@ assists users migrating to a new version. ### Remove provide_context -Instead of settings `provide_context` we're automagically inferring the signature of the callable that is being passed to the PythonOperator. The only behavioural change in is that using a key that is already in the context in the function, such as `dag` or `ds` is not allowed anymore and will thrown an exception. If the `provide_context` is still explicitly passed to the function, it will just end up in the `kwargs`, which can cause no harm. +`provide_context` argument on the PythonOperator was removed. The signature of the callable passed to the PythonOperator is now inferred and argument values are always automatically provided. There is no need to explicitly provide or not provide the context anymore. For example: + +```python +def myfunc(execution_date): + print(execution_date) + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc, dag=dag) +``` + +Notice you don't have to set provide_context=True, variables from the task context are now automatically detected and provided. + +All context variables can still be provided with a double-asterisk argument: + +```python +def myfunc(**context): + print(context) # all variables will be provided to context + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc) +``` + +The task context variable names are reserved names in the callable function, hence a clash with `op_args` and `op_kwargs` results in an exception: + +```python +def myfunc(dag): + # raises a ValueError because "dag" is a reserved name + # valid signature example: myfunc(mydag) + +python_operator = PythonOperator( + task_id='mytask', + op_args=[1], + python_callable=myfunc, +) +``` + +The change is backwards compatible, setting `provide_context` will add the `provide_context` variable to the `kwargs` (but won't do anything). + +PR: [#5990](https://github.com/apache/airflow/pull/5990) ### Changes to FileSensor diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 49b89658aa9c..4d3c8da19f60 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -107,7 +107,7 @@ def determine_op_kwargs(python_callable: Callable, sig = signature(python_callable).parameters.items() op_args_names = islice(sig, num_op_args) for name, _ in op_args_names: - # Check if it part of the context + # Check if it is part of the context if name in context_keys: # Raise an exception to let the user know that the keyword is reserved raise ValueError( @@ -115,7 +115,7 @@ def determine_op_kwargs(python_callable: Callable, ) if any(str(param).startswith("**") for _, param in sig): - # If there is a **kwargs, **context or **_ then just dump everything. + # If there is a ** argument then just dump everything. op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-)