diff --git a/UPDATING.md b/UPDATING.md index 98dc7faf07ee..306a2da45478 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -37,9 +37,51 @@ assists users migrating to a new version. - [Airflow 1.7.1.2](#airflow-1712) + ## Airflow Master +### Remove provide_context + +`provide_context` argument on the PythonOperator was removed. The signature of the callable passed to the PythonOperator is now inferred and argument values are always automatically provided. There is no need to explicitly provide or not provide the context anymore. For example: + +```python +def myfunc(execution_date): + print(execution_date) + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc, dag=dag) +``` + +Notice you don't have to set provide_context=True, variables from the task context are now automatically detected and provided. + +All context variables can still be provided with a double-asterisk argument: + +```python +def myfunc(**context): + print(context) # all variables will be provided to context + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc) +``` + +The task context variable names are reserved names in the callable function, hence a clash with `op_args` and `op_kwargs` results in an exception: + +```python +def myfunc(dag): + # raises a ValueError because "dag" is a reserved name + # valid signature example: myfunc(mydag) + +python_operator = PythonOperator( + task_id='mytask', + op_args=[1], + python_callable=myfunc, +) +``` + +The change is backwards compatible, setting `provide_context` will add the `provide_context` variable to the `kwargs` (but won't do anything). + +PR: [#5990](https://github.com/apache/airflow/pull/5990) + ### Changes to FileSensor + FileSensor is now takes a glob pattern, not just a filename. If the filename you are looking for has `*`, `?`, or `[` in it then you should replace these with `[*]`, `[?]`, and `[[]`. ### Change dag loading duration metric name diff --git a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable index 2d8906b0b3b9..c0c3df61f5c1 100644 --- a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable +++ b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,6 @@ def grabArtifactFromJenkins(**context): artifact_grabber = PythonOperator( task_id='artifact_grabber', - provide_context=True, python_callable=grabArtifactFromJenkins, dag=dag) diff --git a/airflow/contrib/example_dags/example_qubole_operator.py b/airflow/contrib/example_dags/example_qubole_operator.py index 1f7e2a8ce9d8..ef4681a85b79 100644 --- a/airflow/contrib/example_dags/example_qubole_operator.py +++ b/airflow/contrib/example_dags/example_qubole_operator.py @@ -97,7 +97,6 @@ def compare_result(**kwargs): t3 = PythonOperator( task_id='compare_result', - provide_context=True, python_callable=compare_result, trigger_rule="all_done", dag=dag) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index a2dc2031a86b..a4e5ec77aa52 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -16,9 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults +from typing import Optional, Dict, Callable, List class PythonSensor(BaseSensorOperator): @@ -38,12 +40,6 @@ class PythonSensor(BaseSensorOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -56,24 +52,21 @@ class PythonSensor(BaseSensorOperator): @apply_defaults def __init__( self, - python_callable, - op_args=None, - op_kwargs=None, - provide_context=False, - templates_dict=None, + python_callable: Callable, + op_args: Optional[List] = None, + op_kwargs: Optional[Dict] = None, + templates_dict: Optional[Dict] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict - def poke(self, context): - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + def poke(self, context: Dict): + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) self.log.info("Poking callable: %s", str(self.python_callable)) return_value = self.python_callable(*self.op_args, **self.op_kwargs) diff --git a/airflow/example_dags/docker_copy_data.py b/airflow/example_dags/docker_copy_data.py index f091969777ee..484f82f683df 100644 --- a/airflow/example_dags/docker_copy_data.py +++ b/airflow/example_dags/docker_copy_data.py @@ -69,7 +69,6 @@ # # t_is_data_available = ShortCircuitOperator( # task_id='check_if_data_available', -# provide_context=True, # python_callable=is_data_available, # dag=dag) # diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index ec60cfc01b90..7455ef7ebbd2 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -58,7 +58,6 @@ def should_run(**kwargs): cond = BranchPythonOperator( task_id='condition', - provide_context=True, python_callable=should_run, dag=dag, ) diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 152b8cde9e63..e8fc9c963a91 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -37,17 +37,17 @@ ) -def my_py_command(**kwargs): +def my_py_command(test_mode, params): """ Print out the "foo" param passed in via `airflow tasks test example_passing_params_via_test_command run_this -tp '{"foo":"bar"}'` """ - if kwargs["test_mode"]: + if test_mode: print(" 'foo' was passed in via test={} command : kwargs[params][foo] \ - = {}".format(kwargs["test_mode"], kwargs["params"]["foo"])) + = {}".format(test_mode, params["foo"])) # Print out the value of "miff", passed in below via the Python Operator - print(" 'miff' was passed in via task params = {}".format(kwargs["params"]["miff"])) + print(" 'miff' was passed in via task params = {}".format(params["miff"])) return 1 @@ -58,7 +58,6 @@ def my_py_command(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=my_py_command, params={"miff": "agg"}, dag=dag, diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 29c664f0a65e..86403ceb25b6 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -48,7 +48,6 @@ def print_context(ds, **kwargs): run_this = PythonOperator( task_id='print_the_context', - provide_context=True, python_callable=print_context, dag=dag, ) diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 475817698115..32255103d804 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -69,7 +69,6 @@ def run_this_func(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=run_this_func, dag=dag, ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index 8bd8e93b38cf..fb043b021ebc 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -26,7 +26,6 @@ args = { 'owner': 'Airflow', 'start_date': airflow.utils.dates.days_ago(2), - 'provide_context': True, } dag = DAG('example_xcom', schedule_interval="@once", default_args=args) diff --git a/airflow/gcp/utils/mlengine_operator_utils.py b/airflow/gcp/utils/mlengine_operator_utils.py index d09c2318061f..dd44f7c543d2 100644 --- a/airflow/gcp/utils/mlengine_operator_utils.py +++ b/airflow/gcp/utils/mlengine_operator_utils.py @@ -241,7 +241,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 46430b215e93..4d3c8da19f60 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -23,8 +23,10 @@ import subprocess import sys import types +from inspect import signature +from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List import dill @@ -51,12 +53,6 @@ class PythonOperator(BaseOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list (templated) - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -77,11 +73,10 @@ class PythonOperator(BaseOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, - provide_context: bool = False, templates_dict: Optional[Dict] = None, - templates_exts: Optional[Iterable[str]] = None, + templates_exts: Optional[List[str]] = None, *args, **kwargs ) -> None: @@ -91,12 +86,47 @@ def __init__( self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict if templates_exts: self.template_ext = templates_exts - def execute(self, context): + @staticmethod + def determine_op_kwargs(python_callable: Callable, + context: Dict, + num_op_args: int = 0) -> Dict: + """ + Function that will inspect the signature of a python_callable to determine which + values need to be passed to the function. + + :param python_callable: The function that you want to invoke + :param context: The context provided by the execute method of the Operator/Sensor + :param num_op_args: The number of op_args provided, so we know how many to skip + :return: The op_args dictionary which contains the values that are compatible with the Callable + """ + context_keys = context.keys() + sig = signature(python_callable).parameters.items() + op_args_names = islice(sig, num_op_args) + for name, _ in op_args_names: + # Check if it is part of the context + if name in context_keys: + # Raise an exception to let the user know that the keyword is reserved + raise ValueError( + "The key {} in the op_args is part of the context, and therefore reserved".format(name) + ) + + if any(str(param).startswith("**") for _, param in sig): + # If there is a ** argument then just dump everything. + op_kwargs = context + else: + # If there is only for example, an execution_date, then pass only these in :-) + op_kwargs = { + name: context[name] + for name, _ in sig + if name in context # If it isn't available on the context, then ignore + } + return op_kwargs + + def execute(self, context: Dict): # Export context to make it available for callables to use. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) self.log.info("Exporting the following env vars:\n%s", @@ -104,10 +134,10 @@ def execute(self, context): for k, v in airflow_context_vars.items()])) os.environ.update(airflow_context_vars) - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) @@ -130,7 +160,8 @@ class BranchPythonOperator(PythonOperator, SkipMixin): downstream to allow for the DAG state to fill up and the DAG run's state to be inferred. """ - def execute(self, context): + + def execute(self, context: Dict): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -147,7 +178,8 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ - def execute(self, context): + + def execute(self, context: Dict): condition = super().execute(context) self.log.info("Condition result is %s", condition) @@ -200,12 +232,6 @@ class PythonVirtualenvOperator(PythonOperator): :type op_kwargs: list :param op_kwargs: A dict of keyword arguments to pass to python_callable. :type op_kwargs: dict - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param string_args: Strings that are present in the global var virtualenv_string_args, available to python_callable at runtime as a list[str]. Note that args are split by newline. @@ -219,6 +245,7 @@ class PythonVirtualenvOperator(PythonOperator): processing templated fields, for examples ``['.sql', '.hql']`` :type templates_exts: list[str] """ + @apply_defaults def __init__( self, @@ -229,7 +256,6 @@ def __init__( system_site_packages: bool = True, op_args: Iterable = None, op_kwargs: Dict = None, - provide_context: bool = False, string_args: Optional[Iterable[str]] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, @@ -242,7 +268,6 @@ def __init__( op_kwargs=op_kwargs, templates_dict=templates_dict, templates_exts=templates_exts, - provide_context=provide_context, *args, **kwargs) self.requirements = requirements or [] @@ -264,8 +289,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " @@ -383,7 +408,7 @@ def _generate_python_code(self): fn = self.python_callable # dont try to read pickle if we didnt pass anything if self._pass_op_args(): - load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)'\ + load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \ .format(pickling_library) else: load_args_line = 'arg_dict = {"args": [], "kwargs": {}}' diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 6e5d69946232..3102d33f4c04 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -16,6 +16,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict, Callable + +from airflow.operators.python_operator import PythonOperator from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook @@ -31,13 +34,17 @@ class HttpSensor(BaseSensorOperator): HTTP Error codes other than 404 (like 403) or Connection Refused Error would fail the sensor itself directly (no more poking). - The response check can access the template context by passing ``provide_context=True`` to the operator:: + The response check can access the template context to the operator: - def response_check(response, **context): - # Can look at context['ti'] etc. + def response_check(response, task_instance): + # The task_instance is injected, so you can pull data form xcom + # Other context variables such as dag, ds, execution_date are also available. + xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + # In practice you would do something more sensible with this data.. + print(xcom_data) return True - HttpSensor(task_id='my_http_sensor', ..., provide_context=True, response_check=response_check) + HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) :param http_conn_id: The connection to run the sensor against @@ -50,12 +57,6 @@ def response_check(response, **context): :type request_params: a dictionary of string key/value pairs :param headers: The HTTP headers to be added to the GET request :type headers: a dictionary of string key/value pairs - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define context in your - function header. - :type provide_context: bool :param response_check: A check against the 'requests' response object. Returns True for 'pass' and False otherwise. :type response_check: A lambda or defined function. @@ -69,14 +70,13 @@ def response_check(response, **context): @apply_defaults def __init__(self, - endpoint, - http_conn_id='http_default', - method='GET', - request_params=None, - headers=None, - response_check=None, - provide_context=False, - extra_options=None, *args, **kwargs): + endpoint: str, + http_conn_id: str = 'http_default', + method: str = 'GET', + request_params: Dict = None, + headers: Dict = None, + response_check: Callable = None, + extra_options: Dict = None, *args, **kwargs): super().__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -84,13 +84,12 @@ def __init__(self, self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check - self.provide_context = provide_context self.hook = HttpHook( method=method, http_conn_id=http_conn_id) - def poke(self, context): + def poke(self, context: Dict): self.log.info('Poking: %s', self.endpoint) try: response = self.hook.run(self.endpoint, @@ -98,10 +97,9 @@ def poke(self, context): headers=self.headers, extra_options=self.extra_options) if self.response_check: - if self.provide_context: - return self.response_check(response, **context) - else: - return self.response_check(response) + op_kwargs = PythonOperator.determine_op_kwargs(self.response_check, context) + return self.response_check(response, **op_kwargs) + except AirflowException as ae: if str(ae).startswith("404"): return False diff --git a/docs/concepts.rst b/docs/concepts.rst index db4a0f34d23b..83c9f0e1c79f 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -549,9 +549,12 @@ passed, then a corresponding list of XCom values is returned. def push_function(): return value - # inside another PythonOperator where provide_context=True - def pull_function(**context): - value = context['task_instance'].xcom_pull(task_ids='pushing_task') + # inside another PythonOperator + def pull_function(task_instance): + value = task_instance.xcom_pull(task_ids='pushing_task') + +When specifying arguments that are part of the context, they will be +automatically passed to the function. It is also possible to pull XCom directly in a template, here's an example of what this may look like: @@ -633,8 +636,7 @@ For example: .. code:: python - def branch_func(**kwargs): - ti = kwargs['ti'] + def branch_func(ti): xcom_value = int(ti.xcom_pull(task_ids='start_task')) if xcom_value >= 5: return 'continue_task' @@ -649,7 +651,6 @@ For example: branch_op = BranchPythonOperator( task_id='branch_task', - provide_context=True, python_callable=branch_func, dag=dag) diff --git a/docs/howto/operator/python.rst b/docs/howto/operator/python.rst index d0a0da4fb7e4..e3ff9d61ac5e 100644 --- a/docs/howto/operator/python.rst +++ b/docs/howto/operator/python.rst @@ -44,9 +44,9 @@ to the Python callable. Templating ^^^^^^^^^^ -When you set the ``provide_context`` argument to ``True``, Airflow passes in -an additional set of keyword arguments: one for each of the :doc:`Jinja -template variables <../../macros-ref>` and a ``templates_dict`` argument. +Airflow passes in an additional set of keyword arguments: one for each of the +:doc:`Jinja template variables <../../macros-ref>` and a ``templates_dict`` +argument. The ``templates_dict`` argument is templated, so each value in the dictionary is evaluated as a :ref:`Jinja template `. diff --git a/tests/contrib/hooks/test_aws_glue_catalog_hook.py b/tests/contrib/hooks/test_aws_glue_catalog_hook.py index 311dbcaf3390..85b2777e88c5 100644 --- a/tests/contrib/hooks/test_aws_glue_catalog_hook.py +++ b/tests/contrib/hooks/test_aws_glue_catalog_hook.py @@ -43,6 +43,7 @@ } } + @unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") class TestAwsGlueCatalogHook(unittest.TestCase): diff --git a/tests/contrib/operators/test_aws_athena_operator.py b/tests/contrib/operators/test_aws_athena_operator.py index b86bbab0f136..c3d98e25db97 100644 --- a/tests/contrib/operators/test_aws_athena_operator.py +++ b/tests/contrib/operators/test_aws_athena_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', diff --git a/tests/contrib/operators/test_s3_to_sftp_operator.py b/tests/contrib/operators/test_s3_to_sftp_operator.py index c6cd0369557d..4d1e2fd5134b 100644 --- a/tests/contrib/operators/test_s3_to_sftp_operator.py +++ b/tests/contrib/operators/test_s3_to_sftp_operator.py @@ -69,7 +69,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 298cdb244bd4..b5aa65269115 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_to_s3_operator.py b/tests/contrib/operators/test_sftp_to_s3_operator.py index c86c51b8e8a1..00214107f6f6 100644 --- a/tests/contrib/operators/test_sftp_to_s3_operator.py +++ b/tests/contrib/operators/test_sftp_to_s3_operator.py @@ -68,7 +68,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index f66541d010a1..e8cb8d00d72a 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -51,7 +51,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/sensors/test_file_sensor.py b/tests/contrib/sensors/test_file_sensor.py index f7704ddd76d3..6d7b3f3d5e0e 100644 --- a/tests/contrib/sensors/test_file_sensor.py +++ b/tests/contrib/sensors/test_file_sensor.py @@ -50,8 +50,7 @@ def setUp(self): hook = FSHook() args = { 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True + 'start_date': DEFAULT_DATE } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/sensors/test_gcs_upload_session_sensor.py b/tests/contrib/sensors/test_gcs_upload_session_sensor.py index c23083592312..ea4515b6ecb5 100644 --- a/tests/contrib/sensors/test_gcs_upload_session_sensor.py +++ b/tests/contrib/sensors/test_gcs_upload_session_sensor.py @@ -62,7 +62,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/core.py b/tests/core.py index 0dfb9ba3d0ba..f7e6f4c5f5ef 100644 --- a/tests/core.py +++ b/tests/core.py @@ -568,7 +568,6 @@ def test_py_op(templates_dict, ds, **kwargs): t = PythonOperator( task_id='test_py_op', - provide_context=True, python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag) @@ -2179,6 +2178,7 @@ def test_init_proxy_user(self): HDFSHook = None # type: Optional[hdfs_hook.HDFSHook] snakebite = None # type: None + @unittest.skipIf(HDFSHook is None, "Skipping test because HDFSHook is not installed") class TestHDFSHook(unittest.TestCase): diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py index 7747d20710b1..64d827dc9cf8 100644 --- a/tests/dags/test_cli_triggered_dags.py +++ b/tests/dags/test_cli_triggered_dags.py @@ -50,6 +50,5 @@ def success(ti=None, *args, **kwargs): dag1_task2 = PythonOperator( task_id='test_run_dependent_task', python_callable=success, - provide_context=True, dag=dag1) dag1_task1.set_downstream(dag1_task2) diff --git a/tests/dags/test_dag_serialization.py b/tests/dags/test_dag_serialization.py index 9618925efa31..1934c642a423 100644 --- a/tests/dags/test_dag_serialization.py +++ b/tests/dags/test_dag_serialization.py @@ -158,7 +158,7 @@ def make_example_dags(module, dag_ids): def make_simple_dag(): """Make very simple DAG to verify serialization result.""" dag = DAG(dag_id='simple_dag') - _ = BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) + BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) return {'simple_dag': dag} @@ -186,7 +186,7 @@ def compute_next_execution_date(dag, execution_date): }, catchup=False ) - _ = BashOperator( + BashOperator( task_id='echo', bash_command='echo "{{ next_execution_date(dag, execution_date) }}"', dag=dag, diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 3dd8b323fc40..497a00193926 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -57,8 +57,8 @@ def build_recording_function(calls_collection): Then using this custom function recording custom Call objects for further testing (replacing Mock.assert_called_with assertion method) """ - def recording_function(*args, **kwargs): - calls_collection.append(Call(*args, **kwargs)) + def recording_function(*args): + calls_collection.append(Call(*args)) return recording_function @@ -129,11 +129,10 @@ def test_python_operator_python_callable_is_callable(self): task_id='python_operator', dag=self.dag) - def _assertCallsEqual(self, first, second): + def _assert_calls_equal(self, first, second): self.assertIsInstance(first, Call) self.assertIsInstance(second, Call) self.assertTupleEqual(first.args, second.args) - self.assertDictEqual(first.kwargs, second.kwargs) def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" @@ -148,7 +147,7 @@ def test_python_callable_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_args=[ 4, date(2019, 1, 1), @@ -167,7 +166,7 @@ def test_python_callable_arguments_are_templatized(self): ds_templated = DEFAULT_DATE.date().isoformat() self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(4, date(2019, 1, 1), @@ -183,7 +182,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), @@ -200,7 +199,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(an_int=4, a_date=date(2019, 1, 1), @@ -251,6 +250,74 @@ def test_echo_env_variables(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_conflicting_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + # dag is not allowed since it is a reserved keyword + def fn(dag): + # An ValueError should be triggered since we're using dag as a + # reserved keyword + raise RuntimeError("Should not be triggered, dag: {}".format(dag)) + + python_operator = PythonOperator( + task_id='python_operator', + op_args=[1], + python_callable=fn, + dag=self.dag + ) + + with self.assertRaises(ValueError) as context: + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.assertTrue('dag' in context.exception, "'dag' not found in the exception") + + def test_context_with_conflicting_op_args(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + def fn(custom, dag): + self.assertEqual(1, custom, "custom should be 1") + self.assertIsNotNone(dag, "dag should be set") + + python_operator = PythonOperator( + task_id='python_operator', + op_kwargs={'custom': 1}, + python_callable=fn, + dag=self.dag + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_context_with_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + def fn(**context): + # check if context is being set + self.assertGreater(len(context), 0, "Context has not been injected") + + python_operator = PythonOperator( + task_id='python_operator', + op_kwargs={'custom': 1}, + python_callable=fn, + dag=self.dag + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + class TestBranchOperator(unittest.TestCase): @classmethod diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 52c14f88806d..97c3dcc4eb47 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -199,20 +199,6 @@ def f(_): self._run_as_operator(f, op_args=[datetime.datetime.utcnow()]) def test_context(self): - def f(**kwargs): - return kwargs['templates_dict']['ds'] + def f(templates_dict): + return templates_dict['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) - - def test_provide_context(self): - def fn(): - pass - task = PythonVirtualenvOperator( - python_callable=fn, - python_version=sys.version_info[0], - task_id='task', - dag=self.dag, - provide_context=True, - ) - self.assertTrue( - task.provide_context - ) diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 37a0a48e5090..1b6c3667ba65 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -63,7 +63,7 @@ def resp_check(resp): timeout=5, poke_interval=1) with self.assertRaisesRegex(AirflowException, 'AirflowException raised here!'): - task.execute(None) + task.execute(context={}) @patch("airflow.hooks.http_hook.requests.Session.send") def test_head_method(self, mock_session_send): @@ -81,7 +81,7 @@ def resp_check(_): timeout=5, poke_interval=1) - task.execute(None) + task.execute(context={}) args, kwargs = mock_session_send.call_args received_request = args[0] @@ -96,19 +96,13 @@ def resp_check(_): @patch("airflow.hooks.http_hook.requests.Session.send") def test_poke_context(self, mock_session_send): - """ - test provide_context - """ response = requests.Response() response.status_code = 200 mock_session_send.return_value = response - def resp_check(resp, **context): - if context: - if "execution_date" in context: - if context["execution_date"] == DEFAULT_DATE: - return True - + def resp_check(resp, execution_date): + if execution_date == DEFAULT_DATE: + return True raise AirflowException('AirflowException raised here!') task = HttpSensor( @@ -117,7 +111,6 @@ def resp_check(resp, **context): endpoint='', request_params={}, response_check=resp_check, - provide_context=True, timeout=5, poke_interval=1, dag=self.dag) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index e06bdae90891..7126871cb580 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -71,7 +71,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) @@ -123,7 +122,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.try_number = 2