Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/class_resolver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import inspect
import logging
from textwrap import dedent
from typing import Any, Collection, List, Mapping, Optional, Sequence, Type, TypeVar
from typing import (
Any,
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
)

from .base import BaseResolver
from .utils import (
Expand Down Expand Up @@ -118,6 +128,17 @@ def __init__(
suffix=suffix,
)

def subresolver(self, keys: Iterable[str]) -> "ClassResolver[X]":
"""Create a class resolver that's a subset of this one."""
elements = [self.lookup_str(key) for key in keys]
return self.__class__(
elements,
default=self.default,
synonyms=self.synonyms,
suffix=self.suffix,
base=self.base,
)

def extract_name(self, element: Type[X]) -> str:
"""Get the name for an element."""
return element.__name__
Expand Down
21 changes: 21 additions & 0 deletions src/class_resolver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ def __iter__(self) -> Iterator[X]:
"""Iterate over the registered elements."""
return iter(self.lookup_dict.values())

def subresolver(self, keys: Iterable[str]) -> "BaseResolver[X, Y]":
"""Create a resolver that's a subset of this one."""
elements = [self.lookup_str(key) for key in keys]
return self.__class__(
elements,
default=self.default,
synonyms=self.synonyms,
suffix=self.suffix,
)

@property
def options(self) -> Set[str]:
"""Return the normalized option names."""
Expand Down Expand Up @@ -172,6 +182,17 @@ def register(
def lookup(self, query: Hint[X], default: Optional[X] = None) -> X:
"""Lookup an element."""

def lookup_str(self, query: str) -> X:
"""Lookup an element by name."""
key = self.normalize(query)
if key in self.lookup_dict:
return self.lookup_dict[key]
elif key in self.synonyms:
return self.synonyms[key]
else:
valid_choices = sorted(self.options)
raise KeyError(f"{query} is an invalid. Try one of: {valid_choices}")

def docdata(self, query: Hint[X], *path: str, default: Optional[X] = None):
"""Lookup an element and get its docdata.

Expand Down
9 changes: 1 addition & 8 deletions src/class_resolver/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,7 @@ def lookup(self, query: Hint[X], default: Optional[X] = None) -> X:
elif callable(query):
return query # type: ignore
elif isinstance(query, str):
key = self.normalize(query)
if key in self.lookup_dict:
return self.lookup_dict[key]
elif key in self.synonyms:
return self.synonyms[key]
else:
valid_choices = sorted(self.options)
raise KeyError(f"{query} is an invalid. Try one of: {valid_choices}")
return self.lookup_str(query)
else:
raise TypeError(f"Invalid function: {type(query)} - {query}")

Expand Down
5 changes: 5 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,8 @@ class AAlt3Base(Alt3Base):
with self.assertRaises(TypeError) as e:
resolver.make("a")
self.assertEqual("surprise!", str(e.exception))

def test_subresolver(self):
"""Test getting a sub-resolver."""
subresolver = self.resolver.subresolver(["a", "b", "c"])
self.assertEqual(3, len(subresolver.lookup_dict))
5 changes: 5 additions & 0 deletions tests/test_function_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,8 @@ def test_late_entrypoints(self):
self.assertEqual({"add", "sub", "mul"}, set(resolver.lookup_dict))
self.assertEqual(set(), set(resolver.synonyms))
self.assertNotIn("expected_failure", resolver.lookup_dict)

def test_subresolver(self):
"""Test getting a sub-resolver."""
subresolver = self.resolver.subresolver(["add_two", "add_one"])
self.assertEqual(2, len(subresolver.lookup_dict))