Skip to content

Commit

Permalink
Merge pull request #8208 from RasaHQ/fast-nlu-key-in-yaml
Browse files Browse the repository at this point in the history
use quick `is_key_in_yaml` implementation
  • Loading branch information
wochinge authored Mar 16, 2021
2 parents 6e535a8 + 92f5466 commit bfa97e1
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 55 deletions.
1 change: 1 addition & 0 deletions changelog/8208.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Faster reading of YAML NLU training data files.
32 changes: 3 additions & 29 deletions rasa/shared/core/training_data/story_reader/yaml_story_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,35 +160,9 @@ def is_stories_file(cls, file_path: Union[Text, Path]) -> bool:
YamlException: if the file seems to be a YAML file (extension) but
can not be read / parsed.
"""
return rasa.shared.data.is_likely_yaml_file(file_path) and cls.is_key_in_yaml(
file_path, KEY_STORIES, KEY_RULES
)

@classmethod
def is_key_in_yaml(cls, file_path: Union[Text, Path], *keys: Text) -> bool:
"""Check if any of the keys is contained in the root object of the yaml file.
Arguments:
file_path: path to the yaml file
keys: keys to look for
Returns:
`True` if at least one of the keys is found, `False` otherwise.
Raises:
FileNotFoundException: if the file cannot be found.
"""
try:
with open(file_path) as file:
return any(
any(line.lstrip().startswith(f"{key}:") for key in keys)
for line in file
)
except FileNotFoundError:
raise FileNotFoundException(
f"Failed to read file, "
f"'{os.path.abspath(file_path)}' does not exist."
)
return rasa.shared.data.is_likely_yaml_file(
file_path
) and rasa.shared.utils.io.is_key_in_yaml(file_path, KEY_STORIES, KEY_RULES)

@classmethod
def _has_test_prefix(cls, file_path: Text) -> bool:
Expand Down
6 changes: 2 additions & 4 deletions rasa/shared/nlu/training_data/formats/rasa_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _parse_multiline_example(self, item: Text, examples: Text) -> Iterator[Text]
yield example[1:].strip(STRIP_SYMBOLS)

@staticmethod
def is_yaml_nlu_file(filename: Text) -> bool:
def is_yaml_nlu_file(filename: Union[Text, Path]) -> bool:
"""Checks if the specified file possibly contains NLU training data in YAML.
Args:
Expand All @@ -351,9 +351,7 @@ def is_yaml_nlu_file(filename: Text) -> bool:
if not rasa.shared.data.is_likely_yaml_file(filename):
return False

content = rasa.shared.utils.io.read_yaml_file(filename)

return any(key in content for key in {KEY_NLU, KEY_RESPONSES})
return rasa.shared.utils.io.is_key_in_yaml(filename, KEY_NLU, KEY_RESPONSES)


class RasaYAMLWriter(TrainingDataWriter):
Expand Down
25 changes: 25 additions & 0 deletions rasa/shared/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,31 @@ def write_yaml(
YAML_LINE_MAX_WIDTH = 4096


def is_key_in_yaml(file_path: Union[Text, Path], *keys: Text) -> bool:
"""Checks if any of the keys is contained in the root object of the yaml file.
Arguments:
file_path: path to the yaml file
keys: keys to look for
Returns:
`True` if at least one of the keys is found, `False` otherwise.
Raises:
FileNotFoundException: if the file cannot be found.
"""
try:
with open(file_path) as file:
return any(
any(line.lstrip().startswith(f"{key}:") for key in keys)
for line in file
)
except FileNotFoundError:
raise FileNotFoundException(
f"Failed to read file, " f"'{os.path.abspath(file_path)}' does not exist."
)


def convert_to_ordered_dict(obj: Any) -> Any:
"""Convert object to an `OrderedDict`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,26 +204,6 @@ async def test_is_yaml_file(file: Text, is_yaml_file: bool):
assert YAMLStoryReader.is_stories_file(file) == is_yaml_file


@pytest.mark.parametrize(
"file,keys,expected_result",
[
("data/test_yaml_stories/stories.yml", ["stories"], True),
("data/test_yaml_stories/stories.yml", ["something_else"], False),
("data/test_yaml_stories/stories.yml", ["stories", "something_else"], True),
(
"data/test_domains/default_retrieval_intents.yml",
["intents", "responses"],
True,
),
("data/test_yaml_stories/rules_without_stories.yml", ["rules"], True),
("data/test_yaml_stories/rules_without_stories.yml", ["stories"], False),
("data/test_stories/stories.md", ["something"], False),
],
)
async def test_is_key_in_yaml(file: Text, keys: List[Text], expected_result: bool):
assert YAMLStoryReader.is_key_in_yaml(file, *keys) == expected_result


async def test_yaml_intent_with_leading_slash_warning(default_domain: Domain):
yaml_file = "data/test_wrong_yaml_stories/intent_with_leading_slash.yml"

Expand Down
4 changes: 2 additions & 2 deletions tests/shared/nlu/training_data/formats/test_rasa_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ def test_minimal_valid_example():
assert not len(record)


def test_minimal_yaml_nlu_file(tmp_path):
def test_minimal_yaml_nlu_file(tmp_path: pathlib.Path):
target_file = tmp_path / "test_nlu_file.yaml"
rasa.shared.utils.io.write_yaml(MINIMAL_VALID_EXAMPLE, target_file, True)
rasa.shared.utils.io.write_text_file(MINIMAL_VALID_EXAMPLE, target_file)
assert RasaYAMLReader.is_yaml_nlu_file(target_file)


Expand Down
20 changes: 20 additions & 0 deletions tests/shared/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,23 @@ def test_read_invalid_config_file(tmp_path: Path, content: Text):

with pytest.raises(rasa.shared.utils.validation.YamlValidationException):
rasa.shared.utils.io.read_model_configuration(config_file)


@pytest.mark.parametrize(
"file,keys,expected_result",
[
("data/test_yaml_stories/stories.yml", ["stories"], True),
("data/test_yaml_stories/stories.yml", ["something_else"], False),
("data/test_yaml_stories/stories.yml", ["stories", "something_else"], True),
(
"data/test_domains/default_retrieval_intents.yml",
["intents", "responses"],
True,
),
("data/test_yaml_stories/rules_without_stories.yml", ["rules"], True),
("data/test_yaml_stories/rules_without_stories.yml", ["stories"], False),
("data/test_stories/stories.md", ["something"], False),
],
)
async def test_is_key_in_yaml(file: Text, keys: List[Text], expected_result: bool):
assert rasa.shared.utils.io.is_key_in_yaml(file, *keys) == expected_result

0 comments on commit bfa97e1

Please sign in to comment.