Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-103] Allow jinja templates to be used in task params #1488

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,10 +1417,7 @@ def get_template_context(self, session=None):
session.expunge_all()
session.commit()

if task.params:
params.update(task.params)

return {
context = {
'dag': task.dag,
'ds': ds,
'ds_nodash': ds_nodash,
Expand All @@ -1447,6 +1444,18 @@ def get_template_context(self, session=None):
'test_mode': self.test_mode,
}

# Allow task level param definitions to be rendered via jinja
# using the context available up until this point
if task.params:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done in the render_templates and apply only on strings (use isinstance(obj, basestring) to check)

Copy link
Member

@mistercrunch mistercrunch Jun 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still holds. It hasn't been addressed or contested.

if task.render_params:
rt = self.task.render_template # shortcut to method
rendered_content = rt('params', task.params, context)
params.update(rendered_content)
else:
params.update(task.params)

return context

def render_templates(self):
task = self.task
jinja_context = self.get_template_context()
Expand Down Expand Up @@ -1712,6 +1721,9 @@ class derived from this one results in the creation of a task object,
:param on_success_callback: much like the ``on_failure_callback`` excepts
that it is executed when the task succeeds.
:type on_success_callback: callable
:param render_params: set this to true to allow params to be rendered
and available to other templates
:type render_params: bool
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
``{ all_success | all_failed | all_done | one_success |
Expand Down Expand Up @@ -1757,6 +1769,7 @@ def __init__(
on_failure_callback=None,
on_success_callback=None,
on_retry_callback=None,
render_params=False,
trigger_rule=TriggerRule.ALL_SUCCESS,
*args,
**kwargs):
Expand Down Expand Up @@ -1790,6 +1803,7 @@ def __init__(
.format(all_triggers=TriggerRule.all_triggers,
d=dag.dag_id, t=task_id, tr = trigger_rule))

self.render_params = render_params
self.trigger_rule = trigger_rule
self.depends_on_past = depends_on_past
self.wait_for_downstream = wait_for_downstream
Expand Down Expand Up @@ -2048,11 +2062,7 @@ def render_template_from_field(self, attr, content, context, jinja_env):
k: rt("{}[{}]".format(attr, k), v, context)
for k, v in list(content.items())}
else:
param_type = type(content)
msg = (
"Type '{param_type}' used for parameter '{attr}' is "
"not supported for templating").format(**locals())
raise AirflowException(msg)
result = content
return result

def render_template(self, attr, content, context):
Expand Down
68 changes: 68 additions & 0 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,74 @@ def test_py_op(templates_dict, ds, **kwargs):
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)


def test_template_context_simple(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE)
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag)
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {}
}
self.assertDictContainsSubset(expected, context)

def test_template_context_with_static_params(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE,
params={'foo': 'dag', 'bar': 'dag'})
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag,
params={'foo': 'task', 'boolean': True})
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {'foo': 'task', 'bar': 'dag', 'boolean': True}
}
self.assertDictContainsSubset(expected, context)

def test_template_context_with_dynamic_params(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE,
params={'foo': 'dag', 'bar': 'dag'})
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag,
params={'foo': '{{ ds }}'})
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {'foo': '{{ ds }}', 'bar': 'dag'}
}
self.assertDictContainsSubset(expected, context)

def test_template_context_with_dynamic_params_and_render_params(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE,
params={'foo': 'dag', 'bar': 'dag'})
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag,
render_params=True,
params={'foo': '{{ ds }}'})
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {'foo': '2015-01-01', 'bar': 'dag'}
}
self.assertDictContainsSubset(expected, context)

def test_complex_template(self):
class OperatorSubclass(operators.BaseOperator):
template_fields = ['some_templated_field']
Expand Down