Skip to content

Commit

Permalink
Fix try_number handling when db isolation enabled (apache#38943)
Browse files Browse the repository at this point in the history
There was an error in the refresh_from_db code, and because of try_number inconsistency, the same run was going into two different log files.  There is some ugliness here, but some ugliness is unavoidable when dealing with try_number as it is right now.
  • Loading branch information
dstandish authored and utkarsharma2 committed Apr 22, 2024
1 parent 32a19c5 commit 6fa08a9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 10 deletions.
1 change: 0 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,6 @@ def get_task_instance(
)

@staticmethod
@internal_api_call
@provide_session
def fetch_task_instance(
dag_id: str,
Expand Down
27 changes: 25 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _refresh_from_db(
task_instance.end_date = ti.end_date
task_instance.duration = ti.duration
task_instance.state = ti.state
task_instance.try_number = ti._try_number # private attr to get value unaltered by accessor
task_instance.try_number = _get_private_try_number(task_instance=ti)
task_instance.max_tries = ti.max_tries
task_instance.hostname = ti.hostname
task_instance.unixname = ti.unixname
Expand Down Expand Up @@ -925,7 +925,7 @@ def _handle_failure(
TaskInstance.save_to_db(failure_context["ti"], session)


def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
def _get_try_number(*, task_instance: TaskInstance):
"""
Return the try number that a task number will be when it is actually run.
Expand All @@ -943,6 +943,23 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
return task_instance._try_number + 1


def _get_private_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
"""
Opposite of _get_try_number.
Given the value returned by try_number, return the value of _try_number that
should produce the same result.
This is needed for setting _try_number on TaskInstance from the value on PydanticTaskInstance, which has no private attrs.
:param task_instance: the task instance
:meta private:
"""
if task_instance.state == TaskInstanceState.RUNNING:
return task_instance.try_number
return task_instance.try_number - 1


def _set_try_number(*, task_instance: TaskInstance | TaskInstancePydantic, value: int) -> None:
"""
Set a task try number.
Expand Down Expand Up @@ -3000,6 +3017,12 @@ def fetch_handle_failure_context(
_stop_remaining_tasks(task_instance=ti, session=session)
else:
if ti.state == TaskInstanceState.QUEUED:
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(ti, TaskInstancePydantic):
# todo: (AIP-44) we should probably "coalesce" `ti` to TaskInstance before here
# e.g. we could make refresh_from_db return a TI and replace ti with that
raise RuntimeError("Expected TaskInstance here. Further AIP-44 work required.")
# We increase the try_number to fail the task if it fails to start after sometime
ti._try_number += 1
ti.state = State.UP_FOR_RETRY
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
duration: Optional[float]
state: Optional[str]
try_number: int
_try_number: int
max_tries: int
hostname: str
unixname: str
Expand Down
13 changes: 8 additions & 5 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, sessio
Will raise exception if no TI is found in the database.
"""
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstance import TaskInstance, _get_private_try_number
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(ti, TaskInstance):
return ti
Expand All @@ -159,11 +160,13 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, sessio
)
.one_or_none()
)
if isinstance(val, TaskInstance):
val._try_number = ti.try_number
return val
else:
if not val:
raise AirflowException(f"Could not find TaskInstance for {ti}")
if isinstance(ti, TaskInstancePydantic):
val.try_number = _get_private_try_number(task_instance=ti)
else: # TaskInstanceKey
val.try_number = ti.try_number
return val


class FileTaskHandler(logging.Handler):
Expand Down
15 changes: 14 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from traceback import format_exception
from typing import cast
from unittest import mock
from unittest.mock import call, mock_open, patch
from unittest.mock import MagicMock, call, mock_open, patch
from uuid import uuid4

import pendulum
Expand Down Expand Up @@ -66,6 +66,8 @@
TaskInstance,
TaskInstance as TI,
TaskInstanceNote,
_get_private_try_number,
_get_try_number,
_run_finished_callback,
)
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -4636,3 +4638,14 @@ def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
assert ti.try_number == 1 # stays 1
ti.refresh_from_db()
assert ti.try_number == 1 # stays 1


@pytest.mark.parametrize("state", list(TaskInstanceState))
def test_get_private_try_number(state: str):
mock_ti = MagicMock()
mock_ti.state = state
private_try_number = 2
mock_ti._try_number = private_try_number
mock_ti.try_number = _get_try_number(task_instance=mock_ti)
delattr(mock_ti, "_try_number")
assert _get_private_try_number(task_instance=mock_ti) == private_try_number

0 comments on commit 6fa08a9

Please sign in to comment.