Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions airflow-core/src/airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,19 @@ def is_failure_callback(self) -> bool:
}


class DagRunContext(BaseModel):
"""Class to pass context info from the server to build a Execution context object."""

dag_run: ti_datamodel.DagRun | None = None
last_ti: ti_datamodel.TaskInstance | None = None


class DagCallbackRequest(BaseCallbackRequest):
"""A Class with information about the success/failure DAG callback to be executed."""

dag_id: str
run_id: str
context_from_server: DagRunContext | None = None
is_failure_callback: bool | None = True
"""Flag to determine whether it is a Failure Callback or Success Callback"""
type: Literal["DagCallbackRequest"] = "DagCallbackRequest"
Expand Down
60 changes: 52 additions & 8 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@
DeleteVariable,
ErrorResponse,
GetConnection,
GetPreviousDagRun,
GetPrevSuccessfulDagRun,
GetVariable,
OKResponse,
PreviousDagRunResult,
PrevSuccessfulDagRunResult,
PutVariable,
VariableResult,
)
Expand Down Expand Up @@ -94,12 +98,24 @@ class DagFileParsingResult(BaseModel):


ToManager = Annotated[
DagFileParsingResult | GetConnection | GetVariable | PutVariable | DeleteVariable,
DagFileParsingResult
| GetConnection
| GetVariable
| PutVariable
| DeleteVariable
| GetPrevSuccessfulDagRun
| GetPreviousDagRun,
Field(discriminator="type"),
]

ToDagProcessor = Annotated[
DagFileParseRequest | ConnectionResult | VariableResult | ErrorResponse | OKResponse,
DagFileParseRequest
| ConnectionResult
| VariableResult
| PreviousDagRunResult
| PrevSuccessfulDagRunResult
| ErrorResponse
| OKResponse,
Field(discriminator="type"),
]

Expand Down Expand Up @@ -209,6 +225,8 @@ def _execute_callbacks(


def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: FilteringBoundLogger) -> None:
from airflow.sdk.api.datamodels._generated import TIRunContext

dag = dagbag.dags[request.dag_id]

callbacks = dag.on_failure_callback if request.is_failure_callback else dag.on_success_callback
Expand All @@ -217,12 +235,27 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
return

callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
# TODO:We need a proper context object!
context: Context = {
"dag": dag,
"run_id": request.run_id,
"reason": request.msg,
}
ctx_from_server = request.context_from_server

if ctx_from_server is not None and ctx_from_server.last_ti is not None:
task = dag.get_task(ctx_from_server.last_ti.task_id)

runtime_ti = RuntimeTaskInstance.model_construct(
**ctx_from_server.last_ti.model_dump(exclude_unset=True),
task=task,
_ti_context_from_server=TIRunContext.model_construct(
dag_run=ctx_from_server.dag_run,
max_tries=task.retries,
),
)
context = runtime_ti.get_template_context()
context["reason"] = request.msg
else:
context: Context = { # type: ignore[no-redef]
"dag": dag,
"run_id": request.run_id,
"reason": request.msg,
}

for callback in callbacks:
log.info(
Expand Down Expand Up @@ -383,6 +416,17 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
elif isinstance(msg, GetPreviousDagRun):
resp = self.client.dag_runs.get_previous(
dag_id=msg.dag_id,
logical_date=msg.logical_date,
state=msg.state,
)
elif isinstance(msg, GetPrevSuccessfulDagRun):
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)
dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
resp = dagrun_result
dump_opts = {"exclude_unset": True}
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
Expand Down
6 changes: 5 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from airflow import settings
from airflow._shared.timezones import timezone
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext, TaskCallbackRequest
from airflow.configuration import conf
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
from airflow.executors import workloads
Expand Down Expand Up @@ -1854,6 +1854,10 @@ def _schedule_dag_run(
run_id=dag_run.run_id,
bundle_name=dag_model.bundle_name,
bundle_version=dag_run.bundle_version,
context_from_server=DagRunContext(
dag_run=dag_run,
last_ti=dag_run.get_last_ti(dag=dag, session=session),
),
is_failure_callback=True,
msg="timed_out",
)
Expand Down
85 changes: 78 additions & 7 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from sqlalchemy_utils import UUIDType

from airflow._shared.timezones import timezone
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, TaskNotFound
from airflow.listeners.listener import get_listener_manager
Expand Down Expand Up @@ -102,7 +102,7 @@
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.sdk import DAG as SDKDAG, Context
from airflow.sdk import DAG as SDKDAG
from airflow.sdk.types import Operator
from airflow.serialization.serialized_objects import SerializedBaseOperator as BaseOperator
from airflow.utils.types import ArgNotSet
Expand Down Expand Up @@ -1186,6 +1186,10 @@ def recalculate(self) -> _UnfinishedStates:
run_id=self.run_id,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
context_from_server=DagRunContext(
dag_run=self,
last_ti=self.get_last_ti(dag=dag, session=session),
),
is_failure_callback=True,
msg="task_failure",
)
Expand Down Expand Up @@ -1215,6 +1219,10 @@ def recalculate(self) -> _UnfinishedStates:
run_id=self.run_id,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
context_from_server=DagRunContext(
dag_run=self,
last_ti=self.get_last_ti(dag=dag, session=session),
),
is_failure_callback=False,
msg="success",
)
Expand All @@ -1238,6 +1246,10 @@ def recalculate(self) -> _UnfinishedStates:
run_id=self.run_id,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
context_from_server=DagRunContext(
dag_run=self,
last_ti=self.get_last_ti(dag=dag, session=session),
),
is_failure_callback=True,
msg="all_tasks_deadlocked",
)
Expand Down Expand Up @@ -1350,13 +1362,72 @@ def notify_dagrun_state_changed(self, msg: str = ""):
# we can't get all the state changes on SchedulerJob,
# or LocalTaskJob, so we don't want to "falsely advertise" we notify about that

@provide_session
def get_last_ti(self, dag: DAG, session: Session = NEW_SESSION) -> TI | None:
"""Get Last TI from the dagrun to build and pass Execution context object from server to then run callbacks."""
tis = self.get_task_instances(session=session)
# tis from a dagrun may not be a part of dag.partial_subset,
# since dag.partial_subset is a subset of the dag.
# This ensures that we will only use the accessible TI
# context for the callback.
if dag.partial:
tis = [ti for ti in tis if not ti.state == State.NONE]
# filter out removed tasks
tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
if not tis:
return None
ti = tis[-1] # get last TaskInstance of DagRun
return ti

def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"):
"""Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`."""
context: Context = {
"dag": dag,
"run_id": str(self.run_id),
"reason": reason,
}
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun as DRDataModel,
TaskInstance as TIDataModel,
TIRunContext,
)
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

last_ti = self.get_last_ti(dag) # type: ignore[arg-type]
if last_ti:
last_ti_model = TIDataModel.model_validate(last_ti, from_attributes=True)
task = dag.get_task(last_ti.task_id)

dag_run_data = DRDataModel(
dag_id=self.dag_id,
run_id=self.run_id,
logical_date=self.logical_date,
data_interval_start=self.data_interval_start,
data_interval_end=self.data_interval_end,
run_after=self.run_after,
start_date=self.start_date,
end_date=self.end_date,
run_type=self.run_type,
state=self.state,
conf=self.conf,
consumed_asset_events=[],
)

runtime_ti = RuntimeTaskInstance.model_construct(
**last_ti_model.model_dump(exclude_unset=True),
task=task,
_ti_context_from_server=TIRunContext(
dag_run=dag_run_data,
max_tries=last_ti.max_tries,
variables=[],
connections=[],
xcom_keys_to_clear=[],
),
max_tries=last_ti.max_tries,
)
context = runtime_ti.get_template_context()
else:
context = {
"dag": dag,
"run_id": self.run_id,
}

context["reason"] = reason

callbacks = dag.on_success_callback if success else dag.on_failure_callback
if not callbacks:
Expand Down
Loading