diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py index 282e7b8f945a3..01cf859f05efd 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py @@ -112,6 +112,7 @@ def get_log( TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, TaskInstance.map_index == map_index, + TaskInstance.try_number == try_number, ) .join(TaskInstance.dag_run) .options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job)) diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py b/airflow-core/src/airflow/models/taskinstancehistory.py index 7fdd55178e55a..cbd83643fdb8c 100644 --- a/airflow-core/src/airflow/models/taskinstancehistory.py +++ b/airflow-core/src/airflow/models/taskinstancehistory.py @@ -52,6 +52,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from airflow.models import DagRun from airflow.models.taskinstance import TaskInstance @@ -161,6 +162,11 @@ def __init__( Index("idx_tih_dag_run", dag_id, run_id), ) + @property + def id(self) -> str: + """Alias for primary key field to support TaskInstance.""" + return self.task_instance_id + @staticmethod @provide_session def record_ti(ti: TaskInstance, session: Session = NEW_SESSION) -> None: @@ -183,3 +189,8 @@ def record_ti(ti: TaskInstance, session: Session = NEW_SESSION) -> None: ti.set_duration() ti_history = TaskInstanceHistory(ti, state=ti_history_state) session.add(ti_history) + + @provide_session + def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: + """Return the DagRun for this TaskInstanceHistory, matching TaskInstance.""" + return self.dag_run diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index c39792baa81c9..15339c00cbb69 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -34,7 +34,6 @@ from pydantic import BaseModel, ConfigDict, ValidationError from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.helpers import parse_template_string, render_template from airflow.utils.log.logging_mixin import SetContextPropagate @@ -45,7 +44,7 @@ if TYPE_CHECKING: from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstance - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.typing_compat import TypeAlias @@ -180,32 +179,6 @@ def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage] last = msg -def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: - """ - Given TI | TIKey, return a TI object. - - Will raise exception if no TI is found in the database. - """ - from airflow.models.taskinstance import TaskInstance - - if isinstance(ti, TaskInstance): - return ti - val = ( - session.query(TaskInstance) - .filter( - TaskInstance.task_id == ti.task_id, - TaskInstance.dag_id == ti.dag_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - .one_or_none() - ) - if not val: - raise AirflowException(f"Could not find TaskInstance for {ti}") - val.try_number = ti.try_number - return val - - class FileTaskHandler(logging.Handler): """ FileTaskHandler is a python log handler that handles and reads task instance logs. @@ -253,7 +226,9 @@ def __init__( Some handlers emit "end of log" markers, and may not wish to do so when task defers. """ - def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None | SetContextPropagate: + def set_context( + self, ti: TaskInstance | TaskInstanceHistory, *, identifier: str | None = None + ) -> None | SetContextPropagate: """ Provide task_instance context to airflow task handler. @@ -309,9 +284,10 @@ def close(self): self.handler.close() @provide_session - def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSION) -> str: + def _render_filename( + self, ti: TaskInstance | TaskInstanceHistory, try_number: int, session=NEW_SESSION + ) -> str: """Return the worker log filename.""" - ti = _ensure_ti(ti, session) dag_run = ti.get_dagrun(session=session) date = dag_run.logical_date or dag_run.run_after @@ -344,8 +320,8 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen") def _get_executor_get_task_log( - self, ti: TaskInstance - ) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]: + self, ti: TaskInstance | TaskInstanceHistory + ) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]: """ Get the get_task_log method from executor of current task instance. @@ -367,7 +343,7 @@ def _get_executor_get_task_log( def _read( self, - ti: TaskInstance, + ti: TaskInstance | TaskInstanceHistory, try_number: int, metadata: dict[str, Any] | None = None, ): @@ -455,7 +431,8 @@ def _read( return logs, {"end_of_log": end_of_log, "log_pos": log_pos} @staticmethod - def _get_pod_namespace(ti: TaskInstance): + @staticmethod + def _get_pod_namespace(ti: TaskInstance | TaskInstanceHistory): pod_override = ti.executor_config.get("pod_override") namespace = None with suppress(Exception): @@ -463,7 +440,10 @@ def _get_pod_namespace(ti: TaskInstance): return namespace or conf.get("kubernetes_executor", "namespace") def _get_log_retrieval_url( - self, ti: TaskInstance, log_relative_path: str, log_type: LogType | None = None + self, + ti: TaskInstance | TaskInstanceHistory, + log_relative_path: str, + log_type: LogType | None = None, ) -> tuple[str, str]: """Given TI, generate URL with which to fetch logs from service log server.""" if log_type == LogType.TRIGGER: @@ -487,7 +467,7 @@ def _get_log_retrieval_url( def read( self, - task_instance: TaskInstance, + task_instance: TaskInstance | TaskInstanceHistory, try_number: int | None = None, metadata: dict[str, Any] | None = None, ) -> tuple[list[StructuredLogMessage] | str, dict[str, Any]]: diff --git a/airflow-core/src/airflow/utils/log/log_reader.py b/airflow-core/src/airflow/utils/log/log_reader.py index c00a6877511d0..4a36b3ef1b07e 100644 --- a/airflow-core/src/airflow/utils/log/log_reader.py +++ b/airflow-core/src/airflow/utils/log/log_reader.py @@ -33,6 +33,7 @@ from sqlalchemy.orm.session import Session from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.typing_compat import TypeAlias LogMessages: TypeAlias = Union[list[StructuredLogMessage], str] @@ -48,7 +49,10 @@ class TaskLogReader: """Number of empty loop iterations before stopping the stream""" def read_log_chunks( - self, ti: TaskInstance, try_number: int | None, metadata + self, + ti: TaskInstance | TaskInstanceHistory, + try_number: int | None, + metadata: LogMetadata, ) -> tuple[LogMessages, LogMetadata]: """ Read chunks of Task Instance logs. @@ -70,7 +74,12 @@ def read_log_chunks( """ return self.log_handler.read(ti, try_number, metadata=metadata) - def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: dict) -> Iterator[str]: + def read_log_stream( + self, + ti: TaskInstance | TaskInstanceHistory, + try_number: int | None, + metadata: LogMetadata, + ) -> Iterator[str]: """ Continuously read log to the end. @@ -147,7 +156,7 @@ def supports_external_link(self) -> bool: @provide_session def render_log_filename( self, - ti: TaskInstance, + ti: TaskInstance | TaskInstanceHistory, try_number: int | None = None, *, session: Session = NEW_SESSION, diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index 6d06137573973..5813c22819eec 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -42,6 +42,7 @@ from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.models.trigger import Trigger from airflow.providers.standard.operators.python import PythonOperator from airflow.utils.log.file_task_handler import ( @@ -606,6 +607,38 @@ def test_jinja_rendering_catchup_false(self, create_log_template, create_task_in rendered_filename = fth._render_filename(filename_rendering_ti, 42) assert expected_filename == rendered_filename + def test_jinja_id_in_template_for_history( + self, create_log_template, create_task_instance, logical_date, session + ): + """Test that Jinja template using ti.id works for both TaskInstance and TaskInstanceHistory""" + create_log_template("{{ ti.id }}.log") + ti = create_task_instance( + dag_id="dag_history_test", + task_id="history_task", + run_type=DagRunType.SCHEDULED, + logical_date=DEFAULT_DATE, + catchup=True, + ) + TaskInstanceHistory.record_ti(ti, session=session) + session.flush() + tih = ( + session.query(TaskInstanceHistory) + .filter_by( + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index, + try_number=ti.try_number, + ) + .one() + ) + fth = FileTaskHandler("") + rendered_ti = fth._render_filename(ti, ti.try_number, session=session) + rendered_tih = fth._render_filename(tih, ti.try_number, session=session) + expected = f"{ti.id}.log" + assert rendered_ti == expected + assert rendered_tih == expected + class TestLogUrl: def test_log_retrieval_valid(self, create_task_instance):