diff --git a/airflow/policies.py b/airflow/policies.py index 175ae5bbf9063..47c3dffcb22d9 100644 --- a/airflow/policies.py +++ b/airflow/policies.py @@ -138,7 +138,7 @@ def get_airflow_context_vars(context): return {} -def make_plugin_from_local_settings(pm: pluggy.PluginManager, module, names: list[str]): +def make_plugin_from_local_settings(pm: pluggy.PluginManager, module, names: set[str]): """ Turn the functions from airflow_local_settings module into a custom/local plugin. diff --git a/airflow/settings.py b/airflow/settings.py index e51cd208d9f74..bdf70ecf27c88 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -465,11 +465,23 @@ def import_local_settings(): """Import airflow_local_settings.py files to allow overriding any configs in settings.py file.""" try: import airflow_local_settings - + except ModuleNotFoundError as e: + if e.name == "airflow_local_settings": + log.debug("No airflow_local_settings to import.", exc_info=True) + else: + log.critical( + "Failed to import airflow_local_settings due to a transitive module not found error.", + exc_info=True, + ) + raise + except ImportError: + log.critical("Failed to import airflow_local_settings.", exc_info=True) + raise + else: if hasattr(airflow_local_settings, "__all__"): - names = list(airflow_local_settings.__all__) + names = set(airflow_local_settings.__all__) else: - names = list(filter(lambda n: not n.startswith("__"), airflow_local_settings.__dict__.keys())) + names = {n for n in airflow_local_settings.__dict__ if not n.startswith("__")} if "policy" in names and "task_policy" not in names: warnings.warn( @@ -485,30 +497,15 @@ def import_local_settings(): POLICY_PLUGIN_MANAGER, airflow_local_settings, names ) - for name in names: - # If we have already handled a function by adding it to the plugin, then don't clobber the global - # function - if name in plugin_functions: - continue - + # If we have already handled a function by adding it to the plugin, + # then don't clobber the global function + for name in names - plugin_functions: globals()[name] = getattr(airflow_local_settings, name) if POLICY_PLUGIN_MANAGER.hook.task_instance_mutation_hook.get_hookimpls(): task_instance_mutation_hook.is_noop = False log.info("Loaded airflow_local_settings from %s .", airflow_local_settings.__file__) - except ModuleNotFoundError as e: - if e.name == "airflow_local_settings": - log.debug("No airflow_local_settings to import.", exc_info=True) - else: - log.critical( - "Failed to import airflow_local_settings due to a transitive module not found error.", - exc_info=True, - ) - raise - except ImportError: - log.critical("Failed to import airflow_local_settings.", exc_info=True) - raise def initialize(): diff --git a/tests/core/test_policies.py b/tests/core/test_policies.py index e5fbf4fe682df..c7caff80b2c04 100644 --- a/tests/core/test_policies.py +++ b/tests/core/test_policies.py @@ -42,7 +42,7 @@ def dag_policy(dag): mod = Namespace(dag_policy=dag_policy) - policies.make_plugin_from_local_settings(plugin_manager, mod, ["dag_policy"]) + policies.make_plugin_from_local_settings(plugin_manager, mod, {"dag_policy"}) plugin_manager.hook.dag_policy(dag="a") @@ -64,7 +64,7 @@ def dag_policy(wrong_arg_name): mod = Namespace(dag_policy=dag_policy) - policies.make_plugin_from_local_settings(plugin_manager, mod, ["dag_policy"]) + policies.make_plugin_from_local_settings(plugin_manager, mod, {"dag_policy"}) plugin_manager.hook.dag_policy(dag="passed_dag_value")