From c6bf192d3add27f95ab09ee5f27edf2e16da4cdb Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 24 Oct 2024 11:17:37 -0400 Subject: [PATCH 1/3] support selecting steps by package name for strun --- src/stpipe/cli/strun.py | 4 ++-- src/stpipe/entry_points.py | 4 ++-- src/stpipe/utilities.py | 13 +++++++++++- tests/test_utilities.py | 42 +++++++++++++++++++++++++++++++++++++- 4 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/stpipe/cli/strun.py b/src/stpipe/cli/strun.py index 6ba31fbb..7c95bbde 100755 --- a/src/stpipe/cli/strun.py +++ b/src/stpipe/cli/strun.py @@ -1,6 +1,6 @@ import sys -from stpipe import Step +from stpipe import cmdline from stpipe.cli.main import _print_versions from stpipe.exceptions import StpipeExitException @@ -21,7 +21,7 @@ def main(): sys.exit(0) try: - Step.from_cmdline(sys.argv[1:]) + cmdline.step_from_cmdline(sys.argv[1:]) except StpipeExitException as e: sys.exit(e.exit_status) except Exception: diff --git a/src/stpipe/entry_points.py b/src/stpipe/entry_points.py index 133fa6ac..323abfce 100644 --- a/src/stpipe/entry_points.py +++ b/src/stpipe/entry_points.py @@ -1,7 +1,7 @@ import warnings from collections import namedtuple -from importlib_metadata import entry_points +import importlib_metadata STEPS_GROUP = "stpipe.steps" @@ -26,7 +26,7 @@ class alias, and the third is a bool indicating whether the class is to be """ steps = [] - for entry_point in entry_points(group=STEPS_GROUP): + for entry_point in importlib_metadata.entry_points(group=STEPS_GROUP): package_name = entry_point.dist.name package_version = entry_point.dist.version package_steps = [] diff --git a/src/stpipe/utilities.py b/src/stpipe/utilities.py index 1a069c0b..6a6b9551 100644 --- a/src/stpipe/utilities.py +++ b/src/stpipe/utilities.py @@ -19,13 +19,24 @@ def resolve_step_class_alias(name): Parameters ---------- name : str + If name contains "::" only the package with + a name matching the characters before "::" + will be searched for the matching step. Returns ------- str """ + # check if the name contains a package name + if "::" in name: + scope, class_name = name.split("::", maxsplit=1) + else: + scope, class_name = None, name + for info in entry_points.get_steps(): - if info.class_alias is not None and name == info.class_alias: + if scope and info.package_name != scope: + continue + if info.class_alias is not None and class_name == info.class_alias: return info.class_name return name diff --git a/tests/test_utilities.py b/tests/test_utilities.py index f466618b..cc409b65 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,7 +1,7 @@ import pytest from stpipe import Step -from stpipe.utilities import import_class, import_func +from stpipe.utilities import import_class, import_func, resolve_step_class_alias def what_is_your_quest(): @@ -13,6 +13,8 @@ class HovercraftFullOfEels: class Foo(Step): + class_alias = "foo_step" + def process(self, input_data): pass @@ -52,3 +54,41 @@ def test_import_class_no_module(): def test_import_func_no_module(): with pytest.raises(ImportError): import_func("foo") + + +@pytest.mark.parametrize( + "name, resolve", + ( + ("foo_step", True), + ("stpipe::foo_step", True), + ("some_other_package::foo_step", False), + ), +) +def test_class_alias_lookup(name, resolve, monkeypatch): + # as the test class above isn't registered via an entry point + # we mock the entry points here + class FakeDist: + name = "stpipe" + version = "dev" + + class FakeEntryPoint: + dist = FakeDist() + + def load(self): + def loader(): + return [("Foo", "foo_step", False)] + + return loader + + def fake_entrypoints(group=None): + return [FakeEntryPoint()] + + import importlib_metadata + + monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints) + + resolved_name = resolve_step_class_alias(name) + if resolve: + assert resolved_name == Foo.__name__ + else: + assert resolved_name == name From 4cc46da7f8c019b9134b7abfc1ee80d8d3378906 Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 24 Oct 2024 11:25:48 -0400 Subject: [PATCH 2/3] add changelog fragment --- changes/202.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/202.feature.rst diff --git a/changes/202.feature.rst b/changes/202.feature.rst new file mode 100644 index 00000000..1967d26c --- /dev/null +++ b/changes/202.feature.rst @@ -0,0 +1 @@ +Allow class aliases (used during strun) to contain the package name (for example "jwst::resample"). From fd68d872627eb69f979787c757a67697313b2a2a Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 24 Oct 2024 16:41:07 -0400 Subject: [PATCH 3/3] raise ValueError if class_alias lookup returns more than 1 step --- src/stpipe/utilities.py | 20 ++++++++- tests/test_utilities.py | 98 +++++++++++++++++++++++++++++++++-------- 2 files changed, 97 insertions(+), 21 deletions(-) diff --git a/src/stpipe/utilities.py b/src/stpipe/utilities.py index 6a6b9551..5ba436c0 100644 --- a/src/stpipe/utilities.py +++ b/src/stpipe/utilities.py @@ -33,13 +33,29 @@ def resolve_step_class_alias(name): else: scope, class_name = None, name + # track all found steps keyed by package name + found_class_names = {} for info in entry_points.get_steps(): if scope and info.package_name != scope: continue if info.class_alias is not None and class_name == info.class_alias: - return info.class_name + found_class_names[info.package_name] = info - return name + if not found_class_names: + return name + + if len(found_class_names) == 1: + return found_class_names.popitem()[1].class_name + + # class alias resolved to several possible steps + scopes = list(found_class_names.keys()) + msg = ( + f"class alias {name} matched more than 1 step. Please provide " + "the package name along with the step name. One of:\n" + ) + for scope in scopes: + msg += f" {scope}::{name}\n" + raise ValueError(msg) def import_class(full_name, subclassof=object, config_file=None): diff --git a/tests/test_utilities.py b/tests/test_utilities.py index cc409b65..9ca34e64 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -56,39 +56,99 @@ def test_import_func_no_module(): import_func("foo") -@pytest.mark.parametrize( - "name, resolve", - ( - ("foo_step", True), - ("stpipe::foo_step", True), - ("some_other_package::foo_step", False), - ), -) -def test_class_alias_lookup(name, resolve, monkeypatch): +@pytest.fixture() +def mock_entry_points(monkeypatch, request): # as the test class above isn't registered via an entry point # we mock the entry points here class FakeDist: - name = "stpipe" - version = "dev" + def __init__(self, name): + self.name = name + self.version = "dev" class FakeEntryPoint: - dist = FakeDist() + def __init__(self, dist_name, steps): + self.dist = FakeDist(dist_name) + self.steps = steps def load(self): def loader(): - return [("Foo", "foo_step", False)] + return self.steps return loader def fake_entrypoints(group=None): - return [FakeEntryPoint()] + return [FakeEntryPoint(k, v) for k, v in request.param.items()] import importlib_metadata monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints) + yield + - resolved_name = resolve_step_class_alias(name) - if resolve: - assert resolved_name == Foo.__name__ - else: - assert resolved_name == name +@pytest.mark.parametrize("name", ("foo_step", "stpipe::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True +) +def test_class_alias_lookup(name, mock_entry_points): + """ + Test that a step name can be resolved if either: + - only a single step is found that matches + - a step is found and a valid package name was provided + """ + assert resolve_step_class_alias(name) == "Foo" + + +@pytest.mark.parametrize("name", ("bar_step", "other_package::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True +) +def test_class_alias_lookup_fallthrough(name, mock_entry_points): + """ + Test that passing in an unknown class alias or an alias scoped + to a different package falls through to returning the unresolved + class_alias (to match previous behavior). + """ + assert resolve_step_class_alias(name) == name + + +@pytest.mark.parametrize("name", ("aaa::foo_step", "zzz::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", + [ + { + "aaa": [("Foo", "foo_step", False)], + "zzz": [("Foo", "foo_step", False)], + } + ], + indirect=True, +) +def test_class_alias_lookup_scoped(name, mock_entry_points): + """ + Test the lookup succeeds if more than 1 package + provides a matching step name but the "scope" (package name) + is provided on lookup. + """ + assert resolve_step_class_alias(name) == "Foo" + + +@pytest.mark.parametrize( + "mock_entry_points", + [ + { + "aaa": [("Foo", "foo_step", False)], + "zzz": [("Foo", "foo_step", False)], + } + ], + indirect=True, +) +def test_class_alias_lookup_conflict(mock_entry_points): + """ + Test that an ambiguous lookup (a class alias that resolves + to more than 1 step from different packages) results in + an error. + When the package name is provided, tes + """ + with pytest.raises(ValueError) as err: + resolve_step_class_alias("foo_step") + assert err.match("aaa::foo_step") + assert err.match("zzz::foo_step")