From 896b42c06e4269e7dacb095234d4c7f6156d39d6 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Sun, 19 Feb 2023 20:59:29 +0100 Subject: [PATCH 1/3] Add subset functionality `BaseResolver.subresolver()` allows for the resolver to be subsetted based on a list of keys. This might be useful for HPO scenarios. --- src/class_resolver/base.py | 24 ++++++++++++++++++++++++ src/class_resolver/func.py | 9 +-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/class_resolver/base.py b/src/class_resolver/base.py index 9d6c915..73c5ddc 100644 --- a/src/class_resolver/base.py +++ b/src/class_resolver/base.py @@ -109,6 +109,19 @@ def __iter__(self) -> Iterator[X]: """Iterate over the registered elements.""" return iter(self.lookup_dict.values()) + def subresolver(self, keys: Iterable[str]) -> "BaseResolver[X, Y]": + """Create a resolver that's a subset of this one.""" + elements = [ + self.lookup_str(key) + for key in keys + ] + return self.__class__( + elements=elements, + default=self.default, + synonyms=self.synonyms, + suffix=self.suffix, + ) + @property def options(self) -> Set[str]: """Return the normalized option names.""" @@ -172,6 +185,17 @@ def register( def lookup(self, query: Hint[X], default: Optional[X] = None) -> X: """Lookup an element.""" + def lookup_str(self, query: str) -> X: + """Lookup an element by name.""" + key = self.normalize(query) + if key in self.lookup_dict: + return self.lookup_dict[key] + elif key in self.synonyms: + return self.synonyms[key] + else: + valid_choices = sorted(self.options) + raise KeyError(f"{query} is an invalid. Try one of: {valid_choices}") + def docdata(self, query: Hint[X], *path: str, default: Optional[X] = None): """Lookup an element and get its docdata. diff --git a/src/class_resolver/func.py b/src/class_resolver/func.py index 92c0993..f89fb13 100644 --- a/src/class_resolver/func.py +++ b/src/class_resolver/func.py @@ -29,14 +29,7 @@ def lookup(self, query: Hint[X], default: Optional[X] = None) -> X: elif callable(query): return query # type: ignore elif isinstance(query, str): - key = self.normalize(query) - if key in self.lookup_dict: - return self.lookup_dict[key] - elif key in self.synonyms: - return self.synonyms[key] - else: - valid_choices = sorted(self.options) - raise KeyError(f"{query} is an invalid. Try one of: {valid_choices}") + return self.lookup_str(query) else: raise TypeError(f"Invalid function: {type(query)} - {query}") From fb50fc5308c01ba4ce1ff1bee43b13756f0f1d3e Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Sun, 19 Feb 2023 21:01:21 +0100 Subject: [PATCH 2/3] Update base.py --- src/class_resolver/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/class_resolver/base.py b/src/class_resolver/base.py index 73c5ddc..03dfea9 100644 --- a/src/class_resolver/base.py +++ b/src/class_resolver/base.py @@ -111,10 +111,7 @@ def __iter__(self) -> Iterator[X]: def subresolver(self, keys: Iterable[str]) -> "BaseResolver[X, Y]": """Create a resolver that's a subset of this one.""" - elements = [ - self.lookup_str(key) - for key in keys - ] + elements = [self.lookup_str(key) for key in keys] return self.__class__( elements=elements, default=self.default, From a5db6ec6d3a7cb76cc3c9d376e022a30a27d4be4 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Sun, 19 Feb 2023 21:07:34 +0100 Subject: [PATCH 3/3] Add tests --- src/class_resolver/api.py | 23 ++++++++++++++++++++++- src/class_resolver/base.py | 2 +- tests/test_api.py | 5 +++++ tests/test_function_resolver.py | 5 +++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/class_resolver/api.py b/src/class_resolver/api.py index c73c5e9..79edda1 100644 --- a/src/class_resolver/api.py +++ b/src/class_resolver/api.py @@ -5,7 +5,17 @@ import inspect import logging from textwrap import dedent -from typing import Any, Collection, List, Mapping, Optional, Sequence, Type, TypeVar +from typing import ( + Any, + Collection, + Iterable, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, +) from .base import BaseResolver from .utils import ( @@ -118,6 +128,17 @@ def __init__( suffix=suffix, ) + def subresolver(self, keys: Iterable[str]) -> "ClassResolver[X]": + """Create a class resolver that's a subset of this one.""" + elements = [self.lookup_str(key) for key in keys] + return self.__class__( + elements, + default=self.default, + synonyms=self.synonyms, + suffix=self.suffix, + base=self.base, + ) + def extract_name(self, element: Type[X]) -> str: """Get the name for an element.""" return element.__name__ diff --git a/src/class_resolver/base.py b/src/class_resolver/base.py index 03dfea9..07c8c49 100644 --- a/src/class_resolver/base.py +++ b/src/class_resolver/base.py @@ -113,7 +113,7 @@ def subresolver(self, keys: Iterable[str]) -> "BaseResolver[X, Y]": """Create a resolver that's a subset of this one.""" elements = [self.lookup_str(key) for key in keys] return self.__class__( - elements=elements, + elements, default=self.default, synonyms=self.synonyms, suffix=self.suffix, diff --git a/tests/test_api.py b/tests/test_api.py index 3cd7b07..a9f7abc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -444,3 +444,8 @@ class AAlt3Base(Alt3Base): with self.assertRaises(TypeError) as e: resolver.make("a") self.assertEqual("surprise!", str(e.exception)) + + def test_subresolver(self): + """Test getting a sub-resolver.""" + subresolver = self.resolver.subresolver(["a", "b", "c"]) + self.assertEqual(3, len(subresolver.lookup_dict)) diff --git a/tests/test_function_resolver.py b/tests/test_function_resolver.py index c2b498b..d66401d 100644 --- a/tests/test_function_resolver.py +++ b/tests/test_function_resolver.py @@ -112,3 +112,8 @@ def test_late_entrypoints(self): self.assertEqual({"add", "sub", "mul"}, set(resolver.lookup_dict)) self.assertEqual(set(), set(resolver.synonyms)) self.assertNotIn("expected_failure", resolver.lookup_dict) + + def test_subresolver(self): + """Test getting a sub-resolver.""" + subresolver = self.resolver.subresolver(["add_two", "add_one"]) + self.assertEqual(2, len(subresolver.lookup_dict))