diff --git a/kedro/framework/cli/cli.py b/kedro/framework/cli/cli.py index 14a633a6cd..9104e53c1c 100644 --- a/kedro/framework/cli/cli.py +++ b/kedro/framework/cli/cli.py @@ -12,6 +12,7 @@ from typing import Any, Sequence import click +from importlib_metadata import EntryPoint from kedro import __version__ as version from kedro.framework.cli import BRIGHT_BLACK, ORANGE @@ -105,6 +106,7 @@ def __init__(self, project_path: Path): super().__init__( ("Global commands", self.global_groups), ("Project specific commands", self.project_groups), + plugin_entry_points=self.plugin_groups, ) def main( @@ -172,13 +174,19 @@ def main( click.echo(hint) sys.exit(exc.code) + @property + def plugin_groups(self) -> dict[str, EntryPoint]: + eps = list(_get_entry_points("global")) + list(_get_entry_points("project")) + entry_point_dict = {ep.name: ep for ep in eps} + return entry_point_dict + @property def global_groups(self) -> Sequence[click.MultiCommand]: """Property which loads all global command groups from plugins and combines them with the built-in ones (eventually overriding the built-in ones if they are redefined by plugins). """ - return [cli, create_cli, *load_entry_points("global")] + return [cli, create_cli] @property def project_groups(self) -> Sequence[click.MultiCommand]: @@ -201,15 +209,13 @@ def project_groups(self) -> Sequence[click.MultiCommand]: registry_cli, ] - plugins = load_entry_points("project") - try: project_cli = importlib.import_module(f"{self._metadata.package_name}.cli") # fail gracefully if cli.py does not exist except ModuleNotFoundError: # return only built-in commands and commands from plugins # (plugins can override built-in commands) - return [*built_in, *plugins] + return [*built_in] # fail badly if cli.py exists, but has no `cli` in it if not hasattr(project_cli, "cli"): @@ -219,7 +225,7 @@ def project_groups(self) -> Sequence[click.MultiCommand]: user_defined = project_cli.cli # return built-in commands, plugin commands and user defined commands # (overriding happens as follows built-in < plugins < cli.py) - return [*built_in, *plugins, user_defined] + return [*built_in, user_defined] def main() -> None: # pragma: no cover diff --git a/kedro/framework/cli/utils.py b/kedro/framework/cli/utils.py index eda9bc005c..83770d70d7 100644 --- a/kedro/framework/cli/utils.py +++ b/kedro/framework/cli/utils.py @@ -90,6 +90,13 @@ def wrapit(func: Any) -> Any: return wrapit +def _partial_match(plugin_names: list[str], command_name: str) -> str | None: + for plugin_name in plugin_names: + if command_name in plugin_name: + return plugin_name + return None + + def _suggest_cli_command( original_command_name: str, existing_command_names: Iterable[str] ) -> str: @@ -111,13 +118,17 @@ def _suggest_cli_command( class CommandCollection(click.CommandCollection): """Modified from the Click one to still run the source groups function.""" - def __init__(self, *groups: tuple[str, Sequence[click.MultiCommand]]): + def __init__( + self, + *groups: tuple[str, Sequence[click.MultiCommand]], + plugin_entry_points: dict[str, importlib_metadata.EntryPoint] = {}, + ): self.groups = [ (title, self._merge_same_name_collections(cli_list)) for title, cli_list in groups ] + self.lazy_groups = plugin_entry_points sources = list(chain.from_iterable(cli_list for _, cli_list in self.groups)) - help_texts = [ cli.help for cli_collection in sources @@ -179,6 +190,44 @@ def _merge_same_name_collections( if cli_list ] + def main( + self, + args: Any | None = None, + prog_name: Any | None = None, + complete_var: Any | None = None, + standalone_mode: bool = True, + **extra: Any, + ) -> Any: + # Load plugins if the command is not found in the current sources + if args and args[0] not in self.list_commands(None): # type: ignore[arg-type] + self._load_plugins(args[0]) + + return super().main( + args=args, + prog_name=prog_name, + complete_var=complete_var, + standalone_mode=standalone_mode, + **extra, + ) + + def _load_plugins(self, command_name: str) -> None: + """Load plugins if the command is not found in the current sources.""" + ep_names = list(self.lazy_groups.keys()) + part_match = _partial_match(ep_names, command_name) + if part_match: + # Try to smartly load the plugin if there is partial match + loaded_ep = _safe_load_entry_point(self.lazy_groups[part_match]) + self.add_source(loaded_ep) + if command_name in self.list_commands(None): # type: ignore[arg-type] + return + # Load all plugins + for ep in self.lazy_groups.values(): + if command_name in self.list_commands(None): # type: ignore[arg-type] + return + loaded_ep = _safe_load_entry_point(ep) + self.add_source(loaded_ep) + return + def resolve_command( self, ctx: click.core.Context, args: list ) -> tuple[str | None, click.Command | None, list[str]]: diff --git a/tests/framework/cli/test_cli.py b/tests/framework/cli/test_cli.py index d147a6a0d1..d58ab95af0 100644 --- a/tests/framework/cli/test_cli.py +++ b/tests/framework/cli/test_cli.py @@ -1,5 +1,4 @@ from collections import namedtuple -from itertools import cycle from os import rename from pathlib import Path @@ -319,12 +318,7 @@ def test_project_commands_no_clipy(self, mocker, fake_metadata): mocker.patch( "kedro.framework.cli.cli.bootstrap_project", return_value=fake_metadata ) - mocker.patch( - "kedro.framework.cli.cli.importlib.import_module", - side_effect=cycle([ModuleNotFoundError()]), - ) kedro_cli = KedroCLI(fake_metadata.project_path) - print(kedro_cli.project_groups) assert len(kedro_cli.project_groups) == 6 assert kedro_cli.project_groups == [ catalog_cli, @@ -382,6 +376,8 @@ def test_kedro_cli_no_project(self, mocker, tmp_path): result = CliRunner().invoke(kedro_cli, []) + print(result) + assert result.exit_code == 0 assert "Global commands from Kedro" in result.output assert "Project specific commands from Kedro" not in result.output