diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 73ddeca095d26..cd51fbcd84b66 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,4 @@ # Licensed to the Apache Software Foundation (ASF) under one -# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -675,6 +674,9 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" uses: actions/setup-python@v2 with: python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} + cache: 'pip' + cache-dependency-path: ./dev/breeze/setup* + - run: python -m pip install --editable ./dev/breeze/ - name: > Fetch incoming commit ${{ github.sha }} with its parent uses: actions/checkout@v2 diff --git a/UPDATING.md b/UPDATING.md index e68350735d555..917a07121782b 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -231,21 +231,9 @@ To support operator-mapping (AIP 42), the `deps` attribute on operator class mus If you set the `dag_default_view` config option or the `default_view` argument to `DAG()` to `tree` you will need to update your deployment. The old name will continue to work but will issue warnings. -## Airflow 2.2.5 - -No breaking changes - -## Airflow 2.2.4 - -### Smart sensors deprecated - -Smart sensors, an "early access" feature added in Airflow 2, are now deprecated and will be removed in Airflow 2.4.0. They have been superseded by Deferrable Operators, added in Airflow 2.2.0. - -See [Migrating to Deferrable Operators](https://airflow.apache.org/docs/apache-airflow/2.2.4/concepts/smart-sensors.html#migrating-to-deferrable-operators) for details on how to migrate. - ### Database configuration moved to new section -The following configurations have been moved from `[core]` to the new `[database]` section. However when reading new option, the old option will be checked to see if it exists. If it does a DeprecationWarning will be issued and the old option will be used instead. +The following configurations have been moved from `[core]` to the new `[database]` section. However when reading the new option, the old option will be checked to see if it exists. If it does a DeprecationWarning will be issued and the old option will be used instead. - sql_alchemy_conn - sql_engine_encoding @@ -260,6 +248,18 @@ The following configurations have been moved from `[core]` to the new `[database - load_default_connections - max_db_retries +## Airflow 2.2.5 + +No breaking changes + +## Airflow 2.2.4 + +### Smart sensors deprecated + +Smart sensors, an "early access" feature added in Airflow 2, are now deprecated and will be removed in Airflow 2.4.0. They have been superseded by Deferrable Operators, added in Airflow 2.2.0. + +See [Migrating to Deferrable Operators](https://airflow.apache.org/docs/apache-airflow/2.2.4/concepts/smart-sensors.html#migrating-to-deferrable-operators) for details on how to migrate. + ## Airflow 2.2.3 No breaking changes. diff --git a/airflow/configuration.py b/airflow/configuration.py index 748e975d8330a..2fa88461c489c 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -187,19 +187,19 @@ class AirflowConfigParser(ConfigParser): ('core', 'max_active_tasks_per_dag'): ('core', 'dag_concurrency', '2.2.0'), ('logging', 'worker_log_server_port'): ('celery', 'worker_log_server_port', '2.2.0'), ('api', 'access_control_allow_origins'): ('api', 'access_control_allow_origin', '2.2.0'), - ('api', 'auth_backends'): ('api', 'auth_backend', '2.3'), - ('database', 'sql_alchemy_conn'): ('core', 'sql_alchemy_conn', '2.3'), - ('database', 'sql_engine_encoding'): ('core', 'sql_engine_encoding', '2.3'), - ('database', 'sql_engine_collation_for_ids'): ('core', 'sql_engine_collation_for_ids', '2.3'), - ('database', 'sql_alchemy_pool_enabled'): ('core', 'sql_alchemy_pool_enabled', '2.3'), - ('database', 'sql_alchemy_pool_size'): ('core', 'sql_alchemy_pool_size', '2.3'), - ('database', 'sql_alchemy_max_overflow'): ('core', 'sql_alchemy_max_overflow', '2.3'), - ('database', 'sql_alchemy_pool_recycle'): ('core', 'sql_alchemy_pool_recycle', '2.3'), - ('database', 'sql_alchemy_pool_pre_ping'): ('core', 'sql_alchemy_pool_pre_ping', '2.3'), - ('database', 'sql_alchemy_schema'): ('core', 'sql_alchemy_schema', '2.3'), - ('database', 'sql_alchemy_connect_args'): ('core', 'sql_alchemy_connect_args', '2.3'), - ('database', 'load_default_connections'): ('core', 'load_default_connections', '2.3'), - ('database', 'max_db_retries'): ('core', 'max_db_retries', '2.3'), + ('api', 'auth_backends'): ('api', 'auth_backend', '2.3.0'), + ('database', 'sql_alchemy_conn'): ('core', 'sql_alchemy_conn', '2.3.0'), + ('database', 'sql_engine_encoding'): ('core', 'sql_engine_encoding', '2.3.0'), + ('database', 'sql_engine_collation_for_ids'): ('core', 'sql_engine_collation_for_ids', '2.3.0'), + ('database', 'sql_alchemy_pool_enabled'): ('core', 'sql_alchemy_pool_enabled', '2.3.0'), + ('database', 'sql_alchemy_pool_size'): ('core', 'sql_alchemy_pool_size', '2.3.0'), + ('database', 'sql_alchemy_max_overflow'): ('core', 'sql_alchemy_max_overflow', '2.3.0'), + ('database', 'sql_alchemy_pool_recycle'): ('core', 'sql_alchemy_pool_recycle', '2.3.0'), + ('database', 'sql_alchemy_pool_pre_ping'): ('core', 'sql_alchemy_pool_pre_ping', '2.3.0'), + ('database', 'sql_alchemy_schema'): ('core', 'sql_alchemy_schema', '2.3.0'), + ('database', 'sql_alchemy_connect_args'): ('core', 'sql_alchemy_connect_args', '2.3.0'), + ('database', 'load_default_connections'): ('core', 'load_default_connections', '2.3.0'), + ('database', 'max_db_retries'): ('core', 'max_db_retries', '2.3.0'), } # A mapping of old default values that we want to change and warn the user diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 0334a4f20ce26..7a7541eb80f0a 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -102,7 +102,7 @@ class _DagRunTaskStatus: def __init__( self, - dag, + dag: DAG, start_date=None, end_date=None, mark_success=False, @@ -228,7 +228,7 @@ def _update_counters(self, ti_status, session=None): def _manage_executor_state( self, running, session - ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance]]]: + ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance], int]]: """ Checks if the executor agrees with the state of task instances that are running. @@ -238,8 +238,6 @@ def _manage_executor_state( :param running: dict of key, task to verify :return: An iterable of expanded TaskInstance per MappedTask """ - from airflow.models.mappedoperator import MappedOperator - executor = self.executor # TODO: query all instead of refresh from db @@ -266,9 +264,11 @@ def _manage_executor_state( ti.handle_failure_with_callback(error=msg) continue if ti.state not in self.STATES_COUNT_AS_RUNNING: - for node in ti.task.mapped_dependants(): - assert isinstance(node, MappedOperator) - yield node, ti.run_id, node.expand_mapped_task(ti.run_id, session=session) + # Don't use ti.task; if this task is mapped, that attribute + # would hold the unmapped task. We need to original task here. + for node in self.dag.get_task(ti.task_id, include_subdags=True).mapped_dependants(): + new_tis, num_mapped_tis = node.expand_mapped_task(ti.run_id, session=session) + yield node, ti.run_id, new_tis, num_mapped_tis @provide_session def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = None): @@ -609,18 +609,23 @@ def _per_task_process(key, ti: TaskInstance, session=None): ti_status.to_run.clear() # check executor state -- and expand any mapped TIs - for node, run_id, mapped_tis in self._manage_executor_state(ti_status.running, session): + for node, run_id, new_mapped_tis, max_map_index in self._manage_executor_state( + ti_status.running, session + ): def to_keep(key: TaskInstanceKey) -> bool: if key.dag_id != node.dag_id or key.task_id != node.task_id or key.run_id != run_id: # For another Dag/Task/Run -- don't remove return True - return False + return 0 <= key.map_index <= max_map_index # remove the old unmapped TIs for node -- they have been replaced with the mapped TIs ti_status.to_run = {key: ti for (key, ti) in ti_status.to_run.items() if to_keep(key)} - ti_status.to_run.update({ti.key: ti for ti in mapped_tis}) + ti_status.to_run.update({ti.key: ti for ti in new_mapped_tis}) + + for new_ti in new_mapped_tis: + new_ti.set_state(TaskInstanceState.SCHEDULED, session=session) # update the task counters self._update_counters(ti_status=ti_status, session=session) @@ -702,7 +707,7 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: return err - def _get_dag_with_subdags(self): + def _get_dag_with_subdags(self) -> List[DAG]: return [self.dag] + self.dag.subdags @provide_session diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 60cda34c13344..f3872c63a89d1 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -355,141 +355,126 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str ) - pool_to_task_instances: DefaultDict[str, List[TI]] = defaultdict(list) for task_instance in task_instances_to_examine: - pool_to_task_instances[task_instance.pool].append(task_instance) + pool_name = task_instance.pool - # Go through each pool, and queue up a task for execution if there are - # any open slots in the pool. - - for pool, task_instances in pool_to_task_instances.items(): - pool_name = pool - if pool not in pools: - self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool) + pool_stats = pools.get(pool_name) + if not pool_stats: + self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name) starved_pools.add(pool_name) continue - pool_total = pools[pool]["total"] - open_slots = pools[pool]["open"] + # Make sure to emit metrics if pool has no starving tasks + pool_num_starving_tasks.setdefault(pool_name, 0) - num_ready = len(task_instances) - self.log.info( - "Figuring out tasks to run in Pool(name=%s) with %s open slots " - "and %s task instances ready to be queued", - pool, - open_slots, - num_ready, - ) + pool_total = pool_stats["total"] + open_slots = pool_stats["open"] - priority_sorted_task_instances = sorted( - task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date) - ) + if open_slots <= 0: + self.log.info( + "Not scheduling since there are %s open slots in pool %s", open_slots, pool_name + ) + # Can't schedule any more since there are no more open slots. + pool_num_starving_tasks[pool_name] += 1 + num_starving_tasks_total += 1 + starved_pools.add(pool_name) + continue - for current_index, task_instance in enumerate(priority_sorted_task_instances): - if open_slots <= 0: - self.log.info( - "Not scheduling since there are %s open slots in pool %s", open_slots, pool - ) - # Can't schedule any more since there are no more open slots. - num_unhandled = len(priority_sorted_task_instances) - current_index - pool_num_starving_tasks[pool_name] += num_unhandled - num_starving_tasks_total += num_unhandled - starved_pools.add(pool_name) - break - - if task_instance.pool_slots > pool_total: - self.log.warning( - "Not executing %s. Requested pool slots (%s) are greater than " - "total pool slots: '%s' for pool: %s.", - task_instance, - task_instance.pool_slots, - pool_total, - pool, - ) + if task_instance.pool_slots > pool_total: + self.log.warning( + "Not executing %s. Requested pool slots (%s) are greater than " + "total pool slots: '%s' for pool: %s.", + task_instance, + task_instance.pool_slots, + pool_total, + pool_name, + ) - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) - continue + pool_num_starving_tasks[pool_name] += 1 + num_starving_tasks_total += 1 + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) + continue - if task_instance.pool_slots > open_slots: - self.log.info( - "Not executing %s since it requires %s slots " - "but there are %s open slots in the pool %s.", - task_instance, - task_instance.pool_slots, - open_slots, - pool, - ) - pool_num_starving_tasks[pool_name] += 1 - num_starving_tasks_total += 1 - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) - # Though we can execute tasks with lower priority if there's enough room - continue + if task_instance.pool_slots > open_slots: + self.log.info( + "Not executing %s since it requires %s slots " + "but there are %s open slots in the pool %s.", + task_instance, + task_instance.pool_slots, + open_slots, + pool_name, + ) + pool_num_starving_tasks[pool_name] += 1 + num_starving_tasks_total += 1 + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) + # Though we can execute tasks with lower priority if there's enough room + continue - # Check to make sure that the task max_active_tasks of the DAG hasn't been - # reached. - dag_id = task_instance.dag_id + # Check to make sure that the task max_active_tasks of the DAG hasn't been + # reached. + dag_id = task_instance.dag_id - current_active_tasks_per_dag = dag_active_tasks_map[dag_id] - max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks + current_active_tasks_per_dag = dag_active_tasks_map[dag_id] + max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks + self.log.info( + "DAG %s has %s/%s running and queued tasks", + dag_id, + current_active_tasks_per_dag, + max_active_tasks_per_dag_limit, + ) + if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: self.log.info( - "DAG %s has %s/%s running and queued tasks", + "Not executing %s since the number of tasks running or queued " + "from DAG %s is >= to the DAG's max_active_tasks limit of %s", + task_instance, dag_id, - current_active_tasks_per_dag, max_active_tasks_per_dag_limit, ) - if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: - self.log.info( - "Not executing %s since the number of tasks running or queued " - "from DAG %s is >= to the DAG's max_active_tasks limit of %s", - task_instance, + starved_dags.add(dag_id) + continue + + if task_instance.dag_model.has_task_concurrency_limits: + # Many dags don't have a task_concurrency, so where we can avoid loading the full + # serialized DAG the better. + serialized_dag = self.dagbag.get_dag(dag_id, session=session) + # If the dag is missing, fail the task and continue to the next task. + if not serialized_dag: + self.log.error( + "DAG '%s' for task instance %s not found in serialized_dag table", dag_id, - max_active_tasks_per_dag_limit, + task_instance, + ) + session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( + {TI.state: State.FAILED}, synchronize_session='fetch' ) - starved_dags.add(dag_id) continue - if task_instance.dag_model.has_task_concurrency_limits: - # Many dags don't have a task_concurrency, so where we can avoid loading the full - # serialized DAG the better. - serialized_dag = self.dagbag.get_dag(dag_id, session=session) - # If the dag is missing, fail the task and continue to the next task. - if not serialized_dag: - self.log.error( - "DAG '%s' for task instance %s not found in serialized_dag table", - dag_id, + task_concurrency_limit: Optional[int] = None + if serialized_dag.has_task(task_instance.task_id): + task_concurrency_limit = serialized_dag.get_task( + task_instance.task_id + ).max_active_tis_per_dag + + if task_concurrency_limit is not None: + current_task_concurrency = task_concurrency_map[ + (task_instance.dag_id, task_instance.task_id) + ] + + if current_task_concurrency >= task_concurrency_limit: + self.log.info( + "Not executing %s since the task concurrency for" + " this task has been reached.", task_instance, ) - session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( - {TI.state: State.FAILED}, synchronize_session='fetch' - ) + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) continue - task_concurrency_limit: Optional[int] = None - if serialized_dag.has_task(task_instance.task_id): - task_concurrency_limit = serialized_dag.get_task( - task_instance.task_id - ).max_active_tis_per_dag - - if task_concurrency_limit is not None: - current_task_concurrency = task_concurrency_map[ - (task_instance.dag_id, task_instance.task_id) - ] - - if current_task_concurrency >= task_concurrency_limit: - self.log.info( - "Not executing %s since the task concurrency for" - " this task has been reached.", - task_instance, - ) - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) - continue - - executable_tis.append(task_instance) - open_slots -= task_instance.pool_slots - dag_active_tasks_map[dag_id] += 1 - task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 - - pools[pool]["open"] = open_slots + executable_tis.append(task_instance) + open_slots -= task_instance.pool_slots + dag_active_tasks_map[dag_id] += 1 + task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 + + pool_stats["open"] = open_slots is_done = executable_tis or len(task_instances_to_examine) < max_tis # Check this to avoid accidental infinite loops diff --git a/airflow/migrations/versions/0096_587bdf053233_adding_index_for_dag_id_in_job.py b/airflow/migrations/versions/0096_587bdf053233_adding_index_for_dag_id_in_job.py index 693984b6edbfc..2d08d7f922a58 100644 --- a/airflow/migrations/versions/0096_587bdf053233_adding_index_for_dag_id_in_job.py +++ b/airflow/migrations/versions/0096_587bdf053233_adding_index_for_dag_id_in_job.py @@ -31,7 +31,7 @@ down_revision = 'c381b21cb7e4' branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = '2.2.4' def upgrade(): diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 048f850837a73..b5d91e74b14fc 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -706,7 +706,7 @@ def _get_ready_tis( if schedulable.task_id in ti.task.downstream_task_ids: assert isinstance(schedulable.task, MappedOperator) - new_tis = schedulable.task.expand_mapped_task(self.run_id, session=session) + new_tis, _ = schedulable.task.expand_mapped_task(self.run_id, session=session) if schedulable.state == TaskInstanceState.SKIPPED: # Task is now skipped (likely cos upstream returned 0 tasks continue diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index c282dc4d596bc..1ac4ff89fa5d2 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -579,10 +579,11 @@ def _resolve_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, in return map_lengths - def expand_mapped_task(self, run_id: str, *, session: Session) -> Sequence["TaskInstance"]: + def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]: """Create the mapped task instances for mapped task. - :return: The mapped task instances, in ascending order by map index. + :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the + maximum map_index. """ from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook @@ -619,7 +620,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> Sequence["Task ) unmapped_ti.state = TaskInstanceState.SKIPPED session.flush() - return ret + return ret, 0 # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 @@ -661,7 +662,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> Sequence["Task session.flush() - return ret + return ret, total_length def prepare_for_execution(self) -> "MappedOperator": # Since a mapped operator cannot be used for execution, and an unmapped diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2f3d75436de55..5b53916acffe7 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1510,8 +1510,10 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) - # Don't clear Xcom until the task is certain to execute - self.clear_xcom_data() + # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral. + if not self.next_method: + self.clear_xcom_data() + with Stats.timer(f'dag.{self.task.dag_id}.{self.task.task_id}.duration'): # Set the validated/merged params on the task object. self.task.params = context['params'] diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 73221623d8747..c5d6165e8da0a 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -28,6 +28,7 @@ from logging import Logger from airflow.models.dag import DAG + from airflow.models.mappedoperator import MappedOperator from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.task_group import TaskGroup @@ -290,7 +291,7 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: """This is used by SerializedTaskGroup to serialize a task group's content.""" raise NotImplementedError() - def mapped_dependants(self) -> Iterator["DAGNode"]: + def mapped_dependants(self) -> Iterator["MappedOperator"]: """Return any mapped nodes that are direct dependencies of the current task For now, this walks the entire DAG to find mapped nodes that has this diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 38305031dd022..1107655585ace 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -203,10 +203,24 @@ def create_custom_object( if namespace is None: namespace = self.get_namespace() if isinstance(body, str): - body = _load_body_to_dict(body) + body_dict = _load_body_to_dict(body) + else: + body_dict = body + try: + api.delete_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + name=body_dict["metadata"]["name"], + ) + self.log.warning("Deleted SparkApplication with the same name.") + except client.rest.ApiException: + self.log.info(f"SparkApp {body_dict['metadata']['name']} not found.") + try: response = api.create_namespaced_custom_object( - group=group, version=version, namespace=namespace, plural=plural, body=body + group=group, version=version, namespace=namespace, plural=plural, body=body_dict ) self.log.debug("Response: %s", response) return response diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 10296871efc60..fbf0aebd4948c 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -33,7 +33,7 @@ class SparkKubernetesOperator(BaseOperator): https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication :param application_file: Defines Kubernetes 'custom_resource_definition' of 'sparkApplication' as either a - path to a '.json' file or a JSON string. + path to a '.yaml' file, '.json' file, YAML string or JSON string. :param namespace: kubernetes namespace to put sparkApplication :param kubernetes_conn_id: The :ref:`kubernetes connection id ` for the to Kubernetes cluster. @@ -61,14 +61,15 @@ def __init__( self.kubernetes_conn_id = kubernetes_conn_id self.api_group = api_group self.api_version = api_version + self.plural = "sparkapplications" def execute(self, context: 'Context'): - self.log.info("Creating sparkApplication") hook = KubernetesHook(conn_id=self.kubernetes_conn_id) + self.log.info("Creating sparkApplication") response = hook.create_custom_object( group=self.api_group, version=self.api_version, - plural="sparkapplications", + plural=self.plural, body=self.application_file, namespace=self.namespace, ) diff --git a/airflow/providers/jenkins/operators/jenkins_job_trigger.py b/airflow/providers/jenkins/operators/jenkins_job_trigger.py index 7a623d3651c5e..b7dcb25913b27 100644 --- a/airflow/providers/jenkins/operators/jenkins_job_trigger.py +++ b/airflow/providers/jenkins/operators/jenkins_job_trigger.py @@ -153,9 +153,18 @@ def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: # once it will be available in python-jenkins (v > 0.4.15) self.log.info('Polling jenkins queue at the url %s', location) while try_count < self.max_try_before_job_appears: - location_answer = jenkins_request_with_headers( - jenkins_server, Request(method='POST', url=location) - ) + try: + location_answer = jenkins_request_with_headers( + jenkins_server, Request(method='POST', url=location) + ) + # we don't want to fail the operator, this will continue to poll + # until max_try_before_job_appears reached + except (HTTPError, JenkinsException): + self.log.warning('polling failed, retrying', exc_info=True) + try_count += 1 + time.sleep(self.sleep_time) + continue + if location_answer is not None: json_response = json.loads(location_answer['body']) if ( @@ -168,8 +177,9 @@ def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: return build_number try_count += 1 time.sleep(self.sleep_time) + raise AirflowException( - "The job hasn't been executed after polling " f"the queue {self.max_try_before_job_appears} times" + f"The job hasn't been executed after polling the queue {self.max_try_before_job_appears} times" ) def get_hook(self) -> JenkinsHook: diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index b120fe6d4ec26..b60d4b6e89100 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -34,12 +34,15 @@ class OracleOperator(BaseOperator): (templated) :param oracle_conn_id: The :ref:`Oracle connection id ` reference to a specific Oracle database. - :param parameters: (optional) the parameters to render the SQL query with. + :param parameters: (optional, templated) the parameters to render the SQL query with. :param autocommit: if True, each command is automatically committed. (default value: False) """ - template_fields: Sequence[str] = ('sql',) + template_fields: Sequence[str] = ( + 'parameters', + 'sql', + ) template_ext: Sequence[str] = ('.sql',) template_fields_renderers = {'sql': 'sql'} ui_color = '#ededed' @@ -73,10 +76,13 @@ class OracleStoredProcedureOperator(BaseOperator): :param procedure: name of stored procedure to call (templated) :param oracle_conn_id: The :ref:`Oracle connection id ` reference to a specific Oracle database. - :param parameters: (optional) the parameters provided in the call + :param parameters: (optional, templated) the parameters provided in the call """ - template_fields: Sequence[str] = ('procedure',) + template_fields: Sequence[str] = ( + 'parameters', + 'procedure', + ) ui_color = '#ededed' def __init__( diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 4d4732e11de8e..4b0f7e1646c4a 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -940,18 +940,13 @@ def _format_dangling_error(source_table, target_table, invalid_count, reason): def check_run_id_null(session: Session) -> Iterable[str]: - import sqlalchemy.schema - - metadata = sqlalchemy.schema.MetaData(session.bind) - try: - metadata.reflect(only=[DagRun.__tablename__], extend_existing=True, resolve_fks=False) - except exc.InvalidRequestError: - # Table doesn't exist -- empty db - return + metadata = reflect_tables([DagRun], session) # We can't use the model here since it may differ from the db state due to # this function is run prior to migration. Use the reflected table instead. - dagrun_table = metadata.tables[DagRun.__tablename__] + dagrun_table = metadata.tables.get(DagRun.__tablename__) + if dagrun_table is None: + return invalid_dagrun_filter = or_( dagrun_table.c.dag_id.is_(None), @@ -1048,12 +1043,10 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str When we find such "dangling" rows we back them up in a special table and delete them from the main table. """ - import sqlalchemy.schema from sqlalchemy import and_ from airflow.models.renderedtifields import RenderedTaskInstanceFields - metadata = sqlalchemy.schema.MetaData(session.bind) models_to_dagrun: List[Tuple[Base, str]] = [ (mod, ver) for ver, models in { @@ -1062,18 +1055,14 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str }.items() for mod in models ] - for model, _ in [*models_to_dagrun, (DagRun, '2.2')]: - try: - metadata.reflect( - only=[model.__tablename__], extend_existing=True, resolve_fks=False # type: ignore - ) - except exc.InvalidRequestError: - # Table doesn't exist, but try the other ones in case the user is upgrading from an _old_ DB - # version - pass - # Key table doesn't exist -- likely empty DB. - if DagRun.__tablename__ not in metadata or TaskInstance.__tablename__ not in metadata: + metadata = reflect_tables([*[x[0] for x in models_to_dagrun], DagRun], session) + + if ( + metadata.tables.get(DagRun.__tablename__) is None + or metadata.tables.get(TaskInstance.__tablename__) is None + ): + # Key table doesn't exist -- likely empty DB. return # We can't use the model here since it may differ from the db state due to @@ -1102,7 +1091,7 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str exists_subquery = ( session.query(text('1')).select_from(dagrun_table).filter(source_to_dag_run_join_cond) ) - invalid_rows_query = session.query(source_table.c.dag_id, source_table.c.execution_date).filter( + invalid_rows_query = session.query(*[x.label(x.name) for x in source_table.c]).filter( ~exists_subquery.exists() ) @@ -1123,7 +1112,7 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str _move_dangling_data_to_new_table( session, source_table, - invalid_rows_query.with_entities(*source_table.columns), + invalid_rows_query, exists_subquery, dangling_table_name, ) diff --git a/airflow/www/static/js/tree/Table.jsx b/airflow/www/static/js/tree/Table.jsx index 06eb84cad46a3..aef91ce905f2e 100644 --- a/airflow/www/static/js/tree/Table.jsx +++ b/airflow/www/static/js/tree/Table.jsx @@ -146,6 +146,7 @@ const Table = ({ })} + {totalEntries > data.length && ( + )} ); }; diff --git a/airflow/www/static/js/tree/dagRuns/Bar.jsx b/airflow/www/static/js/tree/dagRuns/Bar.jsx index 47aa04f7048aa..d972582cedd9e 100644 --- a/airflow/www/static/js/tree/dagRuns/Bar.jsx +++ b/airflow/www/static/js/tree/dagRuns/Bar.jsx @@ -88,7 +88,7 @@ const DagRunBar = ({ { return ( @@ -89,8 +89,6 @@ const DagRuns = ({ tableWidth }) => { - Runs - Tasks diff --git a/breeze-legacy b/breeze-legacy index 27ef65cb8d43f..2885ca87a6f71 100755 --- a/breeze-legacy +++ b/breeze-legacy @@ -1756,7 +1756,7 @@ Most flags are applicable to the shell command as it will run build when needed. export DETAILED_USAGE_EXEC=" ${CMDNAME} exec [-- ] - Execs into interactive shell to an already running container. The container mus be started + Execs into interactive shell to an already running container. The container must be started already by breeze shell command. If you are not familiar with tmux, this is the best way to run multiple processes in the same container at the same time for example scheduler, webserver, workers, database console and interactive terminal. diff --git a/chart/README.md b/chart/README.md index 4c307f02be915..ea11dc1538aca 100644 --- a/chart/README.md +++ b/chart/README.md @@ -36,7 +36,7 @@ cluster using the [Helm](https://helm.sh) package manager. ## Features -* Supported executors: ``LocalExecutor``, ``CeleryExecutor``, ``CeleryKubernetesExecutor``, ``KubernetesExecutor``. +* Supported executors: ``LocalExecutor``, ``LocalKubernetesExecutor``, ``CeleryExecutor``, ``CeleryKubernetesExecutor``, ``KubernetesExecutor``. * Supported Airflow version: ``1.10+``, ``2.0+`` * Supported database backend: ``PostgresSQL``, ``MySQL`` * Autoscaling for ``CeleryExecutor`` provided by KEDA diff --git a/chart/templates/_helpers.yaml b/chart/templates/_helpers.yaml index 2b49b848ad273..ed56d3ef84bfa 100644 --- a/chart/templates/_helpers.yaml +++ b/chart/templates/_helpers.yaml @@ -123,7 +123,7 @@ If release name contains chart name it will be used as a full name. {{- range $i, $config := .Values.env }} - name: {{ $config.name }} value: {{ $config.value | quote }} - {{- if or (eq $.Values.executor "KubernetesExecutor") (eq $.Values.executor "CeleryKubernetesExecutor") }} + {{- if or (eq $.Values.executor "KubernetesExecutor") (eq $.Values.executor "LocalKubernetesExecutor") (eq $.Values.executor "CeleryKubernetesExecutor") }} - name: AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__{{ $config.name }} value: {{ $config.value | quote }} {{- end }} @@ -136,7 +136,7 @@ If release name contains chart name it will be used as a full name. name: {{ $config.secretName }} key: {{ default "value" $config.secretKey }} {{- end }} - {{- if or (eq $.Values.executor "KubernetesExecutor") (eq $.Values.executor "CeleryKubernetesExecutor") }} + {{- if or (eq $.Values.executor "LocalKubernetesExecutor") (eq $.Values.executor "KubernetesExecutor") (eq $.Values.executor "CeleryKubernetesExecutor") }} {{- range $i, $config := .Values.secret }} - name: AIRFLOW__KUBERNETES_SECRETS__{{ $config.envName }} value: {{ printf "%s=%s" $config.secretName $config.secretKey }} diff --git a/chart/templates/configmaps/configmap.yaml b/chart/templates/configmaps/configmap.yaml index fccbb64b96036..c3a4ced89f555 100644 --- a/chart/templates/configmaps/configmap.yaml +++ b/chart/templates/configmaps/configmap.yaml @@ -54,7 +54,7 @@ data: known_hosts: | {{ .Values.dags.gitSync.knownHosts | nindent 4 }} {{- end }} -{{- if or (eq $.Values.executor "KubernetesExecutor") (eq $.Values.executor "CeleryKubernetesExecutor") }} +{{- if or (eq $.Values.executor "LocalKubernetesExecutor") (eq $.Values.executor "KubernetesExecutor") (eq $.Values.executor "CeleryKubernetesExecutor") }} {{- if semverCompare ">=1.10.12" .Values.airflowVersion }} pod_template_file.yaml: |- {{- if .Values.podTemplate }} diff --git a/chart/templates/rbac/pod-launcher-rolebinding.yaml b/chart/templates/rbac/pod-launcher-rolebinding.yaml index 7cbe568bfa24e..25cf3f5ef6db6 100644 --- a/chart/templates/rbac/pod-launcher-rolebinding.yaml +++ b/chart/templates/rbac/pod-launcher-rolebinding.yaml @@ -19,8 +19,8 @@ ## Airflow Pod Launcher Role Binding ################################# {{- if and .Values.rbac.create .Values.allowPodLaunching }} -{{- $schedulerLaunchExecutors := list "LocalExecutor" "KubernetesExecutor" "CeleryKubernetesExecutor" }} -{{- $workerLaunchExecutors := list "CeleryExecutor" "KubernetesExecutor" "CeleryKubernetesExecutor" }} +{{- $schedulerLaunchExecutors := list "LocalExecutor" "LocalKubernetesExecutor" "KubernetesExecutor" "CeleryKubernetesExecutor" }} +{{- $workerLaunchExecutors := list "CeleryExecutor" "LocalKubernetesExecutor" "KubernetesExecutor" "CeleryKubernetesExecutor" }} {{- if .Values.multiNamespaceMode }} kind: ClusterRoleBinding {{- else }} diff --git a/chart/templates/rbac/security-context-constraint-rolebinding.yaml b/chart/templates/rbac/security-context-constraint-rolebinding.yaml index 8bee565175827..97ea52ea58abd 100644 --- a/chart/templates/rbac/security-context-constraint-rolebinding.yaml +++ b/chart/templates/rbac/security-context-constraint-rolebinding.yaml @@ -19,7 +19,7 @@ ## Airflow SCC Role Binding ################################# {{- if and .Values.rbac.create .Values.rbac.createSCCRoleBinding }} -{{- $hasWorkers := has .Values.executor (list "CeleryExecutor" "KubernetesExecutor" "CeleryKubernetesExecutor") }} +{{- $hasWorkers := has .Values.executor (list "CeleryExecutor" "LocalKubernetesExecutor" "KubernetesExecutor" "CeleryKubernetesExecutor") }} {{- if .Values.multiNamespaceMode }} kind: ClusterRoleBinding {{- else }} diff --git a/chart/values.schema.json b/chart/values.schema.json index a35b9e109a565..3ce6fed4b472f 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -405,6 +405,7 @@ "default": "CeleryExecutor", "enum": [ "LocalExecutor", + "LocalKubernetesExecutor", "CeleryExecutor", "KubernetesExecutor", "CeleryKubernetesExecutor" diff --git a/chart/values.yaml b/chart/values.yaml index f9cf985e1e15e..3f5b3e4ad64d1 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -219,7 +219,7 @@ rbac: createSCCRoleBinding: false # Airflow executor -# Options: LocalExecutor, CeleryExecutor, KubernetesExecutor, CeleryKubernetesExecutor +# One of: LocalExecutor, LocalKubernetesExecutor, CeleryExecutor, KubernetesExecutor, CeleryKubernetesExecutor executor: "CeleryExecutor" # If this is true and using LocalExecutor/KubernetesExecutor/CeleryKubernetesExecutor, the scheduler's diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 09b9750208876..c5a8a5caa6745 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -51,7 +51,7 @@ Here's the list of all the Database Migrations that are executed via when you ru | ``5e3ec427fdd3`` | ``587bdf053233`` | ``2.3.0`` | Increase length of email and username in ``ab_user`` and | | | | | ``ab_register_user`` table to ``256`` characters | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ -| ``587bdf053233`` | ``c381b21cb7e4`` | ``2.3.0`` | Add index for ``dag_id`` column in ``job`` table. | +| ``587bdf053233`` | ``c381b21cb7e4`` | ``2.2.4`` | Add index for ``dag_id`` column in ``job`` table. | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ | ``c381b21cb7e4`` | ``be2bfac3da23`` | ``2.2.4`` | Create a ``session`` table to store web session data | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ diff --git a/scripts/ci/pre_commit/pre_commit_breeze_cmd_line.py b/scripts/ci/pre_commit/pre_commit_breeze_cmd_line.py index 6cdf99ce72f96..bb34176bf426c 100755 --- a/scripts/ci/pre_commit/pre_commit_breeze_cmd_line.py +++ b/scripts/ci/pre_commit/pre_commit_breeze_cmd_line.py @@ -43,12 +43,14 @@ def print_help_for_all_commands(): env['RECORD_BREEZE_WIDTH'] = SCREENSHOT_WIDTH env['RECORD_BREEZE_TITLE'] = "Breeze commands" env['RECORD_BREEZE_OUTPUT_FILE'] = str(BREEZE_IMAGES_DIR / "output-commands.svg") + env['TERM'] = "xterm-256color" check_call(["breeze", "--help"], env=env) for command in get_command_list(): env = os.environ.copy() env['RECORD_BREEZE_WIDTH'] = SCREENSHOT_WIDTH env['RECORD_BREEZE_TITLE'] = f"Command: {command}" env['RECORD_BREEZE_OUTPUT_FILE'] = str(BREEZE_IMAGES_DIR / f"output-{command}.svg") + env['TERM'] = "xterm-256color" check_call(["breeze", command, "--help"], env=env) diff --git a/tests/charts/test_basic_helm_chart.py b/tests/charts/test_basic_helm_chart.py index dff1964e1acf4..308fb6b5def9d 100644 --- a/tests/charts/test_basic_helm_chart.py +++ b/tests/charts/test_basic_helm_chart.py @@ -327,7 +327,8 @@ def test_unsupported_executor(self): }, ) assert ( - 'executor must be one of the following: "LocalExecutor", "CeleryExecutor", ' + 'executor must be one of the following: "LocalExecutor", ' + '"LocalKubernetesExecutor", "CeleryExecutor", ' '"KubernetesExecutor", "CeleryKubernetesExecutor"' in ex_ctx.exception.stderr.decode() ) diff --git a/tests/dags/test_mapped_classic.py b/tests/dags/test_mapped_classic.py index 3880cc74fc6c7..cbf3a8a5b8178 100644 --- a/tests/dags/test_mapped_classic.py +++ b/tests/dags/test_mapped_classic.py @@ -32,3 +32,6 @@ def consumer(value): with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag: PythonOperator.partial(task_id='consumer', python_callable=consumer).expand(op_args=make_arg_lists()) + PythonOperator.partial(task_id='consumer_literal', python_callable=consumer).expand( + op_args=[[1], [2], [3]], + ) diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py index 34f6ae3d720fc..e4e796c3e45de 100644 --- a/tests/dags/test_mapped_taskflow.py +++ b/tests/dags/test_mapped_taskflow.py @@ -29,3 +29,4 @@ def consumer(value): print(repr(value)) consumer.expand(value=make_list()) + consumer.expand(value=[1, 2, 3]) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 65caff63a414c..edbd16a9e38ee 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -23,6 +23,7 @@ import threading from unittest.mock import patch +import pendulum import pytest from airflow import settings @@ -39,11 +40,11 @@ from airflow.models import DagBag, Pool, TaskInstance as TI from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstanceKey +from airflow.models.taskmap import TaskMap from airflow.operators.dummy import DummyOperator from airflow.utils import timezone -from airflow.utils.dates import days_ago from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timeout import timeout from airflow.utils.types import DagRunType from tests.models import TEST_DAGS_FOLDER @@ -1581,7 +1582,7 @@ def test_backfill_has_job_id(self): @pytest.mark.long_running @pytest.mark.parametrize("executor_name", ["SequentialExecutor", "DebugExecutor"]) @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) - def test_mapped_dag(self, dag_id, executor_name): + def test_mapped_dag(self, dag_id, executor_name, session): """ End-to-end test of a simple mapped dag. @@ -1595,11 +1596,114 @@ def test_mapped_dag(self, dag_id, executor_name): self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py')) dag = self.dagbag.get_dag(dag_id) + when = pendulum.today('UTC') + job = BackfillJob( dag=dag, - start_date=days_ago(1), - end_date=days_ago(1), + start_date=when, + end_date=when, donot_pickle=True, executor=ExecutorLoader.load_executor(executor_name), ) job.run() + + dr = DagRun.find(dag_id=dag.dag_id, execution_date=when, session=session)[0] + assert dr + assert dr.state == DagRunState.SUCCESS + + # Check that every task has a start and end date + for ti in dr.task_instances: + assert ti.state == TaskInstanceState.SUCCESS + assert ti.start_date is not None + assert ti.end_date is not None + + def test_mapped_dag_pre_existing_tis(self, dag_maker, session): + """If the DagRun already some mapped TIs, ensure that we re-run them successfully""" + from airflow.decorators import task + from airflow.operators.python import PythonOperator + + list_result = [[1], [2], [{'a': 'b'}]] + + @task + def make_arg_lists(): + return list_result + + def consumer(value): + print(repr(value)) + + with dag_maker(session=session) as dag: + consumer_op = PythonOperator.partial(task_id='consumer', python_callable=consumer).expand( + op_args=make_arg_lists() + ) + PythonOperator.partial(task_id='consumer_literal', python_callable=consumer).expand( + op_args=[[1], [2], [3]], + ) + + dr = dag_maker.create_dagrun() + + # Create the existing mapped TIs -- this the crucial part of this test + ti = dr.get_task_instance('consumer', session=session) + ti.map_index = 0 + for map_index in range(1, 3): + ti = TI(consumer_op, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + + executor = MockExecutor() + + ti_status = BackfillJob._DagRunTaskStatus() + ti_status.active_runs.append(dr) + ti_status.to_run = {ti.key: ti for ti in dr.task_instances} + + job = BackfillJob( + dag=dag, + start_date=dr.execution_date, + end_date=dr.execution_date, + donot_pickle=True, + executor=executor, + ) + + executor_change_state = executor.change_state + + def on_change_state(key, state, info=None): + if key.task_id == 'make_arg_lists': + session.add( + TaskMap( + length=len(list_result), + keys=None, + dag_id=key.dag_id, + run_id=key.run_id, + task_id=key.task_id, + map_index=key.map_index, + ) + ) + session.flush() + executor_change_state(key, state, info) + + with patch.object(executor, 'change_state', side_effect=on_change_state): + job._process_backfill_task_instances( + ti_status=ti_status, + executor=job.executor, + start_date=dr.execution_date, + pickle_id=None, + session=session, + ) + assert ti_status.failed == set() + assert ti_status.succeeded == { + TaskInstanceKey(dag_id=dr.dag_id, task_id='consumer', run_id='test', try_number=1, map_index=0), + TaskInstanceKey(dag_id=dr.dag_id, task_id='consumer', run_id='test', try_number=1, map_index=1), + TaskInstanceKey(dag_id=dr.dag_id, task_id='consumer', run_id='test', try_number=1, map_index=2), + TaskInstanceKey( + dag_id=dr.dag_id, task_id='consumer_literal', run_id='test', try_number=1, map_index=0 + ), + TaskInstanceKey( + dag_id=dr.dag_id, task_id='consumer_literal', run_id='test', try_number=1, map_index=1 + ), + TaskInstanceKey( + dag_id=dr.dag_id, task_id='consumer_literal', run_id='test', try_number=1, map_index=2 + ), + TaskInstanceKey( + dag_id=dr.dag_id, task_id='make_arg_lists', run_id='test', try_number=1, map_index=-1 + ), + } diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 80e3cb1f347b2..7a22f0bb21567 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -610,6 +610,45 @@ def test_find_executable_task_instances_order_priority(self, dag_maker): assert [ti.key for ti in res] == [tis[1].key] session.rollback() + def test_find_executable_task_instances_order_priority_with_pools(self, dag_maker): + """ + The scheduler job should pick tasks with higher priority for execution + even if different pools are involved. + """ + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + session = settings.Session() + + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_order_priority_with_pools' + + session.add(Pool(pool='pool1', slots=32)) + session.add(Pool(pool='pool2', slots=32)) + + with dag_maker(dag_id=dag_id, max_active_tasks=2): + op1 = DummyOperator(task_id='dummy1', priority_weight=1, pool='pool1') + op2 = DummyOperator(task_id='dummy2', priority_weight=2, pool='pool2') + op3 = DummyOperator(task_id='dummy3', priority_weight=3, pool='pool1') + + dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + + ti1 = dag_run.get_task_instance(op1.task_id, session) + ti2 = dag_run.get_task_instance(op2.task_id, session) + ti3 = dag_run.get_task_instance(op3.task_id, session) + + ti1.state = State.SCHEDULED + ti2.state = State.SCHEDULED + ti3.state = State.SCHEDULED + + session.flush() + + res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session) + + assert 2 == len(res) + assert ti3.key == res[0].key + assert ti2.key == res[1].key + + session.rollback() + def test_find_executable_task_instances_order_execution_date_and_priority(self, dag_maker): dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-b' @@ -1096,6 +1135,52 @@ def test_find_executable_task_instances_not_enough_task_concurrency_for_first(se session.rollback() + @mock.patch('airflow.jobs.scheduler_job.Stats.gauge') + def test_emit_pool_starving_tasks_metrics(self, mock_stats_gauge, dag_maker): + self.scheduler_job = SchedulerJob(subdir=os.devnull) + session = settings.Session() + + dag_id = 'SchedulerJobTest.test_emit_pool_starving_tasks_metrics' + with dag_maker(dag_id=dag_id): + op = DummyOperator(task_id='op', pool_slots=2) + + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + + ti = dr.get_task_instance(op.task_id, session) + ti.state = State.SCHEDULED + + set_default_pool_slots(1) + session.flush() + + res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session) + assert 0 == len(res) + + mock_stats_gauge.assert_has_calls( + [ + mock.call('scheduler.tasks.starving', 1), + mock.call(f'pool.starving_tasks.{Pool.DEFAULT_POOL_NAME}', 1), + ], + any_order=True, + ) + mock_stats_gauge.reset_mock() + + set_default_pool_slots(2) + session.flush() + + res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session) + assert 1 == len(res) + + mock_stats_gauge.assert_has_calls( + [ + mock.call('scheduler.tasks.starving', 0), + mock.call(f'pool.starving_tasks.{Pool.DEFAULT_POOL_NAME}', 0), + ], + any_order=True, + ) + + session.rollback() + session.close() + def test_enqueue_task_instances_with_queued_state(self, dag_maker): dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' task_id_1 = 'dummy' diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 2a1ea0889e3d8..4cb44587810f3 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1121,6 +1121,31 @@ def test_xcom_pull_after_success(self, create_task_instance): ti.run(ignore_all_deps=True) assert ti.xcom_pull(task_ids='test_xcom', key=key) is None + def test_xcom_pull_after_deferral(self, create_task_instance, session): + """ + tests xcom will not clear before a task runs its next method after deferral. + """ + + key = 'xcom_key' + value = 'xcom_value' + + ti = create_task_instance( + dag_id='test_xcom', + schedule_interval='@monthly', + task_id='test_xcom', + pool='test_xcom', + ) + + ti.run(mark_success=True) + ti.xcom_push(key=key, value=value) + + ti.next_method = "execute" + session.merge(ti) + session.commit() + + ti.run(ignore_all_deps=True) + assert ti.xcom_pull(task_ids='test_xcom', key=key) == value + def test_xcom_pull_different_execution_date(self, create_task_instance): """ tests xcom fetch behavior with different execution dates, using @@ -2512,8 +2537,8 @@ def show(value): emit_ti.run() show_task = dag.get_task("show") - mapped_tis = show_task.expand_mapped_task(dag_run.run_id, session=session) - assert len(mapped_tis) == len(upstream_return) + mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) + assert num == len(mapped_tis) == len(upstream_return) for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(show_task) @@ -2546,8 +2571,8 @@ def show(number, letter): ti.run() show_task = dag.get_task("show") - mapped_tis = show_task.expand_mapped_task(dag_run.run_id, session=session) - assert len(mapped_tis) == 6 + mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) + assert len(mapped_tis) == 6 == num for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(show_task) @@ -2584,8 +2609,8 @@ def show(a, b): ti.run() show_task = dag.get_task("show") - mapped_tis = show_task.expand_mapped_task(dag_run.run_id, session=session) - assert len(mapped_tis) == 4 + mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) + assert num == len(mapped_tis) == 4 for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(show_task) @@ -2621,7 +2646,8 @@ def cmds(): ti.run() bash_task = dag.get_task("dynamic.bash") - mapped_bash_tis = bash_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_bash_tis, num = bash_task.expand_mapped_task(dag_run.run_id, session=session) + assert num == 2 * 2 for ti in sorted(mapped_bash_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(bash_task) ti.run() @@ -2681,7 +2707,7 @@ def add_one(x): ti.run() task_345 = dag.get_task("add_one__1") - for ti in task_345.expand_mapped_task(dagrun.run_id, session=session): + for ti in task_345.expand_mapped_task(dagrun.run_id, session=session)[0]: ti.refresh_from_task(task_345) ti.run() diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index ed99a170caf0c..6d42a8f13607c 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -176,8 +176,11 @@ def setUp(self): args = {'owner': 'airflow', 'start_date': timezone.datetime(2020, 2, 1)} self.dag = DAG('test_dag_id', default_args=args) + @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object') @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') - def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_create_application_from_yaml( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = SparkKubernetesOperator( application_file=TEST_VALID_APPLICATION_YAML, dag=self.dag, @@ -186,6 +189,13 @@ def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kub ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + namespace='default', + plural='sparkapplications', + version='v1beta2', + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group='sparkoperator.k8s.io', @@ -194,8 +204,11 @@ def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kub version='v1beta2', ) + @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object') @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') - def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_create_application_from_json( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = SparkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -204,6 +217,13 @@ def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kub ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + namespace='default', + plural='sparkapplications', + version='v1beta2', + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group='sparkoperator.k8s.io', @@ -212,9 +232,10 @@ def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kub version='v1beta2', ) + @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object') @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') def test_create_application_from_json_with_api_group_and_version( - self, mock_create_namespaced_crd, mock_kubernetes_hook + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook ): api_group = 'sparkoperator.example.com' api_version = 'v1alpha1' @@ -228,6 +249,13 @@ def test_create_application_from_json_with_api_group_and_version( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_called_once_with( + group=api_group, + namespace='default', + plural='sparkapplications', + version=api_version, + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group=api_group, @@ -236,8 +264,11 @@ def test_create_application_from_json_with_api_group_and_version( version=api_version, ) + @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object') @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') - def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_namespace_from_operator( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = SparkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -247,6 +278,13 @@ def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernet ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + namespace='operator_namespace', + plural='sparkapplications', + version='v1beta2', + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group='sparkoperator.k8s.io', @@ -255,8 +293,11 @@ def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernet version='v1beta2', ) + @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object') @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') - def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_namespace_from_connection( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = SparkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -265,6 +306,13 @@ def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubern ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + namespace='mock_namespace', + plural='sparkapplications', + version='v1beta2', + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group='sparkoperator.k8s.io',