From 9d15c5f53d822223755c3b1a3f1efe580fdbb821 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 14 Apr 2021 13:28:09 -0700 Subject: [PATCH] improve error message from Registrable class --- CHANGELOG.md | 2 ++ allennlp/common/registrable.py | 30 +++++++++++++++++++++++++++--- tests/common/registrable_test.py | 13 +++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a094a87b4f..150a07b0429 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - If a transformer is not in cache but has override weights, the transformer's pretrained weights are no longer downloaded, that is, only its `config.json` file is downloaded. - `SanityChecksCallback` now raises `SanityCheckError` instead of `AssertionError` when a check fails. - `jsonpickle` removed from dependencies. +- Improved the error message from `Registrable.by_name()` when the name passed does not match any registered subclassess. + The error message will include a suggestion if there is a close match between the name passed and a registered name. ### Fixed diff --git a/allennlp/common/registrable.py b/allennlp/common/registrable.py index 1813f37519d..fad5bc36c4a 100644 --- a/allennlp/common/registrable.py +++ b/allennlp/common/registrable.py @@ -189,10 +189,18 @@ def resolve_class_name( else: # is not a qualified class name + available = cls.list_available() + suggestion = _get_suggestion(name, available) raise ConfigurationError( - f"{name} is not a registered name for {cls.__name__}. " - "You probably need to use the --include-package flag " - "to load your custom code. Alternatively, you can specify your choices " + ( + f"'{name}' is not a registered name for '{cls.__name__}'" + + (". " if not suggestion else f", did you mean '{suggestion}'? ") + ) + + "If your registered class comes from custom code, you'll need to import " + "the corresponding modules. If you're using AllenNLP from the command-line, " + "this is done by using the '--include-package' flag, or by specifying your imports " + "in a '.allennlp_plugins' file. " + "Alternatively, you can specify your choices " """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ "in which case they will be automatically imported correctly." ) @@ -209,3 +217,19 @@ def list_available(cls) -> List[str]: raise ConfigurationError(f"Default implementation {default} is not registered") else: return [default] + [k for k in keys if k != default] + + +def _get_suggestion(name: str, available: List[str]) -> Optional[str]: + # First check for simple mistakes like using '-' instead of '_', or vice-versa. + for ch, repl_ch in (("_", "-"), ("-", "_")): + suggestion = name.replace(ch, repl_ch) + if suggestion in available: + return suggestion + # If we still haven't found a reasonable suggestion, we return the first suggestion + # with an edit distance (with transpositions allowed) of 1 to `name`. + from nltk.metrics.distance import edit_distance + + for suggestion in available: + if edit_distance(name, suggestion, transpositions=True) == 1: + return suggestion + return None diff --git a/tests/common/registrable_test.py b/tests/common/registrable_test.py index ff92b7f3861..972e9f7f2fe 100644 --- a/tests/common/registrable_test.py +++ b/tests/common/registrable_test.py @@ -130,3 +130,16 @@ def test_implicit_include_package(self): "testpackage.reader.TextClassificationJsonReader" ) assert duplicate_reader.__name__ == "TextClassificationJsonReader" + + +@pytest.mark.parametrize( + "name", + [ + "sequence-tagging", # using '-' instead of '_' + "sequence-taggign", # transposition of 'ng' + ], +) +def test_suggestions_when_name_not_found(name): + with pytest.raises(ConfigurationError) as exc: + DatasetReader.by_name(name) + assert "did you mean 'sequence_tagging'?" in str(exc.value)