-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Make ExternalTaskSensor work with Task SDK
#48651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,13 +23,14 @@ | |
| from uuid import UUID | ||
|
|
||
| from cadwyn import VersionedAPIRouter | ||
| from fastapi import Body, Depends, HTTPException, status | ||
| from fastapi import Body, Depends, HTTPException, Query, status | ||
| from pydantic import JsonValue | ||
| from sqlalchemy import func, tuple_, update | ||
| from sqlalchemy import func, or_, tuple_, update | ||
| from sqlalchemy.exc import NoResultFound, SQLAlchemyError | ||
| from sqlalchemy.sql import select | ||
|
|
||
| from airflow.api_fastapi.common.db.common import SessionDep | ||
| from airflow.api_fastapi.common.types import UtcDateTime | ||
| from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( | ||
| PrevSuccessfulDagRunResponse, | ||
| TIDeferredStatePayload, | ||
|
|
@@ -45,6 +46,7 @@ | |
| TITerminalStatePayload, | ||
| ) | ||
| from airflow.api_fastapi.execution_api.deps import JWTBearer | ||
| from airflow.models.dagbag import DagBag | ||
| from airflow.models.dagrun import DagRun as DR | ||
| from airflow.models.taskinstance import TaskInstance as TI, _update_rtif | ||
| from airflow.models.taskreschedule import TaskReschedule | ||
|
|
@@ -53,7 +55,9 @@ | |
| from airflow.utils import timezone | ||
| from airflow.utils.state import DagRunState, TaskInstanceState | ||
|
|
||
| router = VersionedAPIRouter( | ||
| router = VersionedAPIRouter() | ||
|
|
||
| ti_id_router = VersionedAPIRouter( | ||
|
Comment on lines
+58
to
+60
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment here explaining why we need this one? It will be clear to reader too |
||
| dependencies=[ | ||
| # This checks that the UUID in the url matches the one in the token for us. | ||
| Depends(JWTBearer(path_param_name="task_instance_id")), | ||
|
|
@@ -64,7 +68,7 @@ | |
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @router.patch( | ||
| @ti_id_router.patch( | ||
| "/{task_instance_id}/run", | ||
| status_code=status.HTTP_200_OK, | ||
| responses={ | ||
|
|
@@ -243,7 +247,7 @@ def ti_run( | |
| ) | ||
|
|
||
|
|
||
| @router.patch( | ||
| @ti_id_router.patch( | ||
| "/{task_instance_id}/state", | ||
| status_code=status.HTTP_204_NO_CONTENT, | ||
| responses={ | ||
|
|
@@ -404,7 +408,7 @@ def ti_update_state( | |
| ) | ||
|
|
||
|
|
||
| @router.patch( | ||
| @ti_id_router.patch( | ||
| "/{task_instance_id}/skip-downstream", | ||
| status_code=status.HTTP_204_NO_CONTENT, | ||
| responses={ | ||
|
|
@@ -436,7 +440,7 @@ def ti_skip_downstream( | |
| log.info("TI %s updated the state of %s task(s) to skipped", ti_id_str, result.rowcount) | ||
|
|
||
|
|
||
| @router.put( | ||
| @ti_id_router.put( | ||
| "/{task_instance_id}/heartbeat", | ||
| status_code=status.HTTP_204_NO_CONTENT, | ||
| responses={ | ||
|
|
@@ -498,7 +502,7 @@ def ti_heartbeat( | |
| log.debug("Task with %s state heartbeated", previous_state) | ||
|
|
||
|
|
||
| @router.put( | ||
| @ti_id_router.put( | ||
| "/{task_instance_id}/rtif", | ||
| status_code=status.HTTP_201_CREATED, | ||
| # TODO: Add description to the operation | ||
|
|
@@ -528,7 +532,7 @@ def ti_put_rtif( | |
| return {"message": "Rendered task instance fields successfully set"} | ||
|
|
||
|
|
||
| @router.get( | ||
| @ti_id_router.get( | ||
| "/{task_instance_id}/previous-successful-dagrun", | ||
| status_code=status.HTTP_200_OK, | ||
| responses={ | ||
|
|
@@ -564,8 +568,86 @@ def get_previous_successful_dagrun( | |
| return PrevSuccessfulDagRunResponse.model_validate(dag_run) | ||
|
|
||
|
|
||
| @router.only_exists_in_older_versions | ||
| @router.post( | ||
| @router.get("/count", status_code=status.HTTP_200_OK) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will need to add cadwyn migration for the new endpoints: https://docs.cadwyn.dev/concepts/version_changes/
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. We only need to add a migration for breaking changes (or changes to existing endpoints) from what I understand.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as #48651 (comment)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check #48651 (comment) :) |
||
| def get_count( | ||
| dag_id: str, | ||
| session: SessionDep, | ||
| task_ids: Annotated[list[str] | None, Query()] = None, | ||
| task_group_id: Annotated[str | None, Query()] = None, | ||
| logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None, | ||
| run_ids: Annotated[list[str] | None, Query()] = None, | ||
| states: Annotated[list[str] | None, Query()] = None, | ||
| ) -> int: | ||
| """Get the count of task instances matching the given criteria.""" | ||
| query = select(func.count()).select_from(TI).where(TI.dag_id == dag_id) | ||
|
|
||
| if task_ids: | ||
| query = query.where(TI.task_id.in_(task_ids)) | ||
|
|
||
| if logical_dates: | ||
| query = query.where(TI.logical_date.in_(logical_dates)) | ||
|
|
||
| if run_ids: | ||
| query = query.where(TI.run_id.in_(run_ids)) | ||
|
|
||
| if task_group_id: | ||
| # Get all tasks in the task group | ||
| dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session) | ||
| if not dag: | ||
| raise HTTPException( | ||
| status.HTTP_404_NOT_FOUND, | ||
| detail={ | ||
| "reason": "not_found", | ||
| "message": f"DAG {dag_id} not found", | ||
| }, | ||
| ) | ||
|
|
||
| task_group = dag.task_group_dict.get(task_group_id) | ||
| if not task_group: | ||
| raise HTTPException( | ||
| status.HTTP_404_NOT_FOUND, | ||
| detail={ | ||
| "reason": "not_found", | ||
| "message": f"Task group {task_group_id} not found in DAG {dag_id}", | ||
| }, | ||
| ) | ||
|
|
||
| # First get all task instances to get the task_id, map_index pairs | ||
| group_tasks = session.scalars( | ||
| select(TI).where( | ||
| TI.dag_id == dag_id, | ||
| TI.task_id.in_(task.task_id for task in task_group.iter_tasks()), | ||
| *([TI.logical_date.in_(logical_dates)] if logical_dates else []), | ||
| *([TI.run_id.in_(run_ids)] if run_ids else []), | ||
| ) | ||
| ).all() | ||
|
|
||
| # Get unique (task_id, map_index) pairs | ||
| task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks] | ||
| if not task_map_pairs: | ||
| # If no task group tasks found, default to checking the task group ID itself | ||
| # This matches the behavior in _get_external_task_group_task_ids | ||
| task_map_pairs = [(task_group_id, -1)] | ||
|
|
||
| # Update query to use task_id, map_index pairs | ||
| query = query.where(tuple_(TI.task_id, TI.map_index).in_(task_map_pairs)) | ||
|
|
||
| if states: | ||
| if "null" in states: | ||
| not_none_states = [s for s in states if s != "null"] | ||
| if not_none_states: | ||
| query = query.where(or_(TI.state.is_(None), TI.state.in_(not_none_states))) | ||
| else: | ||
| query = query.where(TI.state.is_(None)) | ||
| else: | ||
| query = query.where(TI.state.in_(states)) | ||
|
|
||
| count = session.scalar(query) | ||
| return count or 0 | ||
|
|
||
|
|
||
| @ti_id_router.only_exists_in_older_versions | ||
| @ti_id_router.post( | ||
| "/{task_instance_id}/runtime-checks", | ||
| status_code=status.HTTP_204_NO_CONTENT, | ||
| # TODO: Add description to the operation | ||
|
|
@@ -602,3 +684,7 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: | |
| # max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for | ||
| # retries from the task SDK now, we can handle using max_tries | ||
| return max_tries != 0 and try_number <= max_tries | ||
|
|
||
|
|
||
| # This line should be at the end of the file to ensure all routes are registered | ||
| router.include_router(ti_id_router) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to add cadwyn migration for the new endpoints: https://docs.cadwyn.dev/concepts/version_changes/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. We only need to add a migration for breaking changes (or changes to existing endpoints) from what I understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked it here: https://docs.cadwyn.dev/concepts/endpoint_migrations/#defining-endpoints-that-didnt-exist-in-old-versions. Seems we will need it. Let me take it up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just had a chat with the Cadwyn Author ( @zmievsa ) who recommends to only add it for breaking changes.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make Cadwyn's docs more verbose on when it makes the most sense to add a migration. Concepts section mostly focuses on what's possible with Cadwyn while "how to" focuses on what you should actually do.
Either way 99% of the time it makes sense to add an endpoint to all versions since it's not a breaking change. Your users will thank you later
Update: https://docs.cadwyn.dev/concepts/endpoint_migrations/#defining-endpoints-that-didnt-exist-in-old-versions added a bunch of notes here and there about this.