diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 80099b27d9836..a20126ffe2293 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -32,9 +32,13 @@ import warnings from typing import Any, Dict, List, Type -import pkg_resources from six import with_metaclass +try: + import importlib.metadata as importlib_metadata +except ImportError: + import importlib_metadata + from airflow import settings from airflow.models.baseoperator import BaseOperatorLink @@ -109,6 +113,23 @@ def on_load(cls, *args, **kwargs): """ +def entry_points_with_dist(group): + """ + Return EntryPoint objects of the given group, along with the distribution information. + + This is like the ``entry_points()`` function from importlib.metadata, + except it also returns the distribution the entry_point was loaded from. + + :param group: FIlter results to only this entrypoint group + :return: Generator of (EntryPoint, Distribution) objects for the specified groups + """ + for dist in importlib_metadata.distributions(): + for e in dist.entry_points: + if e.group != group: + continue + yield (e, dist) + + def load_entrypoint_plugins(entry_points, airflow_plugins): """ Load AirflowPlugin subclasses from the entrypoints @@ -122,16 +143,18 @@ def load_entrypoint_plugins(entry_points, airflow_plugins): :rtype: list[airflow.plugins_manager.AirflowPlugin] """ global import_errors # pylint: disable=global-statement - for entry_point in entry_points: + for entry_point, dist in entry_points: log.debug('Importing entry_point plugin %s', entry_point.name) try: plugin_obj = entry_point.load() - plugin_obj.__usable_import_name = entry_point.module_name - if is_valid_plugin(plugin_obj, airflow_plugins): - if callable(getattr(plugin_obj, 'on_load', None)): - plugin_obj.on_load() + plugin_obj.__usable_import_name = entry_point.module + if not is_valid_plugin(plugin_obj, airflow_plugins): + continue + + if callable(getattr(plugin_obj, 'on_load', None)): + plugin_obj.on_load() - airflow_plugins.append(plugin_obj) + airflow_plugins.append(plugin_obj) except Exception as e: # pylint: disable=broad-except log.exception("Failed to import plugin %s", entry_point.name) import_errors[entry_point.module_name] = str(e) @@ -204,7 +227,7 @@ def is_valid_plugin(plugin_obj, existing_plugins): import_errors[filepath] = str(e) plugins = load_entrypoint_plugins( - pkg_resources.iter_entry_points('airflow.plugins'), + entry_points_with_dist('airflow.plugins'), plugins ) diff --git a/tests/plugins/test_plugins_manager_rbac.py b/tests/plugins/test_plugins_manager_rbac.py index 83edcb67766d3..c2ca805177907 100644 --- a/tests/plugins/test_plugins_manager_rbac.py +++ b/tests/plugins/test_plugins_manager_rbac.py @@ -22,17 +22,16 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest -import six -from tests.compat import mock +import logging -import pkg_resources +import pytest from airflow.www_rbac import app as application +from tests.compat import mock -class PluginsTestRBAC(unittest.TestCase): - def setUp(self): +class TestPluginsRBAC(object): + def setup_method(self, method): self.app, self.appbuilder = application.create_app(testing=True) def test_flaskappbuilder_views(self): @@ -41,18 +40,18 @@ def test_flaskappbuilder_views(self): plugin_views = [view for view in self.appbuilder.baseviews if view.blueprint.name == appbuilder_class_name] - self.assertTrue(len(plugin_views) == 1) + assert len(plugin_views) == 1 # view should have a menu item matching category of v_appbuilder_package links = [menu_item for menu_item in self.appbuilder.menu.menu if menu_item.name == v_appbuilder_package['category']] - self.assertTrue(len(links) == 1) + assert len(links) == 1 # menu link should also have a link matching the name of the package. link = links[0] - self.assertEqual(link.name, v_appbuilder_package['category']) - self.assertEqual(link.childs[0].name, v_appbuilder_package['name']) + assert link.name == v_appbuilder_package['category'] + assert link.childs[0].name == v_appbuilder_package['name'] def test_flaskappbuilder_menu_links(self): from tests.plugins.test_plugin import appbuilder_mitem @@ -61,40 +60,45 @@ def test_flaskappbuilder_menu_links(self): links = [menu_item for menu_item in self.appbuilder.menu.menu if menu_item.name == appbuilder_mitem['category']] - self.assertTrue(len(links) == 1) + assert len(links) == 1 # menu link should also have a link matching the name of the package. link = links[0] - self.assertEqual(link.name, appbuilder_mitem['category']) - self.assertEqual(link.childs[0].name, appbuilder_mitem['name']) + assert link.name == appbuilder_mitem['category'] + assert link.childs[0].name == appbuilder_mitem['name'] def test_app_blueprints(self): from tests.plugins.test_plugin import bp # Blueprint should be present in the app - self.assertTrue('test_plugin' in self.app.blueprints) - self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name) + assert 'test_plugin' in self.app.blueprints + assert self.app.blueprints['test_plugin'].name == bp.name - @unittest.skipIf(six.PY2, 'self.assertLogs not available for Python 2') - @mock.patch('pkg_resources.iter_entry_points') - def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_ep_plugins): + @pytest.mark.quarantined + def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog): """ Test that Airflow does not raise an Error if there is any Exception because of the Plugin. """ - from airflow.plugins_manager import load_entrypoint_plugins, import_errors + from airflow.plugins_manager import import_errors, load_entrypoint_plugins, entry_points_with_dist + + mock_dist = mock.Mock() mock_entrypoint = mock.Mock() mock_entrypoint.name = 'test-entrypoint' + mock_entrypoint.group = 'airflow.plugins' mock_entrypoint.module_name = 'test.plugins.test_plugins_manager' - mock_entrypoint.load.side_effect = Exception('Version Conflict') - mock_ep_plugins.return_value = [mock_entrypoint] + mock_entrypoint.load.side_effect = ImportError('my_fake_module not found') + mock_dist.entry_points = [mock_entrypoint] + + with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]), caplog.at_level( + logging.ERROR, logger='airflow.plugins_manager' + ): + load_entrypoint_plugins(entry_points_with_dist('airflow.plugins'), []) - with self.assertLogs("airflow.plugins_manager", level="ERROR") as log_output: - load_entrypoint_plugins(pkg_resources.iter_entry_points('airflow.plugins'), []) - received_logs = log_output.output[0] + received_logs = caplog.text # Assert Traceback is shown too assert "Traceback (most recent call last):" in received_logs - assert "Version Conflict" in received_logs + assert "my_fake_module not found" in received_logs assert "Failed to import plugin test-entrypoint" in received_logs - assert ('test.plugins.test_plugins_manager', 'Version Conflict') in import_errors.items() + assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items() diff --git a/tests/plugins/test_plugins_manager_www.py b/tests/plugins/test_plugins_manager_www.py index 656f8b1b7c1f2..27d933e963b2d 100644 --- a/tests/plugins/test_plugins_manager_www.py +++ b/tests/plugins/test_plugins_manager_www.py @@ -23,7 +23,7 @@ from __future__ import unicode_literals import six -from mock import MagicMock, PropertyMock +from mock import MagicMock, Mock from flask.blueprints import Blueprint from flask_admin.menu import MenuLink, MenuView @@ -119,11 +119,16 @@ def setUp(self): ] def _build_mock(self, plugin_obj): - m = MagicMock(**{ - 'load.return_value': plugin_obj - }) - type(m).name = PropertyMock(return_value='plugin-' + plugin_obj.name) - return m + + mock_dist = Mock() + + mock_entrypoint = Mock() + mock_entrypoint.name = 'plugin-' + plugin_obj.name + mock_entrypoint.group = 'airflow.plugins' + mock_entrypoint.load.return_value = plugin_obj + mock_dist.entry_points = [mock_entrypoint] + + return (mock_entrypoint, mock_dist) def test_load_entrypoint_plugins(self): self.assertListEqual(