Skip to content

Commit

Permalink
Refactor: Simplify code in settings (#33267)
Browse files Browse the repository at this point in the history
(cherry picked from commit 95e9d83)
  • Loading branch information
eumiro authored and ephraimbuddy committed Aug 28, 2023
1 parent d31f20b commit 3439f91
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
2 changes: 1 addition & 1 deletion airflow/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_airflow_context_vars(context):
return {}


def make_plugin_from_local_settings(pm: pluggy.PluginManager, module, names: list[str]):
def make_plugin_from_local_settings(pm: pluggy.PluginManager, module, names: set[str]):
"""
Turn the functions from airflow_local_settings module into a custom/local plugin.
Expand Down
39 changes: 18 additions & 21 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,23 @@ def import_local_settings():
"""Import airflow_local_settings.py files to allow overriding any configs in settings.py file."""
try:
import airflow_local_settings

except ModuleNotFoundError as e:
if e.name == "airflow_local_settings":
log.debug("No airflow_local_settings to import.", exc_info=True)
else:
log.critical(
"Failed to import airflow_local_settings due to a transitive module not found error.",
exc_info=True,
)
raise
except ImportError:
log.critical("Failed to import airflow_local_settings.", exc_info=True)
raise
else:
if hasattr(airflow_local_settings, "__all__"):
names = list(airflow_local_settings.__all__)
names = set(airflow_local_settings.__all__)
else:
names = list(filter(lambda n: not n.startswith("__"), airflow_local_settings.__dict__.keys()))
names = {n for n in airflow_local_settings.__dict__ if not n.startswith("__")}

if "policy" in names and "task_policy" not in names:
warnings.warn(
Expand All @@ -485,30 +497,15 @@ def import_local_settings():
POLICY_PLUGIN_MANAGER, airflow_local_settings, names
)

for name in names:
# If we have already handled a function by adding it to the plugin, then don't clobber the global
# function
if name in plugin_functions:
continue

# If we have already handled a function by adding it to the plugin,
# then don't clobber the global function
for name in names - plugin_functions:
globals()[name] = getattr(airflow_local_settings, name)

if POLICY_PLUGIN_MANAGER.hook.task_instance_mutation_hook.get_hookimpls():
task_instance_mutation_hook.is_noop = False

log.info("Loaded airflow_local_settings from %s .", airflow_local_settings.__file__)
except ModuleNotFoundError as e:
if e.name == "airflow_local_settings":
log.debug("No airflow_local_settings to import.", exc_info=True)
else:
log.critical(
"Failed to import airflow_local_settings due to a transitive module not found error.",
exc_info=True,
)
raise
except ImportError:
log.critical("Failed to import airflow_local_settings.", exc_info=True)
raise


def initialize():
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def dag_policy(dag):

mod = Namespace(dag_policy=dag_policy)

policies.make_plugin_from_local_settings(plugin_manager, mod, ["dag_policy"])
policies.make_plugin_from_local_settings(plugin_manager, mod, {"dag_policy"})

plugin_manager.hook.dag_policy(dag="a")

Expand All @@ -64,7 +64,7 @@ def dag_policy(wrong_arg_name):

mod = Namespace(dag_policy=dag_policy)

policies.make_plugin_from_local_settings(plugin_manager, mod, ["dag_policy"])
policies.make_plugin_from_local_settings(plugin_manager, mod, {"dag_policy"})

plugin_manager.hook.dag_policy(dag="passed_dag_value")

Expand Down

0 comments on commit 3439f91

Please sign in to comment.