diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index cf074cd32c244..69859af3f5610 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -381,7 +381,7 @@ def task_test(args, dag: DAG | None = None) -> None: ) try: with redirect_stdout(RedactedIO()): - _run_task(ti=ti) + _run_task(ti=ti, run_triggerer=True) if ti.state == State.FAILED and args.post_mortem: debugger = _guess_debugger() debugger.set_trace() diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index da7d9a801dd54..bc16db4092908 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -24,15 +24,12 @@ import math import operator import os -import signal -import traceback from collections import defaultdict -from collections.abc import Collection, Generator, Iterable, Mapping, Sequence +from collections.abc import Collection, Generator, Iterable, Sequence from datetime import timedelta -from enum import Enum from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from urllib.parse import quote import attrs @@ -40,7 +37,6 @@ import jinja2 import lazy_object_proxy import uuid6 -from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import ( Column, Float, @@ -77,32 +73,19 @@ from airflow.assets.manager import asset_manager from airflow.configuration import conf from airflow.exceptions import ( - AirflowException, - AirflowFailException, AirflowInactiveAssetInInletOrOutletException, - AirflowRescheduleException, - AirflowSensorTimeout, - AirflowSkipException, - AirflowTaskTerminated, - AirflowTaskTimeout, TaskDeferralError, TaskDeferred, - UnmappableXComLengthPushed, - UnmappableXComTypePushed, - XComForMappingNotPushed, ) from airflow.listeners.listener import get_listener_manager -from airflow.models.asset import AssetActive, AssetEvent, AssetModel +from airflow.models.asset import AssetEvent, AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies from airflow.models.log import Log -from airflow.models.renderedtifields import get_serialized_template_fields from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XComModel from airflow.plugins_manager import integrate_macros_plugins -from airflow.sdk.execution_time.context import context_to_airflow_vars -from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext @@ -118,8 +101,6 @@ from airflow.utils.span_status import SpanStatus from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.timeout import timeout from airflow.utils.xcom import XCOM_RETURN_KEY TR = TaskReschedule @@ -130,7 +111,6 @@ if TYPE_CHECKING: from datetime import datetime from pathlib import PurePath - from types import TracebackType import pendulum from sqlalchemy.engine import Connection as SAConnection, Engine @@ -139,12 +119,11 @@ from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.expression import ColumnOperators - from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun from airflow.sdk.api.datamodels._generated import AssetProfile - from airflow.sdk.definitions._internal.abstractoperator import Operator + from airflow.sdk.definitions._internal.abstractoperator import Operator, TaskStateChangeCallback from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import MappedTaskGroup @@ -157,17 +136,6 @@ PAST_DEPENDS_MET = "past_depends_met" -class TaskReturnCode(Enum): - """ - Enum to signal manner of exit for task run command. - - :meta private: - """ - - DEFERRED = 100 - """When task exits with deferral to trigger.""" - - @provide_session def _add_log( event, @@ -352,46 +320,6 @@ def _creator_note(val): return TaskInstanceNote(*val) -@provide_session -def _record_task_map_for_downstreams( - *, - task_instance: TaskInstance, - task: Operator, - value: Any, - session: Session, -) -> None: - """ - Record the task map for downstream tasks. - - :param task_instance: the task instance - :param task: The task object - :param dag: the dag associated with the task - :param value: The value - :param session: SQLAlchemy ORM Session - - :meta private: - """ - from airflow.sdk.definitions.mappedoperator import MappedOperator, is_mappable_value - - if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. - return - # TODO: We don't push TaskMap for mapped task instances because it's not - # currently possible for a downstream to depend on one individual mapped - # task instance. This will change when we implement task mapping inside - # a mapped task group, and we'll need to further analyze the case. - if isinstance(task, MappedOperator): - return - if value is None: - raise XComForMappingNotPushed() - if not is_mappable_value(value): - raise UnmappableXComTypePushed(value) - task_map = TaskMap.from_task_instance_xcom(task_instance, value) - max_map_length = conf.getint("core", "max_map_length", fallback=1024) - if task_map.length > max_map_length: - raise UnmappableXComLengthPushed(value, max_map_length) - session.merge(task_map) - - def _get_email_subject_content( *, task_instance: TaskInstance | RuntimeTaskInstanceProtocol, @@ -1128,36 +1056,6 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> # Re-apply cluster policy here so that task default do not overload previous data task_instance_mutation_hook(self) - @staticmethod - @provide_session - def _clear_xcom_data(ti: TaskInstance, session: Session = NEW_SESSION) -> None: - """ - Clear all XCom data from the database for the task instance. - - If the task is unmapped, all XComs matching this task ID in the same DAG - run are removed. If the task is mapped, only the one with matching map - index is removed. - - :param ti: The TI for which we need to clear xcoms. - :param session: SQLAlchemy ORM Session - """ - ti.log.debug("Clearing XCom data") - if ti.map_index < 0: - map_index: int | None = None - else: - map_index = ti.map_index - XComModel.clear( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=map_index, - session=session, - ) - - @provide_session - def clear_xcom_data(self, session: Session = NEW_SESSION): - self._clear_xcom_data(ti=self, session=session) - @property def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely.""" @@ -1670,184 +1568,25 @@ def clear_next_method_args(self) -> None: self.next_kwargs = None @provide_session - @Sentry.enrich_errors def _run_raw_task( self, mark_success: bool = False, - test_mode: bool = False, - pool: str | None = None, - raise_on_defer: bool = False, session: Session = NEW_SESSION, - ) -> TaskReturnCode | None: - """ - Run a task, update the state upon completion, and run any appropriate callbacks. - - Immediately runs the task (without checking or changing db state - before execution) and then sets the appropriate final state after - completion and runs any post-execute callbacks. Meant to be called - only after another function changes the state to running. - - :param mark_success: Don't run the task, mark its state as success - :param test_mode: Doesn't record success or failure in the DB - :param pool: specifies the pool to use to run the task instance - :param session: SQLAlchemy ORM Session - """ - if TYPE_CHECKING: - assert self.task - - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - self.test_mode = test_mode - self.refresh_from_task(self.task, pool_override=pool) - self.refresh_from_db(session=session) - self.hostname = get_hostname() - self.pid = os.getpid() - if not test_mode: - TaskInstance.save_to_db(ti=self, session=session) - actual_start_date = timezone.utcnow() - Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("ti.start", tags=self.stats_tags) - # Initialize final state counters at zero - for state in State.task_states: - Stats.incr( - f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", - count=0, - tags=self.stats_tags, - ) - # Same metric with tagging - Stats.incr( - "ti.finish", - count=0, - tags={**self.stats_tags, "state": str(state)}, - ) - with set_current_task_instance_session(session=session): - self.task = self.task.prepare_for_execution() - context = self.get_template_context(ignore_param_exceptions=False, session=session) - - try: - if self.task: - from airflow.sdk.definitions.asset import Asset - - inlets = [asset.asprofile() for asset in self.task.inlets if isinstance(asset, Asset)] - outlets = [asset.asprofile() for asset in self.task.outlets if isinstance(asset, Asset)] - TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session) - if not mark_success: - TaskInstance._execute_task_with_callbacks( - self=self, # type: ignore[arg-type] - context=context, - test_mode=test_mode, - session=session, - ) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True) - self.state = TaskInstanceState.SUCCESS - except TaskDeferred as defer: - # The task has signalled it wants to defer execution based on - # a trigger. - if raise_on_defer: - raise - self.defer_task(exception=defer, session=session) - self.log.info( - "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, logical_date=%s, start_date=%s", - self.dag_id, - self.task_id, - self.run_id, - _date_or_empty(task_instance=self, attr="logical_date"), - _date_or_empty(task_instance=self, attr="start_date"), - ) - return TaskReturnCode.DEFERRED - except AirflowSkipException as e: - # Recording SKIP - # log only if exception has any arguments to prevent log flooding - if e.args: - self.log.info(e) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True) - self.state = TaskInstanceState.SKIPPED - _run_finished_callback(callbacks=self.task.on_skipped_callback, context=context) - TaskInstance.save_to_db(ti=self, session=session) - except AirflowRescheduleException as reschedule_exception: - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) - self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") - return None - except (AirflowFailException, AirflowSensorTimeout) as e: - # If AirflowFailException is raised, task should not retry. - # If a sensor in reschedule mode reaches timeout, task should not retry. - self.handle_failure( - e, test_mode, context, force_fail=True, session=session - ) # already saves to db - raise - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e: - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - # for case when task is marked as success/failed externally - # or dagrun timed out and task is marked as skipped - # current behavior doesn't hit the callbacks - if self.state in State.finished: - self.clear_next_method_args() - TaskInstance.save_to_db(ti=self, session=session) - return None - self.handle_failure(e, test_mode, context, session=session) - raise - except SystemExit as e: - # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. - # Therefore, here we must handle only error codes. - msg = f"Task failed due to SystemExit({e.code})" - self.handle_failure(msg, test_mode, context, session=session) - raise AirflowException(msg) - except BaseException as e: - self.handle_failure(e, test_mode, context, session=session) - raise - finally: - # Print a marker post execution for internals of post task processing - log.info("::group::Post task execution logs") - - Stats.incr( - f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", - tags=self.stats_tags, - ) - # Same metric with tagging - Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) - - # Recording SKIPPED or SUCCESS - self.clear_next_method_args() - self.end_date = timezone.utcnow() - _log_state(task_instance=self) - self.set_duration() - - # run on_success_callback before db committing - # otherwise, the LocalTaskJob sees the state is changed to `success`, - # but the task_runner is still running, LocalTaskJob then treats the state is set externally! - if self.state == TaskInstanceState.SUCCESS: - _run_finished_callback(callbacks=self.task.on_success_callback, context=context) - - if not test_mode: - _add_log(event=self.state, task_instance=self, session=session) - if self.state == TaskInstanceState.SUCCESS: - from airflow.sdk.execution_time.task_runner import ( - _build_asset_profiles, - _serialize_outlet_events, - ) - - TaskInstance.register_asset_changes_in_db( - self, - list(_build_asset_profiles(self.task.outlets)), - list(_serialize_outlet_events(context["outlet_events"])), - session=session, - ) + **kwargs: Any, + ) -> None: + """Only kept for tests.""" + from airflow.sdk.definitions.dag import _run_task - TaskInstance.save_to_db(ti=self, session=session) - if self.state == TaskInstanceState.SUCCESS: - try: - get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=self - ) - except Exception: - log.exception("error calling listener") + if mark_success: + self.set_state(TaskInstanceState.SUCCESS) + log.info("[DAG TEST] Marking success for %s ", self.task_id) return None + taskrun_result = _run_task(ti=self) + if taskrun_result is not None and taskrun_result.error: + raise taskrun_result.error + return None + @staticmethod @provide_session def register_asset_changes_in_db( @@ -1988,252 +1727,6 @@ def update_rtif(self, rendered_fields, session: Session = NEW_SESSION): session.flush() RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id, session=session) - def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session): - """Prepare Task for Execution.""" - from airflow.sdk.execution_time.callback_runner import create_executable_runner - from airflow.sdk.execution_time.context import context_get_outlet_events - - if TYPE_CHECKING: - assert self.task - - parent_pid = os.getpid() - - def signal_handler(signum, frame): - pid = os.getpid() - - # If a task forks during execution (from DAG code) for whatever - # reason, we want to make sure that we react to the signal only in - # the process that we've spawned ourselves (referred to here as the - # parent process). - if pid != parent_pid: - os._exit(1) - return - self.log.error("Received SIGTERM. Terminating subprocesses.") - self.log.error("Stacktrace: \n%s", "".join(traceback.format_stack())) - self.task.on_kill() - raise AirflowTaskTerminated( - f"Task received SIGTERM signal {self.task_id=} {self.dag_id=} {self.run_id=} {self.map_index=}" - ) - - signal.signal(signal.SIGTERM, signal_handler) - - # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral. - if not self.next_method: - self.clear_xcom_data() - - with ( - Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"), - Stats.timer("task.duration", tags=self.stats_tags), - ): - # Set the validated/merged params on the task object. - self.task.params = context["params"] - - with set_current_context(context): - dag = self.task.get_dag() - if dag is not None: - jinja_env = dag.get_template_env() - else: - jinja_env = None - task_orig = self.render_templates(context=context, jinja_env=jinja_env) - - # The task is never MappedOperator at this point. - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - if not test_mode: - rendered_fields = get_serialized_template_fields(task=self.task) - self.update_rtif(rendered_fields=rendered_fields) - # Export context to make it available for operators to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - os.environ.update(airflow_context_vars) - - # Log context only for the default execution method, the assumption - # being that otherwise we're resuming a deferred task (in which - # case there's no need to log these again). - if not self.next_method: - self.log.info( - "Exporting env vars: %s", - " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()), - ) - - # Run pre_execute callback - if self.task._pre_execute_hook: - create_executable_runner( - self.task._pre_execute_hook, - context_get_outlet_events(context), - logger=self.log, - ).run(context) - create_executable_runner( - self.task.pre_execute, - context_get_outlet_events(context), - logger=self.log, - ).run(context) - - # Run on_execute callback - self._run_execute_callback(context, self.task) - - # Run on_task_instance_running event - try: - get_listener_manager().hook.on_task_instance_running( - previous_state=TaskInstanceState.QUEUED, task_instance=self - ) - except Exception: - log.exception("error calling listener") - - def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: - """Render named map index if the DAG author defined map_index_template at the task level.""" - if jinja_env is None or (template := context.get("map_index_template")) is None: - return None - rendered_map_index = jinja_env.from_string(template).render(context) - log.debug("Map index rendered as %s", rendered_map_index) - return rendered_map_index - - # Execute the task. - with set_current_context(context): - try: - result = self._execute_task(context, task_orig) - except Exception: - # If the task failed, swallow rendering error so it doesn't mask the main error. - with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): - self._rendered_map_index = _render_map_index(context, jinja_env=jinja_env) - raise - else: # If the task succeeded, render normally to let rendering error bubble up. - self._rendered_map_index = _render_map_index(context, jinja_env=jinja_env) - - # Run post_execute callback - if self.task._post_execute_hook: - create_executable_runner( - self.task._post_execute_hook, - context_get_outlet_events(context), - logger=self.log, - ).run(context, result) - create_executable_runner( - self.task.post_execute, - context_get_outlet_events(context), - logger=self.log, - ).run(context, result) - - Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) - Stats.incr("ti_successes", tags=self.stats_tags) - - def _execute_task(self, context: Context, task_orig: Operator): - """ - Execute Task (optionally with a Timeout) and push Xcom results. - - :param context: Jinja2 context - :param task_orig: origin task - """ - from airflow.sdk.bases.operator import ExecutorSafeguard - from airflow.sdk.definitions.mappedoperator import MappedOperator - - task_to_execute = self.task - - if TYPE_CHECKING: - # TODO: TaskSDK this function will need 100% re-writing - # This only works with a "rich" BaseOperator, not the SDK version - assert isinstance(task_to_execute, BaseOperator) - - if isinstance(task_to_execute, MappedOperator): - raise AirflowException("MappedOperator cannot be executed.") - - # If the task has been deferred and is being executed due to a trigger, - # then we need to pick the right method to come back to, otherwise - # we go for the default execute - execute_callable_kwargs: dict[str, Any] = {} - execute_callable: Callable - if self.next_method: - execute_callable = task_to_execute.resume_execution - execute_callable_kwargs["next_method"] = self.next_method - # We don't want modifictions we make here to be tracked by SQLA - execute_callable_kwargs["next_kwargs"] = {**(self.next_kwargs or {})} - if self.next_method == "execute": - execute_callable_kwargs["next_kwargs"][f"{task_to_execute.__class__.__name__}__sentinel"] = ( - ExecutorSafeguard.sentinel_value - ) - else: - execute_callable = task_to_execute.execute - if execute_callable.__name__ == "execute": - execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = ( - ExecutorSafeguard.sentinel_value - ) - - def _execute_callable(context: Context, **execute_callable_kwargs): - from airflow.sdk.execution_time.callback_runner import create_executable_runner - from airflow.sdk.execution_time.context import context_get_outlet_events - - try: - # Print a marker for log grouping of details before task execution - log.info("::endgroup::") - - return create_executable_runner( - execute_callable, - context_get_outlet_events(context), - logger=log, - ).run(context=context, **execute_callable_kwargs) - except SystemExit as e: - # Handle only successful cases here. Failure cases will be handled upper - # in the exception chain. - if e.code is not None and e.code != 0: - raise - return None - - # If a timeout is specified for the task, make it fail - # if it goes beyond - if task_to_execute.execution_timeout: - # If we are coming in with a next_method (i.e. from a deferral), - # calculate the timeout from our start_date. - if self.next_method and self.start_date: - timeout_seconds = ( - task_to_execute.execution_timeout - (timezone.utcnow() - self.start_date) - ).total_seconds() - else: - timeout_seconds = task_to_execute.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = _execute_callable(context=context, **execute_callable_kwargs) - except AirflowTaskTimeout: - task_to_execute.on_kill() - raise - else: - result = _execute_callable(context=context, **execute_callable_kwargs) - cm = create_session() - with cm as session_or_null: - if task_to_execute.do_xcom_push: - xcom_value = result - else: - xcom_value = None - if xcom_value is not None: # If the task returns a result, push an XCom containing it. - if task_to_execute.multiple_outputs: - if not isinstance(xcom_value, Mapping): - raise AirflowException( - f"Returned output was type {type(xcom_value)} " - "expected dictionary for multiple_outputs" - ) - for key in xcom_value.keys(): - if not isinstance(key, str): - raise AirflowException( - "Returned dictionary keys must be strings when using " - f"multiple_outputs, found {key} ({type(key)}) instead" - ) - for key, value in xcom_value.items(): - self.xcom_push(key=key, value=value, session=session_or_null) - self.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null) - if TYPE_CHECKING: - assert task_orig.dag - _record_task_map_for_downstreams( - task_instance=self, - task=task_orig, - value=xcom_value, - session=session_or_null, - ) - return result - def update_heartbeat(self): with create_session() as session: session.execute( @@ -2318,16 +1811,6 @@ def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESS session.merge(self) session.commit() - def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: - """Functions that need to be run before a Task is executed.""" - if not (callbacks := task.on_execute_callback): - return - for callback in callbacks if isinstance(callbacks, list) else [callbacks]: - try: - callback(context) - except Exception: - self.log.exception("Failed when executing execute callback") - @provide_session def run( self, @@ -2343,7 +1826,7 @@ def run( session: Session = NEW_SESSION, raise_on_defer: bool = False, ) -> None: - """Run TaskInstance.""" + """Run TaskInstance (only kept for tests).""" res = self.check_and_change_state_before_execution( verbose=verbose, ignore_all_deps=ignore_all_deps, @@ -2359,13 +1842,7 @@ def run( if not res: return - self._run_raw_task( - mark_success=mark_success, - test_mode=test_mode, - pool=pool, - session=session, - raise_on_defer=raise_on_defer, - ) + self._run_raw_task(mark_success=mark_success) def dry_run(self) -> None: """Only Renders Templates for the TI.""" @@ -2378,65 +1855,6 @@ def dry_run(self) -> None: assert isinstance(self.task, BaseOperator) self.task.dry_run() - @provide_session - def _handle_reschedule( - self, - actual_start_date: datetime, - reschedule_exception: AirflowRescheduleException, - test_mode: bool = False, - session: Session = NEW_SESSION, - ): - # Don't record reschedule request in test mode - if test_mode: - return - - self.refresh_from_db(session) - - if TYPE_CHECKING: - assert self.task - - self.end_date = timezone.utcnow() - self.set_duration() - - # set state - self.state = TaskInstanceState.UP_FOR_RESCHEDULE - - self.clear_next_method_args() - - session.merge(self) - session.commit() - - # we add this in separate commit to reduce likelihood of deadlock - # see https://github.com/apache/airflow/pull/21362 for more info - session.add( - TaskReschedule( - self.id, - actual_start_date, - self.end_date, - reschedule_exception.reschedule_date, - ) - ) - session.commit() - return self - - @staticmethod - def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: - """ - Truncate the traceback of an exception to the first frame called from within a given function. - - :param error: exception to get traceback from - :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute - - :meta private: - """ - tb = error.__traceback__ - code = truncate_to.__func__.__code__ # type: ignore[attr-defined] - while tb is not None: - if tb.tb_frame.f_code is code: - return tb.tb_next - tb = tb.tb_next - return tb or error.__traceback__ - @classmethod def fetch_handle_failure_context( cls, @@ -2461,11 +1879,7 @@ def fetch_handle_failure_context( :param fail_fast: if True, fail all downstream tasks """ if error: - if isinstance(error, BaseException): - tb = TaskInstance.get_truncated_error_traceback(error, truncate_to=ti._execute_task) - cls.logger().error("Task failed with exception", exc_info=(type(error), error, tb)) - else: - cls.logger().error("%s", error) + cls.logger().error("%s", error) if not test_mode: ti.refresh_from_db(session) @@ -2778,47 +2192,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: return context - @provide_session - def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: - """ - Update task with rendered template fields for presentation in UI. - - If task has already run, will fetch from DB; otherwise will render. - """ - from airflow.models.renderedtifields import RenderedTaskInstanceFields - - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) - if rendered_task_instance_fields: - self.task = self.task.unmap(None) - for field_name, rendered_value in rendered_task_instance_fields.items(): - setattr(self.task, field_name, rendered_value) - return - - try: - # If we get here, either the task hasn't run or the RTIF record was purged. - from airflow.sdk.execution_time.secrets_masker import redact - - self.render_templates() - for field_name in self.task.template_fields: - rendered_value = getattr(self.task, field_name) - setattr(self.task, field_name, redact(rendered_value, field_name)) - except (TemplateAssertionError, UndefinedError) as e: - raise AirflowException( - "Webserver does not have access to User-defined Macros or Filters " - "when Dag Serialization is enabled. Hence for the task that have not yet " - "started running, please use 'airflow tasks render' for debugging the " - "rendering of template_fields." - ) from e - - def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: DagRun): - """Overwrite Task Params with DagRun.conf.""" - if dag_run and dag_run.conf: - self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) - params.update(dag_run.conf) - def render_templates( self, context: Context | None = None, jinja_env: jinja2.Environment | None = None ) -> Operator: @@ -3263,60 +2636,6 @@ def duration_expression_update( } ) - @staticmethod - def validate_inlet_outlet_assets_activeness( - inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session - ) -> None: - from airflow.sdk.definitions.asset import AssetUniqueKey - - if not (inlets or outlets): - return - - all_asset_unique_keys = { - AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore - for inlet_or_outlet in itertools.chain(inlets, outlets) - } - inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys( - all_asset_unique_keys, session - ) - if inactive_asset_unique_keys: - raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys) - - @staticmethod - def _get_inactive_asset_unique_keys( - asset_unique_keys: set[AssetUniqueKey], session: Session - ) -> set[AssetUniqueKey]: - from airflow.sdk.definitions.asset import AssetUniqueKey - - active_asset_unique_keys = { - AssetUniqueKey(name, uri) - for name, uri in session.execute( - select(AssetActive.name, AssetActive.uri).where( - tuple_(AssetActive.name, AssetActive.uri).in_( - attrs.astuple(key) for key in asset_unique_keys - ) - ) - ) - } - return asset_unique_keys - active_asset_unique_keys - - def get_first_reschedule_date(self, context: Context) -> datetime | None: - """Get the first reschedule date for the task instance.""" - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - with create_session() as session: - start_date = session.scalar( - select(TaskReschedule) - .where( - TaskReschedule.ti_id == str(self.id), - ) - .order_by(TaskReschedule.id.asc()) - .with_only_columns(TaskReschedule.start_date) - .limit(1) - ) - return start_date - def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: """Given two operators, find their innermost common mapped task group.""" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 9c7a6e39a0e96..9e26937b63c06 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -32,8 +32,35 @@ def client(request: pytest.FixtureRequest): with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: auth = AsyncMock(spec=JWTValidator) - auth.avalidated_claims.return_value = {"sub": "edb09971-4e0e-4221-ad3f-800852d38085"} - # Inject our fake JWTValidator object. Can be over-ridden by tests if they want + # Create a side_effect function that dynamically extracts the task instance ID from validators + def smart_validated_claims(cred, validators=None): + # Extract task instance ID from validators if present + # This handles the JWTBearerTIPathDep case where the validator contains the task ID from the path + if ( + validators + and "sub" in validators + and isinstance(validators["sub"], dict) + and "value" in validators["sub"] + ): + return { + "sub": validators["sub"]["value"], + "exp": 9999999999, # Far future expiration + "iat": 1000000000, # Past issuance time + "aud": "test-audience", + } + + # For other cases (like JWTBearerDep) where no specific validators are provided + # Return a default UUID with all required claims + return { + "sub": "00000000-0000-0000-0000-000000000000", + "exp": 9999999999, # Far future expiration + "iat": 1000000000, # Past issuance time + "aud": "test-audience", + } + + # Set the side_effect for avalidated_claims + auth.avalidated_claims.side_effect = smart_validated_claims lifespan.registry.register_value(JWTValidator, auth) + yield client diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 77e3e7df3ebbd..21c733a5cced2 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -17,7 +17,6 @@ from __future__ import annotations -import operator from datetime import datetime from unittest import mock from uuid import uuid4 @@ -1084,22 +1083,18 @@ def test_ti_skip_downstream(self, client, session, create_task_instance, dag_mak t1 = EmptyOperator(task_id="t1") t0 >> t1 dr = dag_maker.create_dagrun(run_id="run") - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): - # TODO: TaskSDK #45549 - ti.task = dag_maker.dag.get_task(ti.task_id) - ti.run(session=session) - t0 = dr.get_task_instance("t0") + ti0 = dr.get_task_instance("t0") + ti0.set_state(State.SUCCESS) + response = client.patch( - f"/execution/task-instances/{t0.id}/skip-downstream", + f"/execution/task-instances/{ti0.id}/skip-downstream", json=_json, ) - t1 = dr.get_task_instance("t1") + ti1 = dr.get_task_instance("t1") assert response.status_code == 204 - assert decision.schedulable_tis[0].state == State.SUCCESS - assert t1.state == State.SKIPPED + assert ti1.state == State.SKIPPED class TestTIHealthEndpoint: diff --git a/airflow-core/tests/unit/listeners/test_listeners.py b/airflow-core/tests/unit/listeners/test_listeners.py index 68337c905b24c..3fceaaf0843cc 100644 --- a/airflow-core/tests/unit/listeners/test_listeners.py +++ b/airflow-core/tests/unit/listeners/test_listeners.py @@ -120,14 +120,13 @@ def test_listener_gets_only_subscribed_calls(create_task_instance, session=None) @provide_session -def test_listener_suppresses_exceptions(create_task_instance, session, caplog): +def test_listener_suppresses_exceptions(create_task_instance, session, cap_structlog): lm = get_listener_manager() lm.add_listener(throwing_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) - with caplog.at_level(logging.ERROR): - ti._run_raw_task() - assert "error calling listener" in caplog.messages + ti.run() + assert "error calling listener" in cap_structlog @provide_session @@ -139,7 +138,7 @@ def test_listener_captures_failed_taskinstances(create_task_instance_of_operator BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="exit 1" ) with pytest.raises(AirflowException): - ti._run_raw_task() + ti.run() assert full_listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] assert len(full_listener.state) == 2 @@ -153,7 +152,7 @@ def test_listener_captures_longrunning_taskinstances(create_task_instance_of_ope ti = create_task_instance_of_operator( BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="sleep 5" ) - ti._run_raw_task() + ti.run() assert full_listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] assert len(full_listener.state) == 2 @@ -166,13 +165,9 @@ def test_class_based_listener(create_task_instance, session=None): lm.add_listener(listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) - # Using ti.run() instead of ti._run_raw_task() to capture state change to RUNNING - # that only happens on `check_and_change_state_before_execution()` that is called before - # `run()` calls `_run_raw_task()` ti.run() - assert len(listener.state) == 2 - assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] + assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, DagRunState.SUCCESS] def test_listener_logs_call(caplog, create_task_instance, session): @@ -181,10 +176,9 @@ def test_listener_logs_call(caplog, create_task_instance, session): lm.add_listener(full_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) - ti._run_raw_task() + ti.run() listener_logs = [r for r in caplog.record_tuples if r[0] == "airflow.listeners.listener"] - assert len(listener_logs) == 6 assert all(r[:-1] == ("airflow.listeners.listener", logging.DEBUG) for r in listener_logs) assert listener_logs[0][-1].startswith("Calling 'on_task_instance_running' with {'") assert listener_logs[1][-1].startswith("Hook impls: [" - assert body.startswith("Try 0") # try number only incremented by the scheduler - assert "test_email_alert" in body - def test_set_duration(self): task = EmptyOperator(task_id="op", email="test@test.test") ti = TI(task=task) @@ -2109,173 +1558,6 @@ def test_set_duration_empty_dates(self): ti.set_duration() assert ti.duration is None - def test_success_callback_no_race_condition(self, create_task_instance): - callback_wrapper = CallbackWrapper() - ti = create_task_instance( - on_success_callback=callback_wrapper.success_handler, - end_date=timezone.utcnow() + datetime.timedelta(days=10), - logical_date=timezone.utcnow(), - state=State.RUNNING, - ) - - session = settings.Session() - session.merge(ti) - session.commit() - - callback_wrapper.wrap_task_instance(ti) - ti._run_raw_task() - assert callback_wrapper.callback_ran - assert callback_wrapper.task_state_in_callback == State.SUCCESS - ti.refresh_from_db() - assert ti.state == State.SUCCESS - - def test_outlet_assets(self, create_task_instance, testing_dag_bundle): - """ - Verify that when we have an outlet asset on a task, and the task - completes successfully, an AssetDagRunQueue is logged. - """ - from airflow.example_dags import example_assets - from airflow.example_dags.example_assets import dag1 - - session = settings.Session() - dagbag = DagBag(dag_folder=example_assets.__file__) - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) - - asset_models = session.scalars(select(AssetModel)).all() - SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) - session.flush() - - run_id = str(uuid4()) - dr = DagRun( - dag1.dag_id, - run_id=run_id, - run_type="manual", - state=DagRunState.RUNNING, - logical_date=timezone.utcnow(), - ) - session.merge(dr) - task = dag1.get_task("producing_task_1") - task.bash_command = "echo 1" # make it go faster - ti = TaskInstance(task, run_id=run_id) - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == TaskInstanceState.SUCCESS - - # check that no other asset events recorded - event = ( - session.query(AssetEvent) - .join(AssetEvent.asset) - .filter(AssetEvent.source_task_instance == ti) - .one() - ) - assert event - assert event.asset - - # check that one queue record created for each dag that depends on asset 1 - assert session.query(AssetDagRunQueue.target_dag_id).filter_by(asset_id=event.asset.id).order_by( - AssetDagRunQueue.target_dag_id - ).all() == [ - ("asset_consumes_1",), - ("asset_consumes_1_and_2",), - ("asset_consumes_1_never_scheduled",), - ("conditional_asset_and_time_based_timetable",), - ("consume_1_and_2_with_asset_expressions",), - ("consume_1_or_2_with_asset_expressions",), - ("consume_1_or_both_2_and_3_with_asset_expressions",), - ] - - # check that one event record created for asset1 and this TI - assert session.query(AssetModel.uri).join(AssetEvent.asset).filter( - AssetEvent.source_task_instance == ti - ).one() == ("s3://dag1/output_1.txt",) - - # check that the asset event has an earlier timestamp than the ADRQ's - adrq_timestamps = session.query(AssetDagRunQueue.created_at).filter_by(asset_id=event.asset.id).all() - assert all(event.timestamp < adrq_timestamp for (adrq_timestamp,) in adrq_timestamps), ( - f"Some items in {[str(t) for t in adrq_timestamps]} are earlier than {event.timestamp}" - ) - - def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): - """ - Verify that when we have an outlet asset on a task, and the task - failed, an AssetDagRunQueue is not logged, and an AssetEvent is - not generated - """ - from unit.dags import test_assets - from unit.dags.test_assets import dag_with_fail_task - - session = settings.Session() - dagbag = DagBag(dag_folder=test_assets.__file__) - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) - run_id = str(uuid4()) - dr = DagRun( - dag_with_fail_task.dag_id, - run_id=run_id, - run_type="manual", - state=DagRunState.RUNNING, - logical_date=timezone.utcnow(), - ) - session.merge(dr) - task = dag_with_fail_task.get_task("fail_task") - ti = TaskInstance(task, run_id=run_id) - session.merge(ti) - session.commit() - with pytest.raises(AirflowFailException): - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == TaskInstanceState.FAILED - - # check that no dagruns were queued - assert session.query(AssetDagRunQueue).count() == 0 - - # check that no asset events were generated - assert session.query(AssetEvent).count() == 0 - - def test_outlet_assets_skipped(self, testing_dag_bundle): - """ - Verify that when we have an outlet asset on a task, and the task - is skipped, an AssetDagRunQueue is not logged, and an AssetEvent is - not generated - """ - from unit.dags import test_assets - from unit.dags.test_assets import dag_with_skip_task - - session = settings.Session() - dagbag = DagBag(dag_folder=test_assets.__file__) - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) - - asset_models = session.scalars(select(AssetModel)).all() - SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) - session.flush() - - run_id = str(uuid4()) - dr = DagRun( - dag_with_skip_task.dag_id, - run_id=run_id, - run_type="manual", - state=DagRunState.RUNNING, - logical_date=timezone.utcnow(), - ) - session.merge(dr) - task = dag_with_skip_task.get_task("skip_task") - ti = TaskInstance(task, run_id=run_id) - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == TaskInstanceState.SKIPPED - - # check that no dagruns were queued - assert session.query(AssetDagRunQueue).count() == 0 - - # check that no asset events were generated - assert session.query(AssetEvent).count() == 0 - @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -2339,77 +1621,9 @@ def write(*, outlet_events): assert event.source_task_id == "write" assert event.extra == {"one": 1} - @pytest.mark.want_activate_assets(True) - def test_outlet_asset_extra_yield(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - from airflow.sdk.definitions.asset.metadata import Metadata - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset("test_outlet_asset_extra_1")) - def write1(): - result = "write_1 result" - yield Metadata(Asset(name="test_outlet_asset_extra_1"), {"foo": "bar"}) - return result - - write1() - - def _write2_post_execute(context, result): - yield Metadata(Asset(name="test_outlet_asset_extra_2", uri="test://asset-2"), extra={"x": 1}) - - BashOperator( - task_id="write2", - bash_command=":", - outlets=Asset(name="test_outlet_asset_extra_2", uri="test://asset-2"), - post_execute=_write2_post_execute, - ) - - @task(outlets=Asset("test_outlet_asset_extra_3")) - def write3(): - result = "write_3 result" - yield Metadata(Asset(name="test_outlet_asset_extra_3")) - return result - - write3() - - dr: DagRun = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - xcom = session.scalars( - select(XComModel).filter_by( - dag_id=dr.dag_id, run_id=dr.run_id, task_id="write1", key="return_value" - ) - ).one() - assert xcom.value == json.dumps("write_1 result") - - events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) - assert set(events) == {"write1", "write2", "write3"} - - assert events["write1"].source_dag_id == dr.dag_id - assert events["write1"].source_run_id == dr.run_id - assert events["write1"].source_task_id == "write1" - assert events["write1"].asset.uri == "test_outlet_asset_extra_1" - assert events["write1"].asset.name == "test_outlet_asset_extra_1" - assert events["write1"].extra == {"foo": "bar"} - - assert events["write2"].source_dag_id == dr.dag_id - assert events["write2"].source_run_id == dr.run_id - assert events["write2"].source_task_id == "write2" - assert events["write2"].asset.uri == "test://asset-2/" - assert events["write2"].asset.name == "test_outlet_asset_extra_2" - assert events["write2"].extra == {"x": 1} - - assert events["write3"].source_dag_id == dr.dag_id - assert events["write3"].source_run_id == dr.run_id - assert events["write3"].source_task_id == "write3" - assert events["write3"].asset.uri == "test_outlet_asset_extra_3" - assert events["write3"].asset.name == "test_outlet_asset_extra_3" - assert events["write3"].extra == {} - @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset_uri = "test_outlet_asset_alias_test_case_ds" alias_name_1 = "test_outlet_asset_alias_test_case_asset_alias_1" @@ -2457,7 +1671,7 @@ def producer(*, outlet_events): @pytest.mark.want_activate_assets(True) def test_outlet_multiple_asset_alias(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset_uri = "test_outlet_maa_ds" asset_alias_name_1 = "test_outlet_maa_asset_alias_1" @@ -2530,7 +1744,6 @@ def producer(*, outlet_events): @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_through_metadata(self, dag_maker, session): - from airflow.sdk.definitions.asset import AssetAlias from airflow.sdk.definitions.asset.metadata import Metadata asset_uri = "test_outlet_asset_alias_through_metadata_ds" @@ -2574,7 +1787,7 @@ def producer(*, outlet_events): @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset_alias_name = "test_outlet_asset_alias_asset_not_exists_asset_alias" asset_uri = "does_not_exist" @@ -2603,7 +1816,7 @@ def producer(*, outlet_events): assert session.scalars(asset_event_check_stmt).one().uri == asset_uri def test_outlet_asset_alias_asset_inactive(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset1 = Asset("asset1") asset2 = Asset("asset2") @@ -2743,89 +1956,6 @@ def read(*, inlet_events): assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated - @pytest.mark.want_activate_assets(True) - @pytest.mark.need_serialized_dag - def test_inlet_asset_alias_extra(self, dag_maker, session, mock_supervisor_comms): - from airflow.sdk.definitions.asset import Asset, AssetAlias - - mock_supervisor_comms.get_message.return_value = AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - created_dagruns=[], - timestamp=timezone.utcnow(), - extra={"from": f"write{i}"}, - asset=AssetResponse( - name="test_inlet_asset_extra_ds", uri="test_inlet_asset_extra_ds", group="asset" - ), - ) - for i in (1, 2, 3) - ] - ) - - asset_uri = "test_inlet_asset_extra_ds" - asset_alias_name = "test_inlet_asset_extra_asset_alias" - - asset_model = AssetModel(id=1, uri=asset_uri, group="asset") - asset_alias_model = AssetAliasModel(name=asset_alias_name) - asset_alias_model.assets.append(asset_model) - session.add_all([asset_model, asset_alias_model, AssetActive.for_asset(Asset(asset_uri))]) - session.commit() - - read_task_evaluated = False - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=AssetAlias(asset_alias_name)) - def write(*, ti, outlet_events): - outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"from": ti.task_id}) - - @task(inlets=AssetAlias(asset_alias_name)) - def read(*, inlet_events): - second_event = inlet_events[AssetAlias(asset_alias_name)][1] - assert second_event.asset.uri == asset_uri - assert second_event.extra == {"from": "write2"} - - last_event = inlet_events[AssetAlias(asset_alias_name)][-1] - assert last_event.asset.uri == asset_uri - assert last_event.extra == {"from": "write3"} - - with pytest.raises(KeyError): - inlet_events[Asset("does_not_exist")] - with pytest.raises(KeyError): - inlet_events[AssetAlias("does_not_exist")] - with pytest.raises(IndexError): - inlet_events[AssetAlias(asset_alias_name)][5] - - nonlocal read_task_evaluated - read_task_evaluated = True - - [ - write.override(task_id="write1")(), - write.override(task_id="write2")(), - write.override(task_id="write3")(), - ] >> read() - - dr: DagRun = dag_maker.create_dagrun() - - # Run "write1", "write2", and "write3" (in this order). - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): - # TODO: TaskSDK #45549 - ti.task = dag_maker.dag.get_task(ti.task_id) - ti.run(session=session) - - # Run "read". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - # TODO: TaskSDK #45549 - ti.task = dag_maker.dag.get_task(ti.task_id) - ti.run(session=session) - - # Should be done. - assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - assert read_task_evaluated - @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_comms): asset_alias_name = "test_inlet_asset_extra_asset_alias" @@ -2835,8 +1965,6 @@ def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_ session.add(asset_alias_model) session.commit() - from airflow.sdk.definitions.asset import AssetAlias - with dag_maker(schedule=None, session=session): @task(inlets=AssetAlias(asset_alias_name)) @@ -2855,146 +1983,6 @@ def read(*, inlet_events): # Should be done. assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - @pytest.mark.want_activate_assets(True) - @pytest.mark.parametrize( - "slicer, expected", - [ - (lambda x: x[-2:], [{"from": 8}, {"from": 9}]), - (lambda x: x[-5:-3], [{"from": 5}, {"from": 6}]), - (lambda x: x[:-8], [{"from": 0}, {"from": 1}]), - (lambda x: x[1:-7], [{"from": 1}, {"from": 2}]), - (lambda x: x[-8:4], [{"from": 2}, {"from": 3}]), - (lambda x: x[-5:5], []), - ], - ) - def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected, mock_supervisor_comms): - from airflow.sdk.definitions.asset import Asset - - asset_uri = "test_inlet_asset_extra_slice" - mock_supervisor_comms.get_message.return_value = AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - created_dagruns=[], - timestamp=timezone.utcnow(), - extra={"from": i}, - asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"), - ) - for i in range(0, 10) - ] - ) - - with dag_maker(dag_id="write", serialized=True, schedule="@daily", params={"i": -1}, session=session): - - @task(outlets=Asset(asset_uri)) - def write(*, params, outlet_events): - outlet_events[Asset(asset_uri)].extra = {"from": params["i"]} - - write() - - # Run the write DAG 10 times. - dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, conf={"i": 0}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - for i in range(1, 10): - dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, conf={"i": i}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - result = "the task does not run" - - with dag_maker(dag_id="read", schedule=None, session=session): - - @task(inlets=Asset(asset_uri)) - def read(*, inlet_events): - nonlocal result - events = inlet_events[Asset(asset_uri)] - result = [e.extra for e in slicer(events)] - - read() - - # Run the read DAG. - dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - # Should be done. - assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - assert result == expected - - @pytest.mark.parametrize( - "slicer, expected", - [ - (lambda x: x[-2:], [{"from": 8}, {"from": 9}]), - (lambda x: x[-5:-3], [{"from": 5}, {"from": 6}]), - (lambda x: x[:-8], [{"from": 0}, {"from": 1}]), - (lambda x: x[1:-7], [{"from": 1}, {"from": 2}]), - (lambda x: x[-8:4], [{"from": 2}, {"from": 3}]), - (lambda x: x[-5:5], []), - ], - ) - @pytest.mark.want_activate_assets(True) - def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expected, mock_supervisor_comms): - from airflow.sdk.definitions.asset import Asset - - asset_uri = "test_inlet_asset_alias_extra_slice_ds" - mock_supervisor_comms.get_message.return_value = AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - created_dagruns=[], - timestamp=timezone.utcnow(), - extra={"from": i}, - asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"), - ) - for i in range(0, 10) - ] - ) - asset_alias_name = "test_inlet_asset_alias_extra_slice_asset_alias" - - asset_model = AssetModel(id=1, uri=asset_uri) - asset_alias_model = AssetAliasModel(name=asset_alias_name) - asset_alias_model.assets.append(asset_model) - session.add_all([asset_model, asset_alias_model, AssetActive.for_asset(Asset(asset_uri))]) - session.commit() - - with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, serialized=True, session=session): - - @task(outlets=AssetAlias(asset_alias_name)) - def write(*, params, outlet_events): - outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), {"from": params["i"]}) - - write() - - # Run the write DAG 10 times. - dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, conf={"i": 0}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - for i in range(1, 10): - dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, conf={"i": i}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - result = "the task does not run" - - with dag_maker(dag_id="read", schedule=None, serialized=True, session=session): - - @task(inlets=AssetAlias(asset_alias_name)) - def read(*, inlet_events): - nonlocal result - result = [e.extra for e in slicer(inlet_events[AssetAlias(asset_alias_name)])] - - read() - - # Run the read DAG. - dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - # Should be done. - assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - assert result == expected - def test_changing_of_asset_when_adrq_is_already_populated(self, dag_maker): """ Test that when a task that produces asset has ran, that changing the consumer @@ -3291,76 +2279,6 @@ def test_template_with_json_variable_missing(self, create_task_instance, session with pytest.raises(KeyError): ti.task.render_template('{{ var.json.get("missing_variable") }}', context) - def test_execute_callback(self, create_task_instance): - called = False - - def on_execute_callable(context): - nonlocal called - called = True - assert context["dag_run"].dag_id == "test_dagrun_execute_callback" - - for i, callback_input in enumerate([[on_execute_callable], on_execute_callable]): - ti = create_task_instance( - dag_id=f"test_execute_callback_{i}", - on_execute_callback=callback_input, - state=State.RUNNING, - ) - - session = settings.Session() - - session.merge(ti) - session.commit() - - ti._run_raw_task() - assert called - ti.refresh_from_db() - assert ti.state == State.SUCCESS - - def test_finished_callbacks_callable_handle_and_log_exception(self, caplog): - called = completed = False - - def on_finish_callable(context): - nonlocal called, completed - called = True - raise KeyError - completed = True - - for callback_input in [[on_finish_callable], on_finish_callable]: - called = completed = False - caplog.clear() - _run_finished_callback(callbacks=callback_input, context={}) - - assert called - assert not completed - callback_name = callback_input[0] if isinstance(callback_input, list) else callback_input - callback_name = qualname(callback_name).split(".")[-1] - assert "Executing callback at index 0: on_finish_callable" in caplog.text - assert "Error in callback at index 0: on_finish_callable" in caplog.text - - def test_finished_callbacks_notifier_handle_and_log_exception(self, caplog): - class OnFinishNotifier(BaseNotifier): - """ - error captured by BaseNotifier - """ - - def __init__(self, error: bool): - super().__init__() - self.raise_error = error - - def notify(self, context): - self.execute() - - def execute(self) -> None: - if self.raise_error: - raise KeyError - - caplog.clear() - callbacks = [OnFinishNotifier(error=False), OnFinishNotifier(error=True)] - _run_finished_callback(callbacks=callbacks, context={}) - assert "Executing callback at index 0: OnFinishNotifier" in caplog.text - assert "Executing callback at index 1: OnFinishNotifier" in caplog.text - assert "KeyError" in caplog.text - @provide_session def test_handle_failure(self, create_dummy_dag, session=None): start_date = timezone.datetime(2016, 6, 1) @@ -3622,30 +2540,6 @@ def fail(): ti.run() assert ti.state == State.UP_FOR_RETRY - @patch.object(TaskInstance, "logger") - def test_stacktrace_on_failure_starts_with_task_execute_method(self, mock_get_log, dag_maker): - def fail(): - raise AirflowException("maybe this will pass?") - - with dag_maker(dag_id="test_retries_on_other_exceptions"): - task = PythonOperator( - task_id="test_raise_other_exception", - python_callable=fail, - retries=1, - ) - ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0] - ti.task = task - mock_log = mock.Mock() - mock_get_log.return_value = mock_log - with pytest.raises(AirflowException): - ti.run() - mock_log.error.assert_called_once() - assert mock_log.error.call_args.args == ("Task failed with exception",) - exc_info = mock_log.error.call_args.kwargs["exc_info"] - filename = exc_info[2].tb_frame.f_code.co_filename - formatted_exc = format_exception(*exc_info) - assert sys.modules[TaskInstance.__module__].__file__ == filename, "".join(formatted_exc) - def _env_var_check_callback(self): assert os.environ["AIRFLOW_CTX_DAG_ID"] == "test_echo_env_variables" assert os.environ["AIRFLOW_CTX_TASK_ID"] == "hive_in_python_op" @@ -3657,198 +2551,6 @@ def _env_var_check_callback(self): == os.environ["AIRFLOW_CTX_DAG_RUN_ID"] ) - def test_echo_env_variables(self, dag_maker): - with dag_maker( - "test_echo_env_variables", - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ): - op = PythonOperator(task_id="hive_in_python_op", python_callable=self._env_var_check_callback) - dr = dag_maker.create_dagrun( - run_type=DagRunType.MANUAL, - ) - ti = dr.get_task_instance(op.task_id) - ti.state = State.RUNNING - session = settings.Session() - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == State.SUCCESS - - @pytest.mark.parametrize( - "code, expected_state", - [ - pytest.param(1, State.FAILED, id="code-positive-number"), - pytest.param(-1, State.FAILED, id="code-negative-number"), - pytest.param("error", State.FAILED, id="code-text"), - pytest.param(0, State.SUCCESS, id="code-zero"), - pytest.param(None, State.SUCCESS, id="code-none"), - ], - ) - def test_handle_system_exit_failed(self, dag_maker, code, expected_state): - with dag_maker(): - - def f(*args, **kwargs): - exit(code) - - task = PythonOperator(task_id="mytask", python_callable=f) - - dr = dag_maker.create_dagrun() - ti = dr.get_task_instance(task.task_id) - ti.state = State.RUNNING - session = settings.Session() - session.merge(ti) - session.commit() - - if expected_state == State.SUCCESS: - ctx = contextlib.nullcontext() - else: - ctx = pytest.raises(AirflowException, match=rf"Task failed due to SystemExit\({code}\)") - - with ctx: - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == expected_state - - def test_get_current_context_works_in_template(self, dag_maker): - def user_defined_macro(): - from airflow.sdk import get_current_context - - get_current_context() - - with dag_maker( - "test_context_inside_template", - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10), - user_defined_macros={"user_defined_macro": user_defined_macro}, - ): - - def foo(arg): - print(arg) - - PythonOperator( - task_id="context_inside_template", - python_callable=foo, - op_kwargs={"arg": "{{ user_defined_macro() }}"}, - ) - dagrun = dag_maker.create_dagrun() - tis = dagrun.get_task_instances() - ti: TaskInstance = next(x for x in tis if x.task_id == "context_inside_template") - ti._run_raw_task() - assert ti.state == State.SUCCESS - - @patch.object(Stats, "incr") - def test_task_stats(self, stats_mock, create_task_instance): - ti = create_task_instance( - dag_id="test_task_start_end_stats", - end_date=timezone.utcnow() + datetime.timedelta(days=10), - state=State.RUNNING, - ) - stats_mock.reset_mock() - - session = settings.Session() - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - stats_mock.assert_any_call( - f"ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}", - tags={"dag_id": ti.dag_id, "task_id": ti.task_id}, - ) - stats_mock.assert_any_call( - "ti.finish", - tags={"dag_id": ti.dag_id, "task_id": ti.task_id, "state": ti.state}, - ) - for state in State.task_states: - assert ( - call( - f"ti.finish.{ti.dag_id}.{ti.task_id}.{state}", - count=0, - tags={"dag_id": ti.dag_id, "task_id": ti.task_id}, - ) - in stats_mock.mock_calls - ) - assert ( - call( - "ti.finish", - count=0, - tags={"dag_id": ti.dag_id, "task_id": ti.task_id, "state": str(state)}, - ) - in stats_mock.mock_calls - ) - assert ( - call(f"ti.start.{ti.dag_id}.{ti.task_id}", tags={"dag_id": ti.dag_id, "task_id": ti.task_id}) - in stats_mock.mock_calls - ) - assert call("ti.start", tags={"dag_id": ti.dag_id, "task_id": ti.task_id}) in stats_mock.mock_calls - assert stats_mock.call_count == (2 * len(State.task_states)) + 7 - - def test_command_as_list(self, dag_maker): - with dag_maker(): - PythonOperator(python_callable=print, task_id="hi") - dr = dag_maker.create_dagrun() - ti = dr.task_instances[0] - assert ti.command_as_list() == [ - "airflow", - "tasks", - "run", - ti.dag_id, - ti.task_id, - ti.run_id, - "--subdir", - "DAGS_FOLDER/test_taskinstance.py", - ] - - def test_generate_command_default_param(self): - dag_id = "test_generate_command_default_param" - task_id = "task" - assert_command = ["airflow", "tasks", "run", dag_id, task_id, "run_1"] - generate_command = TI.generate_command(dag_id=dag_id, task_id=task_id, run_id="run_1") - assert assert_command == generate_command - - def test_generate_command_specific_param(self): - dag_id = "test_generate_command_specific_param" - task_id = "task" - assert_command = [ - "airflow", - "tasks", - "run", - dag_id, - task_id, - "run_1", - "--mark-success", - "--map-index", - "0", - ] - generate_command = TI.generate_command( - dag_id=dag_id, task_id=task_id, run_id="run_1", mark_success=True, map_index=0 - ) - assert assert_command == generate_command - - @provide_session - def test_get_rendered_template_fields(self, dag_maker, session=None): - with dag_maker("test-dag", session=session) as dag: - task = BashOperator(task_id="op1", bash_command="{{ task.task_id }}") - dag.fileloc = TEST_DAGS_FOLDER / "test_get_k8s_pod_yaml.py" - ti = dag_maker.create_dagrun().task_instances[0] - ti.task = task - - session.add(RenderedTaskInstanceFields(ti)) - session.flush() - - # Create new TI for the same Task - new_task = BashOperator(task_id="op12", bash_command="{{ task.task_id }}", dag=dag) - - new_ti = TI(task=new_task, run_id=ti.run_id) - new_ti.get_rendered_template_fields(session=session) - - assert ti.task.bash_command == "op1" - - # CleanUp - with create_session() as session: - session.query(RenderedTaskInstanceFields).delete() - def test_set_state_up_for_retry(self, create_task_instance): ti = create_task_instance(state=State.RUNNING) @@ -4064,123 +2766,6 @@ def duplicate_asset_task_in_outlet(*, outlet_events): assert "Asset(name='asset_second', uri='asset_second')" in str(exc.value) assert "Asset(name='asset_first', uri='test://asset/')" in str(exc.value) - @pytest.mark.want_activate_assets(True) - def test_run_with_inactive_assets_in_outlets_within_the_same_dag(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset("asset_first")) - def first_asset_task(*, outlet_events): - outlet_events[Asset("asset_first")].extra = {"foo": "bar"} - - @task(outlets=Asset(name="asset_first", uri="test://asset")) - def duplicate_asset_task(*, outlet_events): - outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} - - first_asset_task() >> duplicate_asset_task() - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - tis["first_asset_task"].run(session=session) - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["duplicate_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - - @pytest.mark.skip( - reason="This test has some issues that were surfaced when dag_maker started allowing multiple serdag versions. Issue #48539 will track fixing this." - ) - @pytest.mark.want_activate_assets(True) - def test_run_with_inactive_assets_in_outlets_in_different_dag(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset("asset_first")) - def first_asset_task(*, outlet_events): - outlet_events[Asset("asset_first")].extra = {"foo": "bar"} - - first_asset_task() - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset(name="asset_first", uri="test://asset")) - def duplicate_asset_task(*, outlet_events): - outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} - - duplicate_asset_task() - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["duplicate_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - - @pytest.mark.want_activate_assets(False) - def test_run_with_inactive_assets_in_inlets_within_the_same_dag(self, dag_maker, session): - valid_asset = Asset("asset_first") - conflict_asset = Asset(name="asset_first", uri="test://asset/") - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(inlets=valid_asset) - def first_asset_task(): - pass - - @task(inlets=conflict_asset) - def conflict_asset_task(): - pass - - first_asset_task() >> conflict_asset_task() - - session.execute(delete(AssetActive)) - session.add(AssetActive.for_asset(valid_asset)) - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - tis["first_asset_task"].run(session=session) - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["conflict_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - - @pytest.mark.want_activate_assets(True) - def test_run_with_inactive_assets_in_inlets_in_different_dag(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(inlets=Asset("asset_first")) - def first_asset_task(*, outlet_events): - pass - - first_asset_task() - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(inlets=Asset(name="asset_first", uri="test://asset")) - def duplicate_asset_task(*, outlet_events): - pass - - duplicate_asset_task() - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["duplicate_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) @@ -4342,343 +2927,6 @@ def tg(arg): tis["push_4"].run() assert dag_maker.session.query(TaskMap).count() == 2 - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - ("abc", UnmappableXComTypePushed, "unmappable return type 'str'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_expand_error_if_unmappable_type(self, dag_maker, return_value, exception_type, error_message): - """If an unmappable return value is used for expand(), fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_expand_error_if_unmappable_type") as dag: - - @dag.task() - def push_something(): - return return_value - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - (123, UnmappableXComTypePushed, "unmappable return type 'int'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_expand_kwargs_error_if_unmappable_type( - self, - dag_maker, - return_value, - exception_type, - error_message, - ): - """If an unmappable return value is used for expand_kwargs(), fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_expand_kwargs_error_if_unmappable_type") as dag: - - @dag.task() - def push(): - return return_value - - MockOperator.partial(task_id="pull").expand_kwargs(push()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - (123, UnmappableXComTypePushed, "unmappable return type 'int'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_task_group_expand_error_if_unmappable_type( - self, - dag_maker, - return_value, - exception_type, - error_message, - ): - """If an unmappable return value is used , fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_task_group_expand_error_if_unmappable_type") as dag: - - @dag.task() - def push(): - return return_value - - @task_group - def tg(arg): - MockOperator(task_id="pull", arg1=arg) - - tg.expand(arg=push()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - (123, UnmappableXComTypePushed, "unmappable return type 'int'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_task_group_expand_kwargs_error_if_unmappable_type( - self, - dag_maker, - return_value, - exception_type, - error_message, - ): - """If an unmappable return value is used, fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_task_group_expand_kwargs_error_if_unmappable_type") as dag: - - @dag.task() - def push(): - return return_value - - @task_group - def tg(arg): - MockOperator(task_id="pull", arg1=arg) - - tg.expand_kwargs(push()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "create_upstream", - [ - # The task returns an invalid expand_kwargs() input (a list[int] instead of list[dict]). - pytest.param(lambda: task(task_id="push")(lambda: [0])(), id="normal"), - # This task returns a list[dict] (correct), but we use map() to transform it to list[int] (wrong). - pytest.param(lambda: task(task_id="push")(lambda: [{"v": ""}])().map(lambda _: 0), id="mapped"), - ], - ) - def test_expand_kwargs_error_if_received_invalid(self, dag_maker, session, create_upstream): - with dag_maker(dag_id="test_expand_kwargs_error_if_received_invalid", session=session): - push_task = create_upstream() - - @task() - def pull(v): - print(v) - - pull.expand_kwargs(push_task) - - dr = dag_maker.create_dagrun() - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - for ti in decision.schedulable_tis: - ti.run() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - for ti in decision.schedulable_tis: - with pytest.raises(ValueError) as ctx: - ti.run() - assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[int]" - - @pytest.mark.parametrize( - "downstream, error_message", - [ - ("taskflow", "mapping already partial argument: arg2"), - ("classic", "unmappable or already specified argument: arg2"), - ], - ids=["taskflow", "classic"], - ) - @pytest.mark.parametrize("strict", [True, False], ids=["strict", "override"]) - def test_expand_kwargs_override_partial(self, dag_maker, session, downstream, error_message, strict): - class ClassicOperator(MockOperator): - def execute(self, context): - return (self.arg1, self.arg2) - - with dag_maker(dag_id="test_expand_kwargs_override_partial", session=session) as dag: - - @dag.task() - def push(): - return [{"arg1": "a"}, {"arg1": "b", "arg2": "c"}] - - push_task = push() - - ClassicOperator.partial(task_id="classic", arg2="d").expand_kwargs(push_task, strict=strict) - - @dag.task(task_id="taskflow") - def pull(arg1, arg2): - return (arg1, arg2) - - pull.partial(arg2="d").expand_kwargs(push_task, strict=strict) - - dr = dag_maker.create_dagrun() - next(ti for ti in dr.task_instances if ti.task_id == "push").run() - - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index, ti.state): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [ - ("classic", 0, None), - ("classic", 1, None), - ("taskflow", 0, None), - ("taskflow", 1, None), - ] - - ti = tis[(downstream, 0, None)] - ti.run() - ti.xcom_pull(task_ids=downstream, map_indexes=0, session=session) == ["a", "d"] - - ti = tis[(downstream, 1, None)] - if strict: - with pytest.raises(TypeError) as ctx: - ti.run() - assert str(ctx.value) == error_message - else: - ti.run() - ti.xcom_pull(task_ids=downstream, map_indexes=1, session=session) == ["b", "c"] - - def test_error_if_upstream_does_not_push(self, dag_maker): - """Fail the upstream task if it fails to push the XCom used for task mapping.""" - with dag_maker(dag_id="test_not_recorded_for_unused") as dag: - - @dag.task(do_xcom_push=False) - def push_something(): - return [1, 2] - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") - with pytest.raises(XComForMappingNotPushed) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == "did not push XCom for task mapping" - - @conf_vars({("core", "max_map_length"): "1"}) - def test_error_if_unmappable_length(self, dag_maker): - """If an unmappable return value is used to map, fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_not_recorded_for_unused") as dag: - - @dag.task() - def push_something(): - return [1, 2] - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") - with pytest.raises(UnmappableXComLengthPushed) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == "unmappable return value length: 2 > 1" - - @pytest.mark.parametrize( - "xcom_value, expected_length, expected_keys", - [ - ([1, 2, 3], 3, None), - ({"a": 1, "b": 2}, 2, ["a", "b"]), - ], - ) - def test_written_task_map(self, dag_maker, xcom_value, expected_length, expected_keys): - """Return value should be recorded in TaskMap if it's used by a downstream to map.""" - with dag_maker(dag_id="test_written_task_map") as dag: - - @dag.task() - def push_something(): - return xcom_value - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - dag_run = dag_maker.create_dagrun() - ti = next(ti for ti in dag_run.task_instances if ti.task_id == "push_something") - ti.run() - - task_map = dag_maker.session.query(TaskMap).one() - assert task_map.dag_id == "test_written_task_map" - assert task_map.task_id == "push_something" - assert task_map.run_id == dag_run.run_id - assert task_map.map_index == -1 - assert task_map.length == expected_length - assert task_map.keys == expected_keys - - @pytest.mark.xfail( - reason="not clear what this is really testing; " - "there's no API for removing a task; " - "and when a serialized dag is there, this fails; " - "and we need a serialized dag for dag.clear to work now" - ) - def test_no_error_on_changing_from_non_mapped_to_mapped(self, dag_maker, session): - """If a task changes from non-mapped to mapped, don't fail on integrity error.""" - with dag_maker(dag_id="test_no_error_on_changing_from_non_mapped_to_mapped") as dag: - - @dag.task() - def add_one(x): - return [x + 1] - - @dag.task() - def add_two(x): - return x + 2 - - task1 = add_one(2) - add_two.expand(x=task1) - - dr = dag_maker.create_dagrun() - ti = dr.get_task_instance(task_id="add_one") - ti.run() - assert ti.state == TaskInstanceState.SUCCESS - dag._remove_task("add_one") - with dag: - task1 = add_one.expand(x=[1, 2, 3]).operator - serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dr.dag = serialized_dag - dr.verify_integrity(session=session) - ti = dr.get_task_instance(task_id="add_one") - assert ti.state == TaskInstanceState.REMOVED - dag.clear() - ti.refresh_from_task(task1) - # This should not raise an integrity error - dr.task_instance_scheduling_decisions() - class TestMappedTaskInstanceReceiveValue: @pytest.mark.parametrize( @@ -4714,29 +2962,6 @@ def show(value): ti.run() assert outputs == expected_outputs - def test_map_has_dag_version(self, dag_maker, session): - from airflow.models.dag_version import DagVersion - - known_versions = [] - - with dag_maker(dag_id="test_89eug7u6f7y", session=session) as dag: - - @dag.task - def show(value, *, ti): - # let's record the dag version ids we observe on the tis - known_versions.append(ti.dag_version_id) - - show.expand(value=[1, 2, 3]) - # get the dag version for the dag - dag_version = session.scalar(select(DagVersion).where(DagVersion.dag_id == dag.dag_id)) - dag_maker.create_dagrun(session=session) - task = dag.get_task("show") - for ti in session.scalars(select(TI)): - ti.refresh_from_task(task) - ti.run(session=session) - # verify that we only saw the dag version we created - assert known_versions == [dag_version.id] * 3 - @pytest.mark.parametrize( "upstream_return, expected_outputs", [ diff --git a/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py index 11d913e1ec568..d22dc6f1a5825 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py @@ -20,7 +20,6 @@ import pendulum import pytest -from airflow.exceptions import DownstreamTasksSkipped from airflow.models import DagRun, TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import BranchPythonOperator @@ -130,10 +129,7 @@ def test_parent_skip_branch(session, dag_maker): ti.task_id: ti for ti in dag_maker.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING).task_instances } - with pytest.raises(DownstreamTasksSkipped) as exc_info: - tis["op1"].run() - - assert exc_info.value.tasks == [("op2", -1)] + tis["op1"].run() dep = NotPreviouslySkippedDep() assert len(list(dep.get_dep_statuses(tis["op2"], session, DepContext()))) == 1 diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index e9a0a1823cf19..6d06137573973 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -105,6 +105,7 @@ def test_default_task_logging_setup(self): handler = handlers[0] assert handler.name == FILE_TASK_HANDLER + @pytest.mark.xfail(reason="TODO: Needs to be ported over to the new structlog based logging") def test_file_task_handler_when_ti_value_is_invalid(self, dag_maker): def task_callable(ti): ti.log.info("test") @@ -149,6 +150,7 @@ def task_callable(ti): # Remove the generated tmp log file. os.remove(log_filename) + @pytest.mark.xfail(reason="TODO: Needs to be ported over to the new structlog based logging") def test_file_task_handler(self, dag_maker, session): def task_callable(ti): ti.log.info("test") diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index 2e990761628fb..289392c8ce118 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -32,5 +32,6 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: return airflow_version.major, airflow_version.minor, airflow_version.micro +AIRFLOW_V_3_0_1 = get_base_airflow_version_tuple() == (3, 0, 1) AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) [].sort() diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py index cc63b8b15b310..e467bcda87633 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -117,7 +117,7 @@ def test_read_from_k8s_under_multi_namespace_mode( mock_list_pod = mock_kube_client.return_value.list_namespaced_pod def task_callable(ti): - ti.log.info("test") + ti.task.log.info("test") with DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE) as dag: task = PythonOperator( @@ -146,7 +146,7 @@ def task_callable(ti): ti.executor = "KubernetesExecutor" logger = ti.log - ti.log.disabled = False + ti.task.log.disabled = False file_handler = next((h for h in logger.handlers if h.name == FILE_TASK_HANDLER), None) set_context(logger, ti) diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 12166d88f9f99..7f56b34f70304 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -48,7 +48,7 @@ from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker from tests_common.test_utils.providers import get_provider_min_airflow_version -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.models.xcom import XComModel as XCom @@ -1224,7 +1224,7 @@ def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): mock_get_records.return_value = 1 - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1272,7 +1272,7 @@ def test_branch_true_with_dag_run(self, mock_get_db_hook, true_value, branch_op) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = true_value - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1320,7 +1320,7 @@ def test_branch_false_with_dag_run(self, mock_get_db_hook, false_value, branch_o mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = false_value - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1377,7 +1377,7 @@ def test_branch_list_with_dag_run(self, mock_get_db_hook): mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = [["1"]] - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1495,7 +1495,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook, fa mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = [false_value] - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py index ec63bba5b2a30..ec83ed1863f6c 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py @@ -25,6 +25,7 @@ from airflow.models import DAG from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook from airflow.providers.microsoft.azure.operators.adx import AzureDataExplorerQueryOperator +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.timezone import datetime TEST_DAG_ID = "unit_tests" @@ -88,12 +89,20 @@ def test_azure_data_explorer_query_operator_xcom_push_and_pull( mock_conn, mock_run_query, create_task_instance_of_operator, + request, ): - ti = create_task_instance_of_operator( - AzureDataExplorerQueryOperator, - dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull", - **MOCK_DATA, - ) - ti.run() - - assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT) + if AIRFLOW_V_3_0_PLUS: + run_task = request.getfixturevalue("run_task") + task = AzureDataExplorerQueryOperator(**MOCK_DATA) + run_task(task=task) + + assert run_task.xcom.get(key="return_value", task_id=task.task_id) == str(MOCK_RESULT) + else: + ti = create_task_instance_of_operator( + AzureDataExplorerQueryOperator, + dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull", + **MOCK_DATA, + ) + ti.run() + + assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT) diff --git a/providers/oracle/tests/unit/oracle/operators/test_oracle.py b/providers/oracle/tests/unit/oracle/operators/test_oracle.py index 2f06e1513472a..02bc3f391de11 100644 --- a/providers/oracle/tests/unit/oracle/operators/test_oracle.py +++ b/providers/oracle/tests/unit/oracle/operators/test_oracle.py @@ -27,6 +27,8 @@ from airflow.providers.oracle.hooks.oracle import OracleHook from airflow.providers.oracle.operators.oracle import OracleStoredProcedureOperator +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + class TestOracleStoredProcedureOperator: @mock.patch.object(OracleHook, "run", autospec=OracleHook.run) @@ -65,12 +67,20 @@ def test_push_oracle_exit_to_xcom(self, mock_callproc, request, dag_maker): error = f"ORA-{ora_exit_code}: This is a five-digit ORA error code" mock_callproc.side_effect = oracledb.DatabaseError(error) - with dag_maker(dag_id=f"dag_{request.node.name}"): + if AIRFLOW_V_3_0_PLUS: + run_task = request.getfixturevalue("run_task") task = OracleStoredProcedureOperator( procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id ) - dr = dag_maker.create_dagrun(run_id=task_id) - ti = TaskInstance(task=task, run_id=dr.run_id) - with pytest.raises(oracledb.DatabaseError, match=re.escape(error)): - ti.run() - assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code + run_task(task=task) + assert run_task.xcom.get(task_id=task.task_id, key="ORA") == ora_exit_code + else: + with dag_maker(dag_id=f"dag_{request.node.name}"): + task = OracleStoredProcedureOperator( + procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id + ) + dr = dag_maker.create_dagrun(run_id=task_id) + ti = TaskInstance(task=task, run_id=dr.run_id) + with pytest.raises(oracledb.DatabaseError, match=re.escape(error)): + ti.run() + assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index b14b6bd5c0df1..f51b5d6a6acca 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -24,6 +24,7 @@ import pytest from airflow.decorators import task +from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone if TYPE_CHECKING: @@ -156,7 +157,7 @@ def func(session: Session): mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook") - def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker): + def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker, request): @task.snowpark( task_id=TASK_ID, snowflake_conn_id=CONN_ID, @@ -171,15 +172,23 @@ def func(session: Session): assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value return {"a": 1, "b": "2"} - with dag_maker(dag_id=TEST_DAG_ID): - ret = func() - - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - ti = dr.get_task_instances()[0] - assert ti.xcom_pull(key="a") == 1 - assert ti.xcom_pull(key="b") == "2" - assert ti.xcom_pull() == {"a": 1, "b": "2"} + if AIRFLOW_V_3_0_PLUS: + run_task = request.getfixturevalue("run_task") + op = func().operator + run_task(task=op) + assert run_task.xcom.get(key="a") == 1 + assert run_task.xcom.get(key="b") == "2" + assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"} + else: + with dag_maker(dag_id=TEST_DAG_ID): + ret = func() + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti = dr.get_task_instances()[0] + assert ti.xcom_pull(key="a") == 1 + assert ti.xcom_pull(key="b") == "2" + assert ti.xcom_pull() == {"a": 1, "b": "2"} mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index b67d186fbf4e1..5a8bd9e87b657 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -495,9 +495,21 @@ def get_python_source(self): return textwrap.dedent(inspect.getsource(self.python_callable)) def _write_args(self, file: Path): + def resolve_proxies(obj): + """Recursively replaces lazy_object_proxy.Proxy instances with their resolved values.""" + if isinstance(obj, lazy_object_proxy.Proxy): + return obj.__wrapped__ # force evaluation + if isinstance(obj, dict): + return {k: resolve_proxies(v) for k, v in obj.items()} + if isinstance(obj, list): + return [resolve_proxies(v) for v in obj] + return obj + if self.op_args or self.op_kwargs: self.log.info("Use %r as serializer.", self.serializer) - file.write_bytes(self.pickling_library.dumps({"args": self.op_args, "kwargs": self.op_kwargs})) + file.write_bytes( + self.pickling_library.dumps({"args": self.op_args, "kwargs": resolve_proxies(self.op_kwargs)}) + ) def _write_string_args(self, file: Path): file.write_text("\n".join(map(str, self.string_args))) diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py b/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py index 43e9ae1d91a37..f0283c0307493 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py @@ -22,12 +22,13 @@ import pytest from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State -if AIRFLOW_V_3_0_PLUS: +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1 + +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped -else: - from airflow.utils.state import State + pytestmark = pytest.mark.db_test @@ -79,7 +80,7 @@ def branch_operator(): dr = dag_maker.create_dagrun() df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branchoperator.operator.run( start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_python.py b/providers/standard/tests/unit/standard/decorators/test_branch_python.py index a78050b6a3cfa..3d8a46d7a37cd 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_python.py @@ -20,12 +20,12 @@ import pytest from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State -if AIRFLOW_V_3_0_PLUS: +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1 + +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped -else: - from airflow.utils.state import State pytestmark = pytest.mark.db_test @@ -67,7 +67,7 @@ def branch_operator(): dr = dag_maker.create_dagrun() df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branchoperator.operator.run( start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py b/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py index ab616b37435cd..170916c21a31b 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py @@ -22,12 +22,13 @@ import pytest from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State -if AIRFLOW_V_3_0_PLUS: +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped -else: - from airflow.utils.state import State + pytestmark = pytest.mark.db_test @@ -95,7 +96,7 @@ def branch_operator(): dr = dag_maker.create_dagrun() df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branchoperator.operator.run( start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 8dfc7b4d8e40f..ed8d5ef0cdb97 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -33,7 +33,7 @@ from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS from unit.standard.operators.test_python import BasePythonTest if AIRFLOW_V_3_0_PLUS: @@ -215,7 +215,22 @@ def identity_notyping_with_decorator_call(x: int): assert identity_notyping_with_decorator_call(5).operator.multiple_outputs is False - def test_manual_multiple_outputs_false_with_typings(self): + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") + def test_manual_multiple_outputs_false_with_typings(self, run_task): + @task_decorator(multiple_outputs=False) + def identity2(x: int, y: int) -> tuple[int, int]: + return x, y + + res = identity2(8, 4) + run_task(task=res.operator) + + assert not res.operator.multiple_outputs + assert run_task.xcom.get(key=res.key) == (8, 4) + assert run_task.xcom.get(key="return_value_0") is None + assert run_task.xcom.get(key="return_value_1") is None + + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 3") + def test_manual_multiple_outputs_false_with_typings_af2(self): @task_decorator(multiple_outputs=False) def identity2(x: int, y: int) -> tuple[int, int]: return x, y @@ -233,7 +248,22 @@ def identity2(x: int, y: int) -> tuple[int, int]: assert ti.xcom_pull(key="return_value_0") is None assert ti.xcom_pull(key="return_value_1") is None - def test_multiple_outputs_ignore_typing(self): + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") + def test_multiple_outputs_ignore_typing(self, run_task): + @task_decorator + def identity_tuple(x: int, y: int) -> tuple[int, int]: + return x, y + + ident = identity_tuple(35, 36) + run_task(task=ident.operator) + + assert not ident.operator.multiple_outputs + assert run_task.xcom.get(key=ident.key) == (35, 36) + assert run_task.xcom.get(key="return_value_0") is None + assert run_task.xcom.get(key="return_value_1") is None + + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 3") + def test_multiple_outputs_ignore_typing_af2(self): @task_decorator def identity_tuple(x: int, y: int) -> tuple[int, int]: return x, y @@ -296,7 +326,9 @@ def add_number(num: int): ret = add_number(2) self.create_dag_run() - with pytest.raises(AirflowException): + + error_expected = AirflowException if (not AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_0_1) else TypeError + with pytest.raises(error_expected): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_fail_multiple_outputs_no_dict(self): @@ -308,7 +340,8 @@ def add_number(num: int): ret = add_number(2) self.create_dag_run() - with pytest.raises(AirflowException): + error_expected = AirflowException if (not AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_0_1) else TypeError + with pytest.raises(error_expected): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_multiple_outputs_empty_dict(self): diff --git a/providers/standard/tests/unit/standard/decorators/test_short_circuit.py b/providers/standard/tests/unit/standard/decorators/test_short_circuit.py index 3992870b52f09..3ead1c252bbb3 100644 --- a/providers/standard/tests/unit/standard/decorators/test_short_circuit.py +++ b/providers/standard/tests/unit/standard/decorators/test_short_circuit.py @@ -21,10 +21,11 @@ from pendulum import datetime from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS + if AIRFLOW_V_3_0_PLUS: from airflow.exceptions import DownstreamTasksSkipped @@ -34,8 +35,8 @@ DEFAULT_DATE = datetime(2022, 8, 17) -@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test doesn't run on AF3. Companion test below.") -def test_short_circuit_decorator_af2(dag_maker): +@pytest.mark.skipif(AIRFLOW_V_3_0_1, reason="Test doesn't run on AF 3.0.1. Companion test below.") +def test_short_circuit_decorator(dag_maker): with dag_maker(serialized=True): @task @@ -82,9 +83,9 @@ def short_circuit(condition): assert ti.state == task_state_mapping[ti.task_id] -@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only runs on AF3") +@pytest.mark.skipif(not AIRFLOW_V_3_0_1, reason="Test only runs on AF3.0.1") @pytest.mark.parametrize(["condition", "should_be_skipped"], [(True, False), (False, True)]) -def test_short_circuit_decorator_af3(dag_maker, session, condition, should_be_skipped): +def test_short_circuit_decorator_af301(dag_maker, session, condition, should_be_skipped): with dag_maker(serialized=True, session=session): @task.short_circuit() @@ -112,7 +113,7 @@ def empty(): ... ti_sc.run() -@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only runs on AF3") +@pytest.mark.skipif(not AIRFLOW_V_3_0_1, reason="Test only runs on AF3.0.1") @pytest.mark.parametrize( ["ignore_downstream_trigger_rules", "expected"], [(True, State.SKIPPED), (False, State.SUCCESS)] ) diff --git a/providers/standard/tests/unit/standard/operators/test_branch_operator.py b/providers/standard/tests/unit/standard/operators/test_branch_operator.py index c38a1074bcaa9..821e7cfb9c675 100644 --- a/providers/standard/tests/unit/standard/operators/test_branch_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_branch_operator.py @@ -24,13 +24,14 @@ from airflow.models.taskinstance import TaskInstance as TI from airflow.providers.standard.operators.branch import BaseBranchOperator from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.providers.standard.utils.skipmixin import XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.exceptions import DownstreamTasksSkipped @@ -75,7 +76,7 @@ def test_without_dag_run(self, dag_maker): branch_2.set_upstream(branch_op) dag_maker.create_dagrun(**triggered_by_kwargs) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -83,9 +84,9 @@ def test_without_dag_run(self, dag_maker): else: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": @@ -115,7 +116,7 @@ def test_branch_list_without_dag_run(self, dag_maker): branch_3.set_upstream(branch_op) dag_maker.create_dagrun(**triggered_by_kwargs) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -130,9 +131,9 @@ def test_branch_list_without_dag_run(self, dag_maker): "branch_3": State.SKIPPED, } - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id in expected: assert ti.state == expected[ti.task_id] else: @@ -152,7 +153,7 @@ def test_with_dag_run(self, dag_maker): branch_op = ChooseBranchOne(task_id="make_choice") branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -184,9 +185,9 @@ def test_with_dag_run(self, dag_maker): "branch_2": State.SKIPPED, } - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id in expected: assert ti.state == expected[ti.task_id] else: @@ -244,7 +245,15 @@ def test_with_skip_in_branch_downstream_dependencies(self, dag_maker): def test_xcom_push(self, dag_maker): dag_id = "branch_operator_test" - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + + triggered_by_kwargs = ( + { + "triggered_by": DagRunTriggeredByType.TEST, + "logical_date": DEFAULT_DATE, + } + if AIRFLOW_V_3_0_PLUS + else {"execution_date": DEFAULT_DATE} + ) with dag_maker( dag_id, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, @@ -257,36 +266,25 @@ def test_xcom_push(self, dag_maker): branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) - if AIRFLOW_V_3_0_PLUS: - dag_maker.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - logical_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), - **triggered_by_kwargs, - ) + dr = dag_maker.create_dagrun( + run_type=DagRunType.MANUAL, + start_date=timezone.utcnow(), + state=State.RUNNING, + data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), + **triggered_by_kwargs, + ) + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) assert exc_info.value.tasks == [("branch_2", -1)] else: - dag_maker.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), - **triggered_by_kwargs, - ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date - - for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): - if ti.task_id == "make_choice": - assert ti.xcom_pull(task_ids="make_choice") == "branch_1" + branch_op_ti = dr.get_task_instance(branch_op.task_id) + assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { + XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] + } def test_with_dag_run_task_groups(self, dag_maker): dag_id = "branch_operator_test" @@ -307,7 +305,7 @@ def test_with_dag_run_task_groups(self, dag_maker): branch_2.set_upstream(branch_op) branch_3.set_upstream(branch_op) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -332,9 +330,9 @@ def test_with_dag_run_task_groups(self, dag_maker): ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": diff --git a/providers/standard/tests/unit/standard/operators/test_datetime.py b/providers/standard/tests/unit/standard/operators/test_datetime.py index eab06756610d9..0c6e40381793f 100644 --- a/providers/standard/tests/unit/standard/operators/test_datetime.py +++ b/providers/standard/tests/unit/standard/operators/test_datetime.py @@ -32,7 +32,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS pytestmark = pytest.mark.db_test @@ -124,7 +124,7 @@ def test_branch_datetime_operator_falls_within_range(self, target_lower, target_ """Check BranchDateTimeOperator branch operation""" self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -157,7 +157,7 @@ def test_branch_datetime_operator_falls_outside_range(self, date, target_lower, self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info, time_machine.travel(date): @@ -183,7 +183,7 @@ def test_branch_datetime_operator_upper_comparison_within_range(self, target_upp self.branch_op.target_upper = target_upper self.branch_op.target_lower = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -208,7 +208,7 @@ def test_branch_datetime_operator_lower_comparison_within_range(self, target_low self.branch_op.target_lower = target_lower self.branch_op.target_upper = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -233,7 +233,7 @@ def test_branch_datetime_operator_upper_comparison_outside_range(self, target_up self.branch_op.target_upper = target_upper self.branch_op.target_lower = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -258,7 +258,7 @@ def test_branch_datetime_operator_lower_comparison_outside_range(self, target_lo self.branch_op.target_lower = target_lower self.branch_op.target_upper = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -295,7 +295,7 @@ def test_branch_datetime_operator_use_task_logical_date(self, dag_maker, target_ self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: diff --git a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py index 81f89a5fdd936..fce99a64b8278 100644 --- a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py @@ -34,7 +34,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.db import clear_db_runs, clear_db_xcom -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.sdk import DAG @@ -129,7 +129,7 @@ def test_skipping_non_latest(self, dag_maker): **triggered_by_kwargs, ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped # AIP-72 @@ -165,10 +165,12 @@ def test_skipping_non_latest(self, dag_maker): latest_ti2 = dr2.get_task_instance(task_id="latest") latest_ti2.task = latest_task latest_ti2.run() + else: + latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + if AIRFLOW_V_3_0_PLUS: date_getter = operator.attrgetter("logical_date") else: - latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) date_getter = operator.attrgetter("execution_date") latest_instances = get_task_instances("latest") diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 16921520cc2d7..4474795afbd50 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -71,7 +71,7 @@ from airflow.utils.types import NOTSET, DagRunType from tests_common.test_utils.db import clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.models.dagrun import DagRun @@ -91,7 +91,7 @@ CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") -if AIRFLOW_V_3_0_PLUS: +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped @@ -406,7 +406,7 @@ def f(): branch_op >> [self.branch_1, self.branch_2] dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) assert dts.value.tasks == [("branch_2", -1)] @@ -445,7 +445,7 @@ def f(): branch_op >> self.branch_2 dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) assert dts.value.tasks == [("branch_1", -1)] @@ -470,7 +470,7 @@ def f(): branch_op >> branches dr = dag_maker.create_dagrun() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with create_session() as session: @@ -503,7 +503,10 @@ def f(): tis = dr.get_task_instances() children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] with create_session() as session: - clear_task_instances(children_tis, session=session, dag=branch_op.dag) + if AIRFLOW_V_3_0_PLUS: + clear_task_instances(children_tis, session=session) + else: + clear_task_instances(children_tis, session=session, dag=branch_op.dag) # Run the cleared tasks again. for task in branches: @@ -563,7 +566,7 @@ def f(): for task_id in task_ids: # Mimic the specific order the scheduling would run the tests. task_instance = tis[task_id] task_instance.refresh_from_task(self.dag_non_serialized.get_task(task_id)) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped try: @@ -721,7 +724,7 @@ def test_short_circuiting( self.op2.trigger_rule = test_trigger_rule dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped if expected_skipped_tasks: @@ -752,7 +755,7 @@ def test_clear_skipped_downstream_task(self): short_circuit >> self.op1 >> self.op2 dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with create_session() as session: @@ -787,9 +790,12 @@ def test_clear_skipped_downstream_task(self): # Clear downstream task "op1" that was previously executed. tis = dr.get_task_instances() with create_session() as session: - clear_task_instances( - [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag - ) + if AIRFLOW_V_3_0_PLUS: + clear_task_instances([ti for ti in tis if ti.task_id == "op1"], session=session) + else: + clear_task_instances( + [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag + ) self.op1.run(start_date=self.default_date, end_date=self.default_date) self.assert_expected_task_states(dr, expected_states) @@ -819,7 +825,7 @@ def test_xcom_push_skipped_tasks(self): empty_task = EmptyOperator(task_id="empty_task") short_op_push_xcom >> empty_task dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped): @@ -1372,7 +1378,6 @@ def f( params, run_id, task_instance_key_str, - test_mode, ts, ts_nodash, ts_nodash_with_tz, @@ -1408,7 +1413,6 @@ def f( outlets, run_id, task_instance_key_str, - test_mode, ts, ts_nodash, ts_nodash_with_tz, @@ -1440,7 +1444,6 @@ def f( outlets, run_id, task_instance_key_str, - test_mode, ts, ts_nodash, ts_nodash_with_tz, @@ -1737,7 +1740,7 @@ def f(): branch_op >> [self.branch_1, self.branch_2] dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) @@ -1778,7 +1781,7 @@ def f(): dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) @@ -1805,7 +1808,7 @@ def f(): dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with create_session() as session: @@ -1840,7 +1843,10 @@ def f(): tis = dr.get_task_instances() children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] with create_session() as session: - clear_task_instances(children_tis, session=session, dag=branch_op.dag) + if AIRFLOW_V_3_0_PLUS: + clear_task_instances(children_tis, session=session) + else: + clear_task_instances(children_tis, session=session, dag=branch_op.dag) # Run the cleared tasks again. for task in branches: @@ -2009,7 +2015,7 @@ class TestCurrentContextRuntime: def test_context_in_task(self): with DAG(dag_id="assert_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): op = MyContextAssertOperator(task_id="assert_context") - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.warns(AirflowProviderDeprecationWarning): op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) else: @@ -2018,7 +2024,7 @@ def test_context_in_task(self): def test_get_context_in_old_style_context_task(self): with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context") - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.warns(AirflowProviderDeprecationWarning): op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) else: diff --git a/providers/standard/tests/unit/standard/operators/test_weekday.py b/providers/standard/tests/unit/standard/operators/test_weekday.py index 583f20fd663be..0372669c96179 100644 --- a/providers/standard/tests/unit/standard/operators/test_weekday.py +++ b/providers/standard/tests/unit/standard/operators/test_weekday.py @@ -29,12 +29,13 @@ from airflow.providers.standard.operators.weekday import BranchDayOfWeekOperator from airflow.providers.standard.utils.skipmixin import XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY from airflow.providers.standard.utils.weekday import WeekDay -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS + if AIRFLOW_V_3_0_PLUS: from airflow.models.xcom import XComModel as XCom else: @@ -115,7 +116,7 @@ def test_branch_follow_true(self, weekday, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -161,7 +162,7 @@ def test_branch_follow_true_with_logical_date(self, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -230,7 +231,7 @@ def test_branch_follow_false(self, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -336,20 +337,16 @@ def test_branch_xcom_push_true_branch(self, dag_maker): ) branch_op_ti = dr.get_task_instance(branch_op.task_id) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op_ti.run() assert exc_info.value.tasks == [("branch_2", -1)] - assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { - XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] - } else: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.xcom_pull(task_ids="make_choice") == "branch_1" + assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { + XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] + } diff --git a/pyproject.toml b/pyproject.toml index ab799ceffd1bd..aba8ac8d6503d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -848,7 +848,7 @@ fixture-parentheses = false ## pytest settings ## [tool.pytest.ini_options] addopts = [ - "--tb=short", + "--tb=no", "-rasl", "--verbosity=2", # Disable `flaky` plugin for pytest. This plugin conflicts with `rerunfailures` because provide the same marker. diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 835e39dcdbad1..08cbb13335333 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1193,7 +1193,7 @@ def add_logger_if_needed(ti: TaskInstance): ti.set_state(State.SUCCESS) log.info("[DAG TEST] Marking success for %s on %s", ti.task, ti.logical_date) else: - _run_task(ti=ti) + _run_task(ti=ti, run_triggerer=True) except Exception: log.exception("Task failed; ti=%s", ti) if use_executor: @@ -1208,7 +1208,7 @@ def add_logger_if_needed(ti: TaskInstance): return dr -def _run_task(*, ti): +def _run_task(*, ti, run_triggerer=False): """ Run a single task instance, and push result to Xcom for downstream tasks. @@ -1245,8 +1245,9 @@ def _run_task(*, ti): msg = taskrun_result.msg ti.set_state(taskrun_result.ti.state) + ti.task = taskrun_result.ti.task - if ti.state == State.DEFERRED and isinstance(msg, DeferTask): + if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and run_triggerer: # API Server expects the task instance to be in QUEUED state before # resuming from deferral. ti.set_state(State.QUEUED) @@ -1255,11 +1256,12 @@ def _run_task(*, ti): trigger = import_string(msg.classpath)(**msg.trigger_kwargs) event = _run_inline_trigger(trigger) ti.next_method = msg.next_method - ti.next_kwargs = {"event": event.payload} if event else msg.kwargs + ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs log.info("[DAG TEST] Trigger completed") ti.set_state(State.SUCCESS) - break + + return taskrun_result except Exception: log.exception("[DAG TEST] Error running task %s", ti) if ti.state not in State.finished: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d017fd75ad160..cb8040278c474 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1147,7 +1147,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): with timeout(timeout_seconds): result = ctx.run(execute, context=context) except AirflowTaskTimeout: - # TODO: handle on kill callback here + task.on_kill() raise else: result = ctx.run(execute, context=context)