Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

improve error message from Registrable class #5125

Merged
merged 2 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- If a transformer is not in cache but has override weights, the transformer's pretrained weights are no longer downloaded, that is, only its `config.json` file is downloaded.
- `SanityChecksCallback` now raises `SanityCheckError` instead of `AssertionError` when a check fails.
- `jsonpickle` removed from dependencies.
- Improved the error message from `Registrable.by_name()` when the name passed does not match any registered subclassess.
The error message will include a suggestion if there is a close match between the name passed and a registered name.

### Fixed

Expand Down
30 changes: 27 additions & 3 deletions allennlp/common/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,18 @@ def resolve_class_name(

else:
# is not a qualified class name
available = cls.list_available()
suggestion = _get_suggestion(name, available)
raise ConfigurationError(
f"{name} is not a registered name for {cls.__name__}. "
"You probably need to use the --include-package flag "
"to load your custom code. Alternatively, you can specify your choices "
(
f"'{name}' is not a registered name for '{cls.__name__}'"
+ (". " if not suggestion else f", did you mean '{suggestion}'? ")
)
+ "If your registered class comes from custom code, you'll need to import "
"the corresponding modules. If you're using AllenNLP from the command-line, "
"this is done by using the '--include-package' flag, or by specifying your imports "
"in a '.allennlp_plugins' file. "
"Alternatively, you can specify your choices "
"""using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
"in which case they will be automatically imported correctly."
)
Expand All @@ -209,3 +217,19 @@ def list_available(cls) -> List[str]:
raise ConfigurationError(f"Default implementation {default} is not registered")
else:
return [default] + [k for k in keys if k != default]


def _get_suggestion(name: str, available: List[str]) -> Optional[str]:
# First check for simple mistakes like using '-' instead of '_', or vice-versa.
for ch, repl_ch in (("_", "-"), ("-", "_")):
suggestion = name.replace(ch, repl_ch)
if suggestion in available:
return suggestion
# If we still haven't found a reasonable suggestion, we return the first suggestion
# with an edit distance (with transpositions allowed) of 1 to `name`.
from nltk.metrics.distance import edit_distance

for suggestion in available:
if edit_distance(name, suggestion, transpositions=True) == 1:
return suggestion
return None
13 changes: 13 additions & 0 deletions tests/common/registrable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,16 @@ def test_implicit_include_package(self):
"testpackage.reader.TextClassificationJsonReader"
)
assert duplicate_reader.__name__ == "TextClassificationJsonReader"


@pytest.mark.parametrize(
"name",
[
"sequence-tagging", # using '-' instead of '_'
"sequence-taggign", # transposition of 'ng'
],
)
def test_suggestions_when_name_not_found(name):
with pytest.raises(ConfigurationError) as exc:
DatasetReader.by_name(name)
assert "did you mean 'sequence_tagging'?" in str(exc.value)