Skip to content

Commit

Permalink
Minimize diff and fix scalar usage
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Oct 23, 2023
1 parent 57a5254 commit cc32607
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
28 changes: 12 additions & 16 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,21 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = session.scalar(
select(func.count(XCom.map_index)).where(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = session.scalar(
select(TaskMap.length).where(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
query = select(TaskMap.length).where(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
return query
return session.scalar(query)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
Expand Down
5 changes: 2 additions & 3 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,11 @@ def execute_complete(self, context: Context, session: Session, event: tuple[str,
# This execution date is parsed from the return trigger event
provided_execution_date = event[1]["execution_dates"][0]
try:
dag_run = session.scalar(
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_execution_date
)
)

).scalar_one()
except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and execution date {self.execution_date}"
Expand Down

0 comments on commit cc32607

Please sign in to comment.