diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index d26c417c55014..fff1cc64adfae 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -29,9 +29,11 @@ import itertools import logging +import traceback from typing import TYPE_CHECKING, NamedTuple -from sqlalchemy import and_, exists, func, select, tuple_ +from sqlalchemy import and_, delete, exists, func, select, tuple_ +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import joinedload, load_only from airflow.assets.manager import asset_manager @@ -45,9 +47,12 @@ ) from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag from airflow.models.dagrun import DagRun +from airflow.models.dagwarning import DagWarningType +from airflow.models.errors import ParseImportError from airflow.models.trigger import Trigger from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.triggers.base import BaseTrigger +from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.timezone import utcnow from airflow.utils.types import DagRunType @@ -58,6 +63,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select + from airflow.models.dagwarning import DagWarning from airflow.typing_compat import Self log = logging.getLogger(__name__) @@ -163,6 +169,181 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se ) +def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir: str | None): + """ + Try to serialize the dag to the DB, but make a note of any errors. + + We can't place them directly in import_errors, as this may be retried, and work the next time + """ + from airflow import settings + from airflow.configuration import conf + from airflow.models.dagcode import DagCode + from airflow.models.serialized_dag import SerializedDagModel + + try: + # We can't use bulk_write_to_db as we want to capture each error individually + dag_was_updated = SerializedDagModel.write_dag( + dag, + min_update_interval=settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL, + session=session, + processor_subdir=processor_subdir, + ) + if dag_was_updated: + _sync_dag_perms(dag, session=session) + else: + # Check and update DagCode + DagCode.update_source_code(dag) + return [] + except OperationalError: + raise + except Exception: + log.exception("Failed to write serialized DAG dag_id=%s fileloc=%s", dag.dag_id, dag.fileloc) + dagbag_import_error_traceback_depth = conf.getint("core", "dagbag_import_error_traceback_depth") + return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))] + + +def _sync_dag_perms(dag: DAG, session: Session): + """Sync DAG specific permissions.""" + dag_id = dag.dag_id + + log.debug("Syncing DAG permissions: %s to the DB", dag_id) + from airflow.www.security_appless import ApplessAirflowSecurityManager + + security_manager = ApplessAirflowSecurityManager(session=session) + security_manager.sync_perm_for_dag(dag_id, dag.access_control) + + +def _update_dag_warnings( + dag_ids: list[str], warnings: set[DagWarning], warning_types: tuple[DagWarningType], session: Session +): + from airflow.models.dagwarning import DagWarning + + stored_warnings = set( + session.scalars( + select(DagWarning).where( + DagWarning.dag_id.in_(dag_ids), + DagWarning.warning_type.in_(warning_types), + ) + ) + ) + + for warning_to_delete in stored_warnings - warnings: + session.delete(warning_to_delete) + + for warning_to_add in warnings: + session.merge(warning_to_add) + + +def _update_import_errors( + files_parsed: set[str], import_errors: dict[str, str], processor_subdir: str | None, session: Session +): + from airflow.listeners.listener import get_listener_manager + + # We can remove anything from files parsed in this batch that doesn't have an error. We need to remove old + # errors (i.e. from files that are removed) separately + + session.execute(delete(ParseImportError).where(ParseImportError.filename.in_(list(files_parsed)))) + + query = select(ParseImportError.filename).where(ParseImportError.processor_subdir == processor_subdir) + existing_import_error_files = set(session.scalars(query)) + + # Add the errors of the processed files + for filename, stacktrace in import_errors.items(): + if filename in existing_import_error_files: + session.query(ParseImportError).where(ParseImportError.filename == filename).update( + {"filename": filename, "timestamp": utcnow(), "stacktrace": stacktrace}, + ) + # sending notification when an existing dag import error occurs + get_listener_manager().hook.on_existing_dag_import_error(filename=filename, stacktrace=stacktrace) + else: + session.add( + ParseImportError( + filename=filename, + timestamp=utcnow(), + stacktrace=stacktrace, + processor_subdir=processor_subdir, + ) + ) + # sending notification when a new dag import error occurs + get_listener_manager().hook.on_new_dag_import_error(filename=filename, stacktrace=stacktrace) + session.query(DagModel).filter(DagModel.fileloc == filename).update({"has_import_errors": True}) + + +def update_dag_parsing_results_in_db( + dags: Collection[DAG], + import_errors: dict[str, str], + processor_subdir: str | None, + warnings: set[DagWarning], + session: Session, + *, + warning_types: tuple[DagWarningType] = (DagWarningType.NONEXISTENT_POOL,), +): + """ + Update everything to do with DAG parsing in the DB. + + This function will create or update rows in the following tables: + + - DagModel (`dag` table), DagTag, DagCode and DagVersion + - SerializedDagModel (`serialized_dag` table) + - ParseImportError (including with any errors as a result of serialization, not just parsing) + - DagWarning + - DAG Permissions + + This function will not remove any rows for dags not passed in. It will remove parse errors and warnings + from dags/dag files that are passed in. In order words, if a DAG is passed in with a fileloc of `a.py` + then all warnings and errors related to this file will be removed. + + ``import_errors`` will be updated in place with an new errors + """ + # Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case + # of any Operational Errors + # In case of failures, provide_session handles rollback + for attempt in run_with_db_retries(logger=log): + with attempt: + serialize_errors = [] + log.debug( + "Running dagbag.bulk_write_to_db with retries. Try %d of %d", + attempt.retry_state.attempt_number, + MAX_DB_RETRIES, + ) + log.debug("Calling the DAG.bulk_sync_to_db method") + try: + DAG.bulk_write_to_db(dags, processor_subdir=processor_subdir, session=session) + # Write Serialized DAGs to DB, capturing errors + # Write Serialized DAGs to DB, capturing errors + for dag in dags: + serialize_errors.extend(_serialize_dag_capturing_errors(dag, session, processor_subdir)) + except OperationalError: + session.rollback() + raise + # Only now we are "complete" do we update import_errors - don't want to record errors from + # previous failed attempts + import_errors.update(dict(serialize_errors)) + + # Record import errors into the ORM - we don't retry on this one as it's not as critical that it works + try: + # TODO: This won't clear errors for files that exist that no longer contain DAGs. Do we need to pass + # in the list of file parsed? + + good_dag_filelocs = {dag.fileloc for dag in dags if dag.fileloc not in import_errors} + _update_import_errors( + files_parsed=good_dag_filelocs, + import_errors=import_errors, + processor_subdir=processor_subdir, + session=session, + ) + except Exception: + log.exception("Error logging import errors!") + + # Record DAG warnings in the metadatabase. + try: + _update_dag_warnings([dag.dag_id for dag in dags], warnings, warning_types, session) + except Exception: + log.exception("Error logging DAG warnings.") + + session.flush() + + class DagModelOperation(NamedTuple): """Collect DAG objects and perform database operations for them.""" diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index f60377d496625..57c69238a1f7e 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -661,10 +661,7 @@ def _refresh_dag_dir(self) -> bool: self.set_file_paths(self._file_paths) try: - self.log.debug("Removing old import errors") - DagFileProcessorManager.clear_nonexistent_import_errors( - file_paths=self._file_paths, processor_subdir=self.get_dag_directory() - ) + self.clear_nonexistent_import_errors() except Exception: self.log.exception("Error removing old import errors") @@ -702,24 +699,19 @@ def _print_stat(self): self._log_file_processing_stats(self._file_paths) self.last_stat_print_time = time.monotonic() - @staticmethod @provide_session - def clear_nonexistent_import_errors( - file_paths: list[str] | None, processor_subdir: str | None, session=NEW_SESSION - ): + def clear_nonexistent_import_errors(self, session=NEW_SESSION): """ Clear import errors for files that no longer exist. :param file_paths: list of paths to DAG definition files :param session: session for ORM operations """ - query = delete(ParseImportError) + self.log.debug("Removing old import errors") + query = delete(ParseImportError).where(ParseImportError.processor_subdir == self.get_dag_directory()) - if file_paths: - query = query.where( - ~ParseImportError.filename.in_(file_paths), - ParseImportError.processor_subdir == processor_subdir, - ) + if self._file_paths: + query = query.where(ParseImportError.filename.notin_(self._file_paths)) session.execute(query.execution_options(synchronize_session="fetch")) session.commit() diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index b3e6ff770b8c1..a2b4474402588 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING from setproctitle import setproctitle -from sqlalchemy import delete, event, select +from sqlalchemy import event from airflow import settings from airflow.callbacks.callback_requests import ( @@ -38,11 +38,8 @@ ) from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.listeners.listener import get_listener_manager -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DAG from airflow.models.dagbag import DagBag -from airflow.models.dagwarning import DagWarning, DagWarningType -from airflow.models.errors import ParseImportError from airflow.models.pool import Pool from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, _run_finished_callback @@ -136,6 +133,7 @@ def _run_file_processor( thread_name: str, dag_directory: str, callback_requests: list[CallbackRequest], + known_pools: set[str] | None = None, ) -> None: """ Process the given file. @@ -172,6 +170,7 @@ def _handle_dag_file_processing(): result: tuple[int, int, int] = dag_file_processor.process_file( file_path=file_path, callback_requests=callback_requests, + known_pools=known_pools, ) result_channel.send(result) @@ -228,6 +227,8 @@ def start(self) -> None: context = self._get_multiprocessing_context() + pool_names = {p.pool for p in Pool.get_pools()} + _parent_channel, _child_channel = context.Pipe(duplex=False) process = context.Process( target=type(self)._run_file_processor, @@ -238,6 +239,7 @@ def start(self) -> None: f"DagFileProcessor{self._instance_id}", self._dag_directory, self._callback_requests, + pool_names, ), name=f"DagFileProcessor{self._instance_id}-Process", ) @@ -417,113 +419,8 @@ def __init__(self, dag_directory: str, log: logging.Logger): super().__init__() self._log = log self._dag_directory = dag_directory - self.dag_warnings: set[tuple[str, str]] = set() self._last_num_of_db_queries = 0 - @staticmethod - @provide_session - def update_import_errors( - file_last_changed: dict[str, datetime], - import_errors: dict[str, str], - processor_subdir: str | None, - session: Session = NEW_SESSION, - ) -> None: - """ - Update any import errors to be displayed in the UI. - - For the DAGs in the given DagBag, record any associated import errors and clears - errors for files that no longer have them. These are usually displayed through the - Airflow UI so that users know that there are issues parsing DAGs. - :param file_last_changed: Dictionary containing the last changed time of the files - :param import_errors: Dictionary containing the import errors - :param session: session for ORM operations - """ - files_without_error = file_last_changed - import_errors.keys() - - # Clear the errors of the processed files - # that no longer have errors - for dagbag_file in files_without_error: - session.execute( - delete(ParseImportError) - .where(ParseImportError.filename.startswith(dagbag_file)) - .execution_options(synchronize_session="fetch") - ) - - # files that still have errors - existing_import_error_files = [x.filename for x in session.query(ParseImportError.filename).all()] - - # Add the errors of the processed files - for filename, stacktrace in import_errors.items(): - if filename in existing_import_error_files: - session.query(ParseImportError).filter(ParseImportError.filename == filename).update( - {"filename": filename, "timestamp": timezone.utcnow(), "stacktrace": stacktrace}, - synchronize_session="fetch", - ) - # sending notification when an existing dag import error occurs - get_listener_manager().hook.on_existing_dag_import_error( - filename=filename, stacktrace=stacktrace - ) - else: - session.add( - ParseImportError( - filename=filename, - timestamp=timezone.utcnow(), - stacktrace=stacktrace, - processor_subdir=processor_subdir, - ) - ) - # sending notification when a new dag import error occurs - get_listener_manager().hook.on_new_dag_import_error(filename=filename, stacktrace=stacktrace) - ( - session.query(DagModel) - .filter(DagModel.fileloc == filename) - .update({"has_import_errors": True}, synchronize_session="fetch") - ) - - session.commit() - session.flush() - - @classmethod - @provide_session - def update_dag_warnings(cla, *, dagbag: DagBag, session: Session = NEW_SESSION) -> None: - """Validate and raise exception if any task in a dag is using a non-existent pool.""" - - def get_pools(dag) -> dict[str, set[str]]: - return {dag.dag_id: {task.pool for task in dag.tasks}} - - pool_dict: dict[str, set[str]] = {} - for dag in dagbag.dags.values(): - pool_dict.update(get_pools(dag)) - dag_ids = {dag.dag_id for dag in dagbag.dags.values()} - - all_pools = {p.pool for p in Pool.get_pools(session)} - warnings: set[DagWarning] = set() - for dag_id, dag_pools in pool_dict.items(): - nonexistent_pools = dag_pools - all_pools - if nonexistent_pools: - warnings.add( - DagWarning( - dag_id, - DagWarningType.NONEXISTENT_POOL, - f"Dag '{dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}", - ) - ) - - stored_warnings = set( - session.scalars( - select(DagWarning).where( - DagWarning.dag_id.in_(dag_ids), - DagWarning.warning_type == DagWarningType.NONEXISTENT_POOL, - ) - ) - ) - - for warning_to_delete in stored_warnings - warnings: - session.delete(warning_to_delete) - - for warning_to_add in warnings: - session.merge(warning_to_add) - @classmethod @provide_session def execute_callbacks( @@ -666,9 +563,9 @@ def _execute_task_callbacks( session.flush() @classmethod - def _get_dagbag(cls, file_path: str): + def _get_dagbag(cls, file_path: str, known_pools: set[str] | None): try: - return DagBag(file_path, include_examples=False) + return DagBag(file_path, include_examples=False, known_pools=known_pools) except Exception: cls.logger().exception("Failed at reloading the DAG file %s", file_path) Stats.incr("dag_file_refresh_error", tags={"file_path": file_path}) @@ -679,6 +576,7 @@ def process_file( self, file_path: str, callback_requests: list[CallbackRequest], + known_pools: set[str] | None = None, session: Session = NEW_SESSION, ) -> tuple[int, int, int]: """ @@ -700,7 +598,7 @@ def process_file( with count_queries(session) as query_counter: try: - dagbag = DagFileProcessor._get_dagbag(file_path) + dagbag = DagFileProcessor._get_dagbag(file_path, known_pools) except Exception: self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr("dag_file_refresh_error", 1, 1, tags={"file_path": file_path}) @@ -708,44 +606,16 @@ def process_file( if dagbag.dags: self.log.info("DAG(s) %s retrieved from %s", ", ".join(map(repr, dagbag.dags)), file_path) + self.execute_callbacks(dagbag, callback_requests, self.UNIT_TEST_MODE) else: self.log.warning("No viable dags retrieved from %s", file_path) - DagFileProcessor.update_import_errors( - file_last_changed=dagbag.file_last_changed, - import_errors=dagbag.import_errors, - processor_subdir=self._dag_directory, - ) if callback_requests: # If there were callback requests for this file but there was a # parse error we still need to progress the state of TIs, # otherwise they might be stuck in queued/running for ever! DagFileProcessor.execute_callbacks_without_dag(callback_requests, self.UNIT_TEST_MODE) - return 0, len(dagbag.import_errors), self._cache_last_num_of_db_queries(query_counter) - - self.execute_callbacks(dagbag, callback_requests, self.UNIT_TEST_MODE) - - serialize_errors = DagFileProcessor.save_dag_to_db( - dags=dagbag.dags, - dag_directory=self._dag_directory, - ) - - dagbag.import_errors.update(dict(serialize_errors)) - # Record import errors into the ORM - try: - DagFileProcessor.update_import_errors( - file_last_changed=dagbag.file_last_changed, - import_errors=dagbag.import_errors, - processor_subdir=self._dag_directory, - ) - except Exception: - self.log.exception("Error logging import errors!") - - # Record DAG warnings in the metadatabase. - try: - self.update_dag_warnings(dagbag=dagbag) - except Exception: - self.log.exception("Error logging DAG warnings.") + dagbag.sync_to_db(self._dag_directory, session=session) return len(dagbag.dags), len(dagbag.import_errors), self._cache_last_num_of_db_queries(query_counter) @@ -753,14 +623,3 @@ def _cache_last_num_of_db_queries(self, query_counter: _QueryCounter | None = No if query_counter: self._last_num_of_db_queries = query_counter.queries_number return self._last_num_of_db_queries - - @staticmethod - @provide_session - def save_dag_to_db( - dags: dict[str, DAG], - dag_directory: str, - session=NEW_SESSION, - ): - import_errors = DagBag._sync_to_db(dags=dags, processor_subdir=dag_directory, session=session) - session.commit() - return import_errors diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 1e98e4922e4f2..03e416c2478b8 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -34,7 +34,6 @@ Column, String, ) -from sqlalchemy.exc import OperationalError from tabulate import tabulate from airflow import settings @@ -50,7 +49,6 @@ ) from airflow.listeners.listener import get_listener_manager from airflow.models.base import Base -from airflow.models.dagcode import DagCode from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.dag_cycle_tester import check_cycle @@ -62,7 +60,6 @@ might_contain_dag, ) from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.timeout import timeout from airflow.utils.types import NOTSET @@ -72,6 +69,7 @@ from sqlalchemy.orm import Session from airflow.models.dag import DAG + from airflow.models.dagwarning import DagWarning from airflow.utils.types import ArgNotSet @@ -117,6 +115,7 @@ class DagBag(LoggingMixin): de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links are not loaded to not run User code in Scheduler. :param collect_dags: when True, collects dags during class initialization. + :param known_pools: If not none, then generate warnings if a Task attempts to use an unknown pool. """ def __init__( @@ -127,9 +126,8 @@ def __init__( read_dags_from_db: bool = False, load_op_links: bool = True, collect_dags: bool = True, + known_pools: set[str] | None = None, ): - # Avoid circular import - super().__init__() include_examples = ( @@ -155,6 +153,8 @@ def __init__( # Only used by SchedulerJob to compare the dag_hash to identify change in DAGs self.dags_hash: dict[str, str] = {} + self.known_pools = known_pools + self.dagbag_import_error_tracebacks = conf.getboolean("core", "dagbag_import_error_tracebacks") self.dagbag_import_error_traceback_depth = conf.getint("core", "dagbag_import_error_traceback_depth") if collect_dags: @@ -328,6 +328,35 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): self.file_last_changed[filepath] = file_last_changed_on_disk return found_dags + @property + def dag_warnings(self) -> set[DagWarning]: + """Get the set of DagWarnings for the bagged dags.""" + from airflow.models.dagwarning import DagWarning, DagWarningType + + # None means this feature is not enabled. Empty set means we don't know about any pools at all! + if self.known_pools is None: + return set() + + def get_pools(dag) -> dict[str, set[str]]: + return {dag.dag_id: {task.pool for task in dag.tasks}} + + pool_dict: dict[str, set[str]] = {} + for dag in self.dags.values(): + pool_dict.update(get_pools(dag)) + + warnings: set[DagWarning] = set() + for dag_id, dag_pools in pool_dict.items(): + nonexistent_pools = dag_pools - self.known_pools + if nonexistent_pools: + warnings.add( + DagWarning( + dag_id, + DagWarningType.NONEXISTENT_POOL, + f"Dag '{dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}", + ) + ) + return warnings + def _load_modules_from_file(self, filepath, safe_mode): from airflow.sdk.definitions.contextmanager import DagContext @@ -596,95 +625,18 @@ def dagbag_report(self): ) return report - @classmethod - @provide_session - def _sync_to_db( - cls, - dags: dict[str, DAG], - processor_subdir: str | None = None, - session: Session = NEW_SESSION, - ): - """Save attributes about list of DAG to the DB.""" - # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag - from airflow.models.dag import DAG - from airflow.models.serialized_dag import SerializedDagModel - - log = cls.logger() - - def _serialize_dag_capturing_errors(dag, session, processor_subdir): - """ - Try to serialize the dag to the DB, but make a note of any errors. - - We can't place them directly in import_errors, as this may be retried, and work the next time - """ - try: - # We can't use bulk_write_to_db as we want to capture each error individually - dag_was_updated = SerializedDagModel.write_dag( - dag, - min_update_interval=settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL, - session=session, - processor_subdir=processor_subdir, - ) - if dag_was_updated: - DagBag._sync_perm_for_dag(dag, session=session) - else: - # Check and update DagCode - DagCode.update_source_code(dag) - return [] - except OperationalError: - raise - except Exception: - log.exception("Failed to write serialized DAG: %s", dag.fileloc) - dagbag_import_error_traceback_depth = conf.getint( - "core", "dagbag_import_error_traceback_depth" - ) - return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))] - - # Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case - # of any Operational Errors - # In case of failures, provide_session handles rollback - import_errors = {} - for attempt in run_with_db_retries(logger=log): - with attempt: - serialize_errors = [] - log.debug( - "Running dagbag.sync_to_db with retries. Try %d of %d", - attempt.retry_state.attempt_number, - MAX_DB_RETRIES, - ) - log.debug("Calling the DAG.bulk_sync_to_db method") - try: - DAG.bulk_write_to_db(dags.values(), processor_subdir=processor_subdir, session=session) - # Write Serialized DAGs to DB, capturing errors - for dag in dags.values(): - serialize_errors.extend( - _serialize_dag_capturing_errors(dag, session, processor_subdir) - ) - except OperationalError: - session.rollback() - raise - # Only now we are "complete" do we update import_errors - don't want to record errors from - # previous failed attempts - import_errors.update(dict(serialize_errors)) - - return import_errors - @provide_session def sync_to_db(self, processor_subdir: str | None = None, session: Session = NEW_SESSION): - import_errors = DagBag._sync_to_db(dags=self.dags, processor_subdir=processor_subdir, session=session) - self.import_errors.update(import_errors) - - @classmethod - @provide_session - def _sync_perm_for_dag(cls, dag: DAG, session: Session = NEW_SESSION): - """Sync DAG specific permissions.""" - dag_id = dag.dag_id - - cls.logger().debug("Syncing DAG permissions: %s to the DB", dag_id) - from airflow.www.security_appless import ApplessAirflowSecurityManager - - security_manager = ApplessAirflowSecurityManager(session=session) - security_manager.sync_perm_for_dag(dag_id, dag.access_control) + """Save attributes about list of DAG to the DB.""" + from airflow.dag_processing.collection import update_dag_parsing_results_in_db + + update_dag_parsing_results_in_db( + self.dags.values(), + self.import_errors, + processor_subdir, + self.dag_warnings, + session=session, + ) def generate_md5_hash(context): diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 437bfd81895ca..d1662943604ac 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -86,7 +86,7 @@ def setup_dag(self, configured_app): mapped_dag.dag_id: mapped_dag, unscheduled_dag.dag_id: unscheduled_dag, } - DagBag._sync_to_db(dag_bag.dags) + dag_bag.sync_to_db() configured_app.dag_bag = dag_bag # type:ignore @staticmethod diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index d0256c5c288d4..bbd75361e7970 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -18,25 +18,50 @@ from __future__ import annotations +import logging import warnings from collections.abc import Generator from datetime import timedelta +from typing import TYPE_CHECKING +from unittest import mock +from unittest.mock import patch import pytest -from sqlalchemy.exc import SAWarning +from sqlalchemy import func, select +from sqlalchemy.exc import OperationalError, SAWarning -from airflow.dag_processing.collection import AssetModelOperation, _get_latest_runs_stmt +import airflow.dag_processing.collection +from airflow.dag_processing.collection import ( + AssetModelOperation, + _get_latest_runs_stmt, + _sync_dag_perms, + update_dag_parsing_results_in_db, +) +from airflow.exceptions import SerializationError from airflow.models import DagModel, Trigger from airflow.models.asset import ( AssetActive, asset_trigger_association_table, ) +from airflow.models.dag import DAG +from airflow.models.errors import ParseImportError +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger from airflow.sdk.definitions.asset import Asset +from airflow.utils import timezone as tz from airflow.utils.session import create_session -from tests_common.test_utils.db import clear_db_assets, clear_db_dags, clear_db_triggers +from tests_common.test_utils.db import ( + clear_db_assets, + clear_db_dags, + clear_db_import_errors, + clear_db_serialized_dags, + clear_db_triggers, +) + +if TYPE_CHECKING: + from kgb import SpyAgency def test_statement_latest_runs_one_dag(): @@ -129,3 +154,249 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t assert session.query(Trigger).count() == expected_num_triggers assert session.query(asset_trigger_association_table).count() == expected_num_triggers + + +@pytest.mark.db_test +class TestUpdateDagParsingResults: + """Tests centred around the ``update_dag_parsing_results_in_db`` function.""" + + @pytest.fixture + def clean_db(self, session): + yield + clear_db_serialized_dags() + clear_db_dags() + clear_db_import_errors() + + @pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session commit hygiene + def test_sync_perms_syncs_dag_specific_perms_on_update( + self, monkeypatch, spy_agency: SpyAgency, session, time_machine + ): + """ + Test that dagbag.sync_to_db will sync DAG specific permissions when a DAG is + new or updated + """ + from airflow import settings + + serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() + assert serialized_dags_count == 0 + + monkeypatch.setattr(settings, "MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5) + time_machine.move_to(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False) + + dag = DAG(dag_id="test") + + sync_perms_spy = spy_agency.spy_on( + airflow.dag_processing.collection._sync_dag_perms, + call_original=False, + ) + + def _sync_to_db(): + sync_perms_spy.reset_calls() + time_machine.shift(20) + + update_dag_parsing_results_in_db([dag], dict(), None, set(), session) + + _sync_to_db() + spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session) + + # DAG isn't updated + _sync_to_db() + spy_agency.assert_spy_not_called(sync_perms_spy) + + # DAG is updated + dag.tags = {"new_tag"} + _sync_to_db() + spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session) + + serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() + + @patch.object(SerializedDagModel, "write_dag") + @patch("airflow.models.dag.DAG.bulk_write_to_db") + def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, session): + """Test that important DB operations in db sync are retried on OperationalError""" + + serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() + assert serialized_dags_count == 0 + mock_dag = mock.MagicMock() + dags = [mock_dag] + + op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY) + + # Mock error for the first 2 tries and a successful third try + side_effect = [op_error, op_error, mock.ANY] + + mock_bulk_write_to_db.side_effect = side_effect + + mock_session = mock.MagicMock() + update_dag_parsing_results_in_db( + dags=dags, import_errors={}, processor_subdir=None, warnings=set(), session=mock_session + ) + + # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully + mock_bulk_write_to_db.assert_has_calls( + [ + mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), + mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), + mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), + ] + ) + # Assert that rollback is called twice (i.e. whenever OperationalError occurs) + mock_session.rollback.assert_has_calls([mock.call(), mock.call()]) + # Check that 'SerializedDagModel.write_dag' is also called + # Only called once since the other two times the 'DAG.bulk_write_to_db' error'd + # and the session was roll-backed before even reaching 'SerializedDagModel.write_dag' + mock_s10n_write_dag.assert_has_calls( + [ + mock.call( + mock_dag, min_update_interval=mock.ANY, processor_subdir=None, session=mock_session + ), + ] + ) + + serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() + assert serialized_dags_count == 0 + + def test_serialized_dags_are_written_to_db_on_sync(self, session): + """ + Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB + even when dagbag.read_dags_from_db is False + """ + serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() + assert serialized_dags_count == 0 + + dag = DAG(dag_id="test") + + update_dag_parsing_results_in_db([dag], dict(), None, set(), session) + + new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() + assert new_serialized_dags_count == 1 + + @patch.object(SerializedDagModel, "write_dag") + def test_serialized_dag_errors_are_import_errors(self, mock_serialize, caplog, session): + """ + Test that errors serializing a DAG are recorded as import_errors in the DB + """ + mock_serialize.side_effect = SerializationError + + caplog.set_level(logging.ERROR) + + dag = DAG(dag_id="test") + dag.fileloc = "abc.py" + + import_errors = {} + update_dag_parsing_results_in_db([dag], import_errors, None, set(), session) + assert "SerializationError" in caplog.text + + # Should have been edited in places + err = import_errors.get(dag.fileloc) + assert "SerializationError" in err + + dag_model: DagModel = session.get(DagModel, (dag.dag_id,)) + assert dag_model.has_import_errors is True + + import_errors = session.query(ParseImportError).all() + + assert len(import_errors) == 1 + import_error = import_errors[0] + assert import_error.filename == dag.fileloc + assert "SerializationError" in import_error.stacktrace + + def test_new_import_error_replaces_old(self, session): + """ + Test that existing import error is updated and new record not created + for a dag with the same filename + """ + filename = "abc.py" + prev_error = ParseImportError( + filename=filename, + timestamp=tz.utcnow(), + stacktrace="Some error", + processor_subdir=None, + ) + session.add(prev_error) + session.flush() + prev_error_id = prev_error.id + + update_dag_parsing_results_in_db( + dags=[], + import_errors={"abc.py": "New error"}, + processor_subdir=None, + warnings=set(), + session=session, + ) + + import_error = session.query(ParseImportError).filter(ParseImportError.filename == filename).one() + + # assert that the ID of the import error did not change + assert import_error.id == prev_error_id + assert import_error.stacktrace == "New error" + + def test_remove_error_clears_import_error(self, session): + # Pre-condition: there is an import error for the dag file + filename = "abc.py" + prev_error = ParseImportError( + filename=filename, + timestamp=tz.utcnow(), + stacktrace="Some error", + processor_subdir=None, + ) + session.add(prev_error) + + # And one for another file we haven't been given results for -- this shouldn't be deleted + session.add( + ParseImportError( + filename="def.py", + timestamp=tz.utcnow(), + stacktrace="Some error", + processor_subdir=None, + ) + ) + session.flush() + + # Sanity check of pre-condition + import_errors = set(session.scalars(select(ParseImportError.filename))) + assert import_errors == {"abc.py", "def.py"} + + dag = DAG(dag_id="test") + dag.fileloc = filename + + import_errors = {} + update_dag_parsing_results_in_db([dag], import_errors, None, set(), session) + + dag_model: DagModel = session.get(DagModel, (dag.dag_id,)) + assert dag_model.has_import_errors is False + + import_errors = set(session.scalars(select(ParseImportError.filename))) + + assert import_errors == {"def.py"} + + def test_sync_perm_for_dag_with_dict_access_control(self, session, spy_agency: SpyAgency): + """ + Test that dagbag._sync_perm_for_dag will call ApplessAirflowSecurityManager.sync_perm_for_dag + """ + from airflow.www.security_appless import ApplessAirflowSecurityManager + + spy = spy_agency.spy_on( + ApplessAirflowSecurityManager.sync_perm_for_dag, owner=ApplessAirflowSecurityManager + ) + + dag = DAG(dag_id="test") + + def _sync_perms(): + spy.reset_calls() + _sync_dag_perms(dag, session=session) + + # perms dont exist + _sync_perms() + spy_agency.assert_spy_called_with(spy, dag.dag_id, access_control=None) + + # perms now exist + _sync_perms() + spy_agency.assert_spy_called_with(spy, dag.dag_id, access_control=None) + + # Always sync if we have access_control + dag.access_control = {"Public": {"DAGs": {"can_read"}, "DAG Runs": {"can_create"}}} + _sync_perms() + spy_agency.assert_spy_called_with( + spy, dag.dag_id, access_control={"Public": {"DAGs": {"can_read"}, "DAG Runs": {"can_create"}}} + ) diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index b23cd44f959ae..d3ecd98b91680 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import os import pathlib import sys from unittest import mock @@ -26,11 +25,10 @@ import pytest -from airflow import settings from airflow.callbacks.callback_requests import TaskCallbackRequest from airflow.configuration import TEST_DAGS_FOLDER, conf from airflow.dag_processing.processor import DagFileProcessor, DagFileProcessorProcess -from airflow.models import DagBag, DagModel, TaskInstance +from airflow.models import DagBag, TaskInstance from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone @@ -39,13 +37,10 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.asserts import assert_queries_count -from tests_common.test_utils.compat import ParseImportError from tests_common.test_utils.config import conf_vars, env_vars from tests_common.test_utils.db import ( clear_db_dags, - clear_db_import_errors, clear_db_jobs, - clear_db_pools, clear_db_runs, clear_db_serialized_dags, ) @@ -64,8 +59,6 @@ # tricking airflow into thinking these # files contain a DAG (otherwise Airflow will skip them) PARSEABLE_DAG_FILE_CONTENTS = '"airflow DAG"' -UNPARSEABLE_DAG_FILE_CONTENTS = "airflow DAG" -INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()" # Filename to be used for dags that are created in an ad-hoc manner and can be removed/ # created at runtime @@ -85,9 +78,7 @@ class TestDagFileProcessor: @staticmethod def clean_db(): clear_db_runs() - clear_db_pools() clear_db_dags() - clear_db_import_errors() clear_db_jobs() clear_db_serialized_dags() @@ -246,320 +237,6 @@ def test_process_file_should_failure_callback(self, monkeypatch, tmp_path, get_t msg = " ".join([str(k) for k in ti.key.primary]) + " fired callback" assert msg in callback_file.read_text() - @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) - def test_add_unparseable_file_before_sched_start_creates_import_error(self, tmp_path): - unparseable_filename = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - with open(unparseable_filename, "w") as unparseable_file: - unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - - with create_session() as session: - self._process_file(unparseable_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == unparseable_filename - assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)" - session.rollback() - - @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) - def test_add_unparseable_zip_file_creates_import_error(self, tmp_path): - zip_filename = (tmp_path / "test_zip.zip").as_posix() - invalid_dag_filename = os.path.join(zip_filename, TEMP_DAG_FILENAME) - with ZipFile(zip_filename, "w") as zip_file: - zip_file.writestr(TEMP_DAG_FILENAME, UNPARSEABLE_DAG_FILE_CONTENTS) - - with create_session() as session: - self._process_file(zip_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == invalid_dag_filename - assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)" - session.rollback() - - @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) - def test_dag_model_has_import_error_is_true_when_import_error_exists(self, tmp_path, session): - dag_file = os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py") - temp_dagfile = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - with open(dag_file) as main_dag, open(temp_dagfile, "w") as next_dag: - for line in main_dag: - next_dag.write(line) - # first we parse the dag - self._process_file(temp_dagfile, dag_directory=tmp_path, session=session) - # assert DagModel.has_import_errors is false - dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() - assert not dm.has_import_errors - # corrupt the file - with open(temp_dagfile, "a") as file: - file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - - self._process_file(temp_dagfile, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == temp_dagfile - assert import_error.stacktrace - dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() - assert dm.has_import_errors - - def test_no_import_errors_with_parseable_dag(self, tmp_path): - parseable_filename = tmp_path / TEMP_DAG_FILENAME - parseable_filename.write_text(PARSEABLE_DAG_FILE_CONTENTS) - - with create_session() as session: - self._process_file(parseable_filename.as_posix(), dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 0 - - session.rollback() - - def test_no_import_errors_with_parseable_dag_in_zip(self, tmp_path): - zip_filename = (tmp_path / "test_zip.zip").as_posix() - with ZipFile(zip_filename, "w") as zip_file: - zip_file.writestr(TEMP_DAG_FILENAME, PARSEABLE_DAG_FILE_CONTENTS) - - with create_session() as session: - self._process_file(zip_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 0 - - session.rollback() - - @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) - def test_new_import_error_replaces_old(self, tmp_path): - unparseable_filename = tmp_path / TEMP_DAG_FILENAME - # Generate original import error - unparseable_filename.write_text(UNPARSEABLE_DAG_FILE_CONTENTS) - - session = settings.Session() - self._process_file(unparseable_filename.as_posix(), dag_directory=tmp_path, session=session) - - # Generate replacement import error (the error will be on the second line now) - unparseable_filename.write_text( - PARSEABLE_DAG_FILE_CONTENTS + os.linesep + UNPARSEABLE_DAG_FILE_CONTENTS - ) - self._process_file(unparseable_filename.as_posix(), dag_directory=tmp_path, session=session) - - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == unparseable_filename.as_posix() - assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)" - - session.rollback() - - def test_import_error_record_is_updated_not_deleted_and_recreated(self, tmp_path): - """ - Test that existing import error is updated and new record not created - for a dag with the same filename - """ - filename_to_parse = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - # Generate original import error - with open(filename_to_parse, "w") as file_to_parse: - file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - session = settings.Session() - self._process_file(filename_to_parse, dag_directory=tmp_path, session=session) - - import_error_1 = ( - session.query(ParseImportError).filter(ParseImportError.filename == filename_to_parse).one() - ) - - # process the file multiple times - for _ in range(10): - self._process_file(filename_to_parse, dag_directory=tmp_path, session=session) - - import_error_2 = ( - session.query(ParseImportError).filter(ParseImportError.filename == filename_to_parse).one() - ) - - # assert that the ID of the import error did not change - assert import_error_1.id == import_error_2.id - - def test_remove_error_clears_import_error(self, tmp_path): - filename_to_parse = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - - # Generate original import error - with open(filename_to_parse, "w") as file_to_parse: - file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - session = settings.Session() - self._process_file(filename_to_parse, dag_directory=tmp_path, session=session) - - # Remove the import error from the file - with open(filename_to_parse, "w") as file_to_parse: - file_to_parse.writelines(PARSEABLE_DAG_FILE_CONTENTS) - self._process_file(filename_to_parse, dag_directory=tmp_path, session=session) - - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 0 - - session.rollback() - - def test_remove_error_clears_import_error_zip(self, tmp_path): - session = settings.Session() - - # Generate original import error - zip_filename = (tmp_path / "test_zip.zip").as_posix() - with ZipFile(zip_filename, "w") as zip_file: - zip_file.writestr(TEMP_DAG_FILENAME, UNPARSEABLE_DAG_FILE_CONTENTS) - self._process_file(zip_filename, dag_directory=tmp_path, session=session) - - import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 1 - - # Remove the import error from the file - with ZipFile(zip_filename, "w") as zip_file: - zip_file.writestr(TEMP_DAG_FILENAME, "import os # airflow DAG") - self._process_file(zip_filename, dag_directory=tmp_path, session=session) - - import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 0 - - session.rollback() - - def test_import_error_tracebacks(self, tmp_path): - unparseable_filename = (tmp_path / TEMP_DAG_FILENAME).as_posix() - with open(unparseable_filename, "w") as unparseable_file: - unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) - - with create_session() as session: - self._process_file(unparseable_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == unparseable_filename - if PY311: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 3, in \n' - " something()\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - " ^^^^^^^^^^^\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - else: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 3, in \n' - " something()\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - assert import_error.stacktrace == expected_stacktrace.format( - unparseable_filename, unparseable_filename - ) - session.rollback() - - @conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"}) - def test_import_error_traceback_depth(self, tmp_path): - unparseable_filename = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - with open(unparseable_filename, "w") as unparseable_file: - unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) - - with create_session() as session: - self._process_file(unparseable_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == unparseable_filename - if PY311: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - " ^^^^^^^^^^^\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - else: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - assert import_error.stacktrace == expected_stacktrace.format(unparseable_filename) - - session.rollback() - - def test_import_error_tracebacks_zip(self, tmp_path): - invalid_zip_filename = (tmp_path / "test_zip_invalid.zip").as_posix() - invalid_dag_filename = os.path.join(invalid_zip_filename, TEMP_DAG_FILENAME) - with ZipFile(invalid_zip_filename, "w") as invalid_zip_file: - invalid_zip_file.writestr(TEMP_DAG_FILENAME, INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) - - with create_session() as session: - self._process_file(invalid_zip_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == invalid_dag_filename - if PY311: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 3, in \n' - " something()\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - " ^^^^^^^^^^^\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - else: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 3, in \n' - " something()\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - assert import_error.stacktrace == expected_stacktrace.format( - invalid_dag_filename, invalid_dag_filename - ) - session.rollback() - - @conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"}) - def test_import_error_tracebacks_zip_depth(self, tmp_path): - invalid_zip_filename = (tmp_path / "test_zip_invalid.zip").as_posix() - invalid_dag_filename = os.path.join(invalid_zip_filename, TEMP_DAG_FILENAME) - with ZipFile(invalid_zip_filename, "w") as invalid_zip_file: - invalid_zip_file.writestr(TEMP_DAG_FILENAME, INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) - - with create_session() as session: - self._process_file(invalid_zip_filename, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == invalid_dag_filename - if PY311: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - " ^^^^^^^^^^^\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - else: - expected_stacktrace = ( - "Traceback (most recent call last):\n" - ' File "{}", line 2, in something\n' - " return airflow_DAG\n" - "NameError: name 'airflow_DAG' is not defined\n" - ) - assert import_error.stacktrace == expected_stacktrace.format(invalid_dag_filename) - session.rollback() - @conf_vars({("logging", "dag_processor_log_target"): "stdout"}) @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) @mock.patch("airflow.dag_processing.processor.redirect_stdout") diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 81f94e09b2ccb..f8cbc4d3e6632 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import contextlib import inspect import logging import os @@ -31,19 +32,16 @@ import pytest import time_machine -from sqlalchemy import func -from sqlalchemy.exc import OperationalError import airflow.example_dags from airflow import settings -from airflow.exceptions import SerializationError from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag +from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.serialized_dag import SerializedDagModel from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone as tz from airflow.utils.session import create_session -from airflow.www.security_appless import ApplessAirflowSecurityManager from tests import cluster_policies from tests.models import TEST_DAGS_FOLDER @@ -55,6 +53,13 @@ example_dags_folder = pathlib.Path(airflow.example_dags.__path__[0]) # type: ignore[attr-defined] +PY311 = sys.version_info >= (3, 11) + +# Include the words "airflow" and "dag" in the file contents, +# tricking airflow into thinking these +# files contain a DAG (otherwise Airflow will skip them) +INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()" + def db_clean_up(): db.clear_db_dags() @@ -210,6 +215,7 @@ def test_zip(self, tmp_path): dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")) assert dagbag.get_dag("test_zip_dag") assert sys.path == syspath_before # sys.path doesn't change + assert not dagbag.import_errors @patch("airflow.models.dagbag.timeout") @patch("airflow.models.dagbag.settings.get_dagbag_import_timeout") @@ -569,51 +575,6 @@ def test_deactivate_unknown_dags(self): with create_session() as session: session.query(DagModel).filter(DagModel.dag_id == "test_deactivate_unknown_dags").delete() - def test_serialized_dags_are_written_to_db_on_sync(self): - """ - Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB - even when dagbag.read_dags_from_db is False - """ - with create_session() as session: - serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() - assert serialized_dags_count == 0 - - dagbag = DagBag( - dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), - include_examples=False, - ) - dagbag.sync_to_db() - - assert not dagbag.read_dags_from_db - - new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() - assert new_serialized_dags_count == 1 - - @patch("airflow.models.serialized_dag.SerializedDagModel.write_dag") - def test_serialized_dag_errors_are_import_errors(self, mock_serialize, caplog): - """ - Test that errors serializing a DAG are recorded as import_errors in the DB - """ - mock_serialize.side_effect = SerializationError - - with create_session() as session: - path = os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py") - - dagbag = DagBag( - dag_folder=path, - include_examples=False, - ) - assert dagbag.import_errors == {} - - caplog.set_level(logging.ERROR) - dagbag.sync_to_db(session=session) - assert "SerializationError" in caplog.text - - assert path in dagbag.import_errors - err = dagbag.import_errors[path] - assert "SerializationError" in err - session.rollback() - def test_timeout_dag_errors_are_import_errors(self, tmp_path, caplog): """ Test that if the DAG contains Timeout error it will be still loaded to DB as import_errors @@ -655,153 +616,49 @@ def f(): assert "tmp_file.py" in dagbag.import_errors assert "DagBag import timeout for" in caplog.text - @patch("airflow.models.dagbag.DagBag.collect_dags") - @patch("airflow.models.serialized_dag.SerializedDagModel.write_dag") - @patch("airflow.models.dag.DAG.bulk_write_to_db") - def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, mock_collect_dags): - """Test that dagbag.sync_to_db is retried on OperationalError""" - - dagbag = DagBag("/dev/null") - mock_dag = mock.MagicMock() - dagbag.dags["mock_dag"] = mock_dag - - op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY) - - # Mock error for the first 2 tries and a successful third try - side_effect = [op_error, op_error, mock.ANY] - - mock_bulk_write_to_db.side_effect = side_effect - - mock_session = mock.MagicMock() - dagbag.sync_to_db(session=mock_session) - - # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully - mock_bulk_write_to_db.assert_has_calls( - [ - mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), - mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), - mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), - ] + @staticmethod + def _make_test_traceback(unparseable_filename: str, depth=None) -> str: + marker = " ^^^^^^^^^^^\n" if PY311 else "" + frames = ( + f' File "{unparseable_filename}", line 3, in \n something()\n', + f' File "{unparseable_filename}", line 2, in something\n return airflow_DAG\n{marker}', ) - # Assert that rollback is called twice (i.e. whenever OperationalError occurs) - mock_session.rollback.assert_has_calls([mock.call(), mock.call()]) - # Check that 'SerializedDagModel.write_dag' is also called - # Only called once since the other two times the 'DAG.bulk_write_to_db' error'd - # and the session was roll-backed before even reaching 'SerializedDagModel.write_dag' - mock_s10n_write_dag.assert_has_calls( - [ - mock.call( - mock_dag, min_update_interval=mock.ANY, processor_subdir=None, session=mock_session - ), - ] + depth = 0 if depth is None else -depth + return ( + "Traceback (most recent call last):\n" + + "".join(frames[depth:]) + + "NameError: name 'airflow_DAG' is not defined\n" ) - @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5) - @patch("airflow.models.dagbag.DagBag._sync_perm_for_dag") - def test_sync_to_db_syncs_dag_specific_perms_on_update(self, mock_sync_perm_for_dag): - """ - Test that dagbag.sync_to_db will sync DAG specific permissions when a DAG is - new or updated - """ - db_clean_up() - session = settings.Session() - with time_machine.travel(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False) as frozen_time: - dagbag = DagBag( - dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), - include_examples=False, - ) - - def _sync_to_db(): - mock_sync_perm_for_dag.reset_mock() - frozen_time.shift(20) - dagbag.sync_to_db(session=session) - - dag = dagbag.dags["test_example_bash_operator"] - dag.sync_to_db() - _sync_to_db() - mock_sync_perm_for_dag.assert_called_once_with(dag, session=session) - - # DAG isn't updated - _sync_to_db() - mock_sync_perm_for_dag.assert_not_called() - - # DAG is updated - dag.tags = ["new_tag"] - _sync_to_db() - session.commit() - mock_sync_perm_for_dag.assert_called_once_with(dag, session=session) - - @patch("airflow.www.security_appless.ApplessAirflowSecurityManager") - def test_sync_perm_for_dag(self, mock_security_manager): - """ - Test that dagbag._sync_perm_for_dag will call ApplessAirflowSecurityManager.sync_perm_for_dag - """ - db_clean_up() - with create_session() as session: - security_manager = ApplessAirflowSecurityManager(session) - mock_sync_perm_for_dag = mock_security_manager.return_value.sync_perm_for_dag - mock_sync_perm_for_dag.side_effect = security_manager.sync_perm_for_dag - - dagbag = DagBag( - dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), - include_examples=False, - ) - dag = dagbag.dags["test_example_bash_operator"] - - def _sync_perms(): - mock_sync_perm_for_dag.reset_mock() - DagBag._sync_perm_for_dag(dag, session=session) - - # perms dont exist - _sync_perms() - mock_sync_perm_for_dag.assert_called_once_with("test_example_bash_operator", None) - - # perms now exist - _sync_perms() - mock_sync_perm_for_dag.assert_called_once_with("test_example_bash_operator", None) - - # Always sync if we have access_control - dag.access_control = {"Public": {"can_read"}} - _sync_perms() - mock_sync_perm_for_dag.assert_called_once_with( - "test_example_bash_operator", {"Public": {"DAGs": {"can_read"}}} - ) - - @patch("airflow.www.security_appless.ApplessAirflowSecurityManager") - def test_sync_perm_for_dag_with_dict_access_control(self, mock_security_manager): - """ - Test that dagbag._sync_perm_for_dag will call ApplessAirflowSecurityManager.sync_perm_for_dag - """ - db_clean_up() - with create_session() as session: - security_manager = ApplessAirflowSecurityManager(session) - mock_sync_perm_for_dag = mock_security_manager.return_value.sync_perm_for_dag - mock_sync_perm_for_dag.side_effect = security_manager.sync_perm_for_dag - - dagbag = DagBag( - dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), - include_examples=False, - ) - dag = dagbag.dags["test_example_bash_operator"] - - def _sync_perms(): - mock_sync_perm_for_dag.reset_mock() - DagBag._sync_perm_for_dag(dag, session=session) - - # perms dont exist - _sync_perms() - mock_sync_perm_for_dag.assert_called_once_with("test_example_bash_operator", None) - - # perms now exist - _sync_perms() - mock_sync_perm_for_dag.assert_called_once_with("test_example_bash_operator", None) - - # Always sync if we have access_control - dag.access_control = {"Public": {"DAGs": {"can_read"}, "DAG Runs": {"can_create"}}} - _sync_perms() - mock_sync_perm_for_dag.assert_called_once_with( - "test_example_bash_operator", {"Public": {"DAGs": {"can_read"}, "DAG Runs": {"can_create"}}} - ) + @pytest.mark.parametrize(("depth",), ((None,), (1,))) + def test_import_error_tracebacks(self, tmp_path, depth): + unparseable_filename = tmp_path.joinpath("dag.py").as_posix() + with open(unparseable_filename, "w") as unparseable_file: + unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) + + with contextlib.ExitStack() as cm: + if depth is not None: + cm.enter_context(conf_vars({("core", "dagbag_import_error_traceback_depth"): str(depth)})) + dagbag = DagBag(dag_folder=unparseable_filename, include_examples=False) + import_errors = dagbag.import_errors + + assert unparseable_filename in import_errors + assert import_errors[unparseable_filename] == self._make_test_traceback(unparseable_filename, depth) + + @pytest.mark.parametrize(("depth",), ((None,), (1,))) + def test_import_error_tracebacks_zip(self, tmp_path, depth): + invalid_zip_filename = (tmp_path / "test_zip_invalid.zip").as_posix() + invalid_dag_filename = os.path.join(invalid_zip_filename, "dag.py") + with zipfile.ZipFile(invalid_zip_filename, "w") as invalid_zip_file: + invalid_zip_file.writestr("dag.py", INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) + + with contextlib.ExitStack() as cm: + if depth is not None: + cm.enter_context(conf_vars({("core", "dagbag_import_error_traceback_depth"): str(depth)})) + dagbag = DagBag(dag_folder=invalid_zip_filename, include_examples=False) + import_errors = dagbag.import_errors + assert invalid_dag_filename in import_errors + assert import_errors[invalid_dag_filename] == self._make_test_traceback(invalid_dag_filename, depth) @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5) @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5) @@ -1030,3 +887,36 @@ def test_dabgag_captured_warnings_zip(self): assert len(captured_warnings) == 2 assert captured_warnings[0] == (f"{in_zip_dag_file}:47: DeprecationWarning: Deprecated Parameter") assert captured_warnings[1] == f"{in_zip_dag_file}:49: UserWarning: Some Warning" + + @pytest.mark.parametrize( + ("known_pools", "expected"), + ( + pytest.param(None, set(), id="disabled"), + pytest.param( + {"default_pool"}, + { + DagWarning( + "test", + DagWarningType.NONEXISTENT_POOL, + "Dag 'test' references non-existent pools: ['pool1']", + ), + }, + id="only-default", + ), + pytest.param( + {"default_pool", "pool1"}, + set(), + id="known-pools", + ), + ), + ) + def test_dag_warnings_invalid_pool(self, known_pools, expected): + from airflow.models.baseoperator import BaseOperator + + with DAG(dag_id="test") as dag: + BaseOperator(task_id="1") + BaseOperator(task_id="2", pool="pool1") + + dagbag = DagBag(dag_folder="", include_examples=False, collect_dags=False, known_pools=known_pools) + dagbag.bag_dag(dag) + assert dagbag.dag_warnings == expected