diff --git a/changelog/8208.improvement.md b/changelog/8208.improvement.md new file mode 100644 index 000000000000..309f9a04d45f --- /dev/null +++ b/changelog/8208.improvement.md @@ -0,0 +1 @@ +Faster reading of YAML NLU training data files. diff --git a/rasa/shared/core/training_data/story_reader/yaml_story_reader.py b/rasa/shared/core/training_data/story_reader/yaml_story_reader.py index e948cfd1638e..6939fba8f754 100644 --- a/rasa/shared/core/training_data/story_reader/yaml_story_reader.py +++ b/rasa/shared/core/training_data/story_reader/yaml_story_reader.py @@ -160,35 +160,9 @@ def is_stories_file(cls, file_path: Union[Text, Path]) -> bool: YamlException: if the file seems to be a YAML file (extension) but can not be read / parsed. """ - return rasa.shared.data.is_likely_yaml_file(file_path) and cls.is_key_in_yaml( - file_path, KEY_STORIES, KEY_RULES - ) - - @classmethod - def is_key_in_yaml(cls, file_path: Union[Text, Path], *keys: Text) -> bool: - """Check if any of the keys is contained in the root object of the yaml file. - - Arguments: - file_path: path to the yaml file - keys: keys to look for - - Returns: - `True` if at least one of the keys is found, `False` otherwise. - - Raises: - FileNotFoundException: if the file cannot be found. - """ - try: - with open(file_path) as file: - return any( - any(line.lstrip().startswith(f"{key}:") for key in keys) - for line in file - ) - except FileNotFoundError: - raise FileNotFoundException( - f"Failed to read file, " - f"'{os.path.abspath(file_path)}' does not exist." - ) + return rasa.shared.data.is_likely_yaml_file( + file_path + ) and rasa.shared.utils.io.is_key_in_yaml(file_path, KEY_STORIES, KEY_RULES) @classmethod def _has_test_prefix(cls, file_path: Text) -> bool: diff --git a/rasa/shared/nlu/training_data/formats/rasa_yaml.py b/rasa/shared/nlu/training_data/formats/rasa_yaml.py index ed54e58a8f23..1ef6b30de36b 100644 --- a/rasa/shared/nlu/training_data/formats/rasa_yaml.py +++ b/rasa/shared/nlu/training_data/formats/rasa_yaml.py @@ -334,7 +334,7 @@ def _parse_multiline_example(self, item: Text, examples: Text) -> Iterator[Text] yield example[1:].strip(STRIP_SYMBOLS) @staticmethod - def is_yaml_nlu_file(filename: Text) -> bool: + def is_yaml_nlu_file(filename: Union[Text, Path]) -> bool: """Checks if the specified file possibly contains NLU training data in YAML. Args: @@ -351,9 +351,7 @@ def is_yaml_nlu_file(filename: Text) -> bool: if not rasa.shared.data.is_likely_yaml_file(filename): return False - content = rasa.shared.utils.io.read_yaml_file(filename) - - return any(key in content for key in {KEY_NLU, KEY_RESPONSES}) + return rasa.shared.utils.io.is_key_in_yaml(filename, KEY_NLU, KEY_RESPONSES) class RasaYAMLWriter(TrainingDataWriter): diff --git a/rasa/shared/utils/io.py b/rasa/shared/utils/io.py index 91e1c3ce402d..b5de0b3a315f 100644 --- a/rasa/shared/utils/io.py +++ b/rasa/shared/utils/io.py @@ -414,6 +414,31 @@ def write_yaml( YAML_LINE_MAX_WIDTH = 4096 +def is_key_in_yaml(file_path: Union[Text, Path], *keys: Text) -> bool: + """Checks if any of the keys is contained in the root object of the yaml file. + + Arguments: + file_path: path to the yaml file + keys: keys to look for + + Returns: + `True` if at least one of the keys is found, `False` otherwise. + + Raises: + FileNotFoundException: if the file cannot be found. + """ + try: + with open(file_path) as file: + return any( + any(line.lstrip().startswith(f"{key}:") for key in keys) + for line in file + ) + except FileNotFoundError: + raise FileNotFoundException( + f"Failed to read file, " f"'{os.path.abspath(file_path)}' does not exist." + ) + + def convert_to_ordered_dict(obj: Any) -> Any: """Convert object to an `OrderedDict`. diff --git a/tests/shared/core/training_data/story_reader/test_yaml_story_reader.py b/tests/shared/core/training_data/story_reader/test_yaml_story_reader.py index ff87ecd26859..04a63630a452 100644 --- a/tests/shared/core/training_data/story_reader/test_yaml_story_reader.py +++ b/tests/shared/core/training_data/story_reader/test_yaml_story_reader.py @@ -204,26 +204,6 @@ async def test_is_yaml_file(file: Text, is_yaml_file: bool): assert YAMLStoryReader.is_stories_file(file) == is_yaml_file -@pytest.mark.parametrize( - "file,keys,expected_result", - [ - ("data/test_yaml_stories/stories.yml", ["stories"], True), - ("data/test_yaml_stories/stories.yml", ["something_else"], False), - ("data/test_yaml_stories/stories.yml", ["stories", "something_else"], True), - ( - "data/test_domains/default_retrieval_intents.yml", - ["intents", "responses"], - True, - ), - ("data/test_yaml_stories/rules_without_stories.yml", ["rules"], True), - ("data/test_yaml_stories/rules_without_stories.yml", ["stories"], False), - ("data/test_stories/stories.md", ["something"], False), - ], -) -async def test_is_key_in_yaml(file: Text, keys: List[Text], expected_result: bool): - assert YAMLStoryReader.is_key_in_yaml(file, *keys) == expected_result - - async def test_yaml_intent_with_leading_slash_warning(default_domain: Domain): yaml_file = "data/test_wrong_yaml_stories/intent_with_leading_slash.yml" diff --git a/tests/shared/nlu/training_data/formats/test_rasa_yaml.py b/tests/shared/nlu/training_data/formats/test_rasa_yaml.py index c18f1235ed40..fa922ed6cd9a 100644 --- a/tests/shared/nlu/training_data/formats/test_rasa_yaml.py +++ b/tests/shared/nlu/training_data/formats/test_rasa_yaml.py @@ -360,9 +360,9 @@ def test_minimal_valid_example(): assert not len(record) -def test_minimal_yaml_nlu_file(tmp_path): +def test_minimal_yaml_nlu_file(tmp_path: pathlib.Path): target_file = tmp_path / "test_nlu_file.yaml" - rasa.shared.utils.io.write_yaml(MINIMAL_VALID_EXAMPLE, target_file, True) + rasa.shared.utils.io.write_text_file(MINIMAL_VALID_EXAMPLE, target_file) assert RasaYAMLReader.is_yaml_nlu_file(target_file) diff --git a/tests/shared/utils/test_io.py b/tests/shared/utils/test_io.py index 600861230654..28c80031893a 100644 --- a/tests/shared/utils/test_io.py +++ b/tests/shared/utils/test_io.py @@ -532,3 +532,23 @@ def test_read_invalid_config_file(tmp_path: Path, content: Text): with pytest.raises(rasa.shared.utils.validation.YamlValidationException): rasa.shared.utils.io.read_model_configuration(config_file) + + +@pytest.mark.parametrize( + "file,keys,expected_result", + [ + ("data/test_yaml_stories/stories.yml", ["stories"], True), + ("data/test_yaml_stories/stories.yml", ["something_else"], False), + ("data/test_yaml_stories/stories.yml", ["stories", "something_else"], True), + ( + "data/test_domains/default_retrieval_intents.yml", + ["intents", "responses"], + True, + ), + ("data/test_yaml_stories/rules_without_stories.yml", ["rules"], True), + ("data/test_yaml_stories/rules_without_stories.yml", ["stories"], False), + ("data/test_stories/stories.md", ["something"], False), + ], +) +async def test_is_key_in_yaml(file: Text, keys: List[Text], expected_result: bool): + assert rasa.shared.utils.io.is_key_in_yaml(file, *keys) == expected_result