Skip to content

Commit

Permalink
Initial changes for issue apache#39092
Browse files Browse the repository at this point in the history
  • Loading branch information
karenbraganz committed May 2, 2024
1 parent 6112745 commit 01134e3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
11 changes: 10 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,12 +2112,21 @@ def set_task_instance_state(

task = self.get_task(task_id)
task.dag = self
dagrun = self.fetch_dagrun(dag_id=self.dag_id, run_id=run_id, session=session)
tasks_to_set_state = []

tasks_to_set_state: list[Operator | tuple[Operator, int]]
if map_indexes is None:
tasks_to_set_state = [task]
else:
tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
for map_index in map_indexes:
tasks_to_set_state.append((task, map_index))

ti = dagrun.get_task_instance(task_id=task_id, session=session, map_index=map_index)
context = ti.get_template_context(session=session)
jinja_env = self.get_template_env()
ti.render_map_index(context, jinja_env=jinja_env)
print(f"Rendered Map Index: {ti.rendered_map_index}")

altered = set_state(
tasks=tasks_to_set_state,
Expand Down
41 changes: 17 additions & 24 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,9 +1250,8 @@ def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg:
str(task_instance.state).upper(),
task_instance.dag_id,
task_instance.task_id,
task_instance.run_id,
]
message = "%sMarking task as %s. dag_id=%s, task_id=%s, run_id=%s, "
message = "%sMarking task as %s. dag_id=%s, task_id=%s, "
if task_instance.map_index >= 0:
params.append(task_instance.map_index)
message += "map_index=%d, "
Expand Down Expand Up @@ -1787,17 +1786,15 @@ def generate_command(
@property
def log_url(self) -> str:
"""Log URL for TaskInstance."""
run_id = quote(self.run_id)
iso = quote(self.execution_date.isoformat())
base_url = conf.get_mandatory_value("webserver", "BASE_URL")
return (
f"{base_url}"
f"/dags"
f"/{self.dag_id}"
f"/grid"
f"?dag_run_id={run_id}"
"/log"
f"?execution_date={iso}"
f"&task_id={self.task_id}"
f"&dag_id={self.dag_id}"
f"&map_index={self.map_index}"
"&tab=logs"
)

@property
Expand Down Expand Up @@ -1857,7 +1854,7 @@ def get_task_instance(
) -> TaskInstance | TaskInstancePydantic | None:
query = (
session.query(TaskInstance)
.options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
.options(lazyload("dag_run")) # lazy load dag run to avoid locking it
.filter_by(
dag_id=dag_id,
run_id=run_id,
Expand Down Expand Up @@ -2559,10 +2556,9 @@ def _run_raw_task(
raise
self.defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s",
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
self.dag_id,
self.task_id,
self.run_id,
_date_or_empty(task_instance=self, attr="execution_date"),
_date_or_empty(task_instance=self, attr="start_date"),
)
Expand Down Expand Up @@ -2733,26 +2729,17 @@ def signal_handler(signum, frame):
get_listener_manager().hook.on_task_instance_running(
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
)

def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
"""Render named map index if the DAG author defined map_index_template at the task level."""
if jinja_env is None or (template := context.get("map_index_template")) is None:
return None
rendered_map_index = jinja_env.from_string(template).render(context)
log.debug("Map index rendered as %s", rendered_map_index)
return rendered_map_index

# Execute the task.
with set_current_context(context):
try:
result = self._execute_task(context, task_orig)
except Exception:
# If the task failed, swallow rendering error so it doesn't mask the main error.
with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
self.render_map_index(context, jinja_env=jinja_env)
raise
else: # If the task succeeded, render normally to let rendering error bubble up.
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
self.render_map_index(context, jinja_env=jinja_env)

# Run post_execute callback
self.task.post_execute(context=context, result=result)
Expand All @@ -2761,6 +2748,13 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None)
# Same metric with tagging
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Stats.incr("ti_successes", tags=self.stats_tags)

def render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
"""Render named map index if the DAG author defined map_index_template at the task level."""
if jinja_env is None or (template := context.get("map_index_template")) is None:
return None
self.rendered_map_index = jinja_env.from_string(template).render(context)
self.log.debug("Map index rendered as %s", self.rendered_map_index)

def _execute_task(self, context: Context, task_orig: Operator):
"""
Expand Down Expand Up @@ -2911,8 +2905,7 @@ def _handle_reschedule(
# Log reschedule request
session.add(
TaskReschedule(
self.task_id,
self.dag_id,
self.task,
self.run_id,
self._try_number,
actual_start_date,
Expand Down

0 comments on commit 01134e3

Please sign in to comment.