From a98c38af14763fceeb173627e16e26b31f381be3 Mon Sep 17 00:00:00 2001 From: Nathan Fradet <56734983+Natooz@users.noreply.github.com> Date: Mon, 7 Aug 2023 15:35:11 +0200 Subject: [PATCH] improving tests / coverage, legacy sos_eos_tokens arg removed from _create_base_vocabulary --- miditok/tokenizations/mmm.py | 4 +- miditok/tokenizations/mumidi.py | 2 +- miditok/tokenizations/octuple_mono.py | 2 +- miditok/tokenizations/remi.py | 2 +- miditok/tokenizations/remi_plus.py | 4 +- miditok/tokenizations/structured.py | 7 +-- miditok/tokenizations/tsd.py | 2 +- tests/test_pytorch_data_loading.py | 71 ++++++++++++++++++++------- 8 files changed, 61 insertions(+), 33 deletions(-) diff --git a/miditok/tokenizations/mmm.py b/miditok/tokenizations/mmm.py index d6d218d2..c6b62811 100644 --- a/miditok/tokenizations/mmm.py +++ b/miditok/tokenizations/mmm.py @@ -374,9 +374,7 @@ def tokens_to_midi( midi.dump(output_path) return midi - def _create_base_vocabulary( - self, sos_eos_tokens: Optional[bool] = None - ) -> List[str]: + def _create_base_vocabulary(self) -> List[str]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 diff --git a/miditok/tokenizations/mumidi.py b/miditok/tokenizations/mumidi.py index 85d156ce..7d01356b 100644 --- a/miditok/tokenizations/mumidi.py +++ b/miditok/tokenizations/mumidi.py @@ -382,7 +382,7 @@ def tokens_to_track( """ pass - def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[List[str]]: + def _create_base_vocabulary(self) -> List[List[str]]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 diff --git a/miditok/tokenizations/octuple_mono.py b/miditok/tokenizations/octuple_mono.py index 242d838e..2c74affe 100644 --- a/miditok/tokenizations/octuple_mono.py +++ b/miditok/tokenizations/octuple_mono.py @@ -200,7 +200,7 @@ def tokens_to_track( return instrument, tempo_changes - def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[List[str]]: + def _create_base_vocabulary(self) -> List[List[str]]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 diff --git a/miditok/tokenizations/remi.py b/miditok/tokenizations/remi.py index 98bdc7ad..3ddd0b2d 100644 --- a/miditok/tokenizations/remi.py +++ b/miditok/tokenizations/remi.py @@ -286,7 +286,7 @@ def tokens_to_track( tempo_changes[0].time = 0 return instrument, tempo_changes - def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]: + def _create_base_vocabulary(self) -> List[str]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 diff --git a/miditok/tokenizations/remi_plus.py b/miditok/tokenizations/remi_plus.py index 5dc5cde7..7ebfaf49 100644 --- a/miditok/tokenizations/remi_plus.py +++ b/miditok/tokenizations/remi_plus.py @@ -444,9 +444,7 @@ def tokens_to_midi( midi.dump(output_path) return midi - def _create_base_vocabulary( - self, sos_eos_tokens: Optional[bool] = None - ) -> List[str]: + def _create_base_vocabulary(self) -> List[str]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 diff --git a/miditok/tokenizations/structured.py b/miditok/tokenizations/structured.py index 0f2cf09b..0aa5178f 100644 --- a/miditok/tokenizations/structured.py +++ b/miditok/tokenizations/structured.py @@ -184,7 +184,7 @@ def tokens_to_track( return instrument, [TempoChange(TEMPO, 0)] - def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]: + def _create_base_vocabulary(self) -> List[str]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 @@ -195,11 +195,6 @@ def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]: :return: the vocabulary as a list of string. """ - if sos_eos_tokens is not None: - print( - "\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, " - "_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m" - ) vocab = [] # PITCH diff --git a/miditok/tokenizations/tsd.py b/miditok/tokenizations/tsd.py index eaab2d73..87c6c1fc 100644 --- a/miditok/tokenizations/tsd.py +++ b/miditok/tokenizations/tsd.py @@ -390,7 +390,7 @@ def tokens_to_midi( midi.dump(output_path) return midi - def _create_base_vocabulary(self, sos_eos_tokens: bool = False) -> List[str]: + def _create_base_vocabulary(self) -> List[str]: r"""Creates the vocabulary, as a list of string tokens. Each token as to be given as the form of "Type_Value", separated with an underscore. Example: Pitch_58 diff --git a/tests/test_pytorch_data_loading.py b/tests/test_pytorch_data_loading.py index f954d671..ac2a996a 100644 --- a/tests/test_pytorch_data_loading.py +++ b/tests/test_pytorch_data_loading.py @@ -20,11 +20,15 @@ def test_split_seq(): def test_dataset_ram(): - # One token stream + multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3] + one_track_midis_paths = list(Path("tests", "Maestro_MIDIs").glob("**/*.mid"))[:3] + tokens_os_dir = Path("tests", "multitrack_tokens_os") + + # MIDI + One token stream config = miditok.TokenizerConfig(use_programs=True) tokenizer_os = miditok.TSD(config) dataset_os = miditok.pytorch_data.DatasetTok( - list(Path("tests", "Maestro_MIDIs").glob("**/*.mid")), + one_track_midis_paths, 50, 100, tokenizer_os, @@ -32,31 +36,50 @@ def test_dataset_ram(): for _ in dataset_os: pass - # Multiple token streams + # MIDI + Multiple token streams tokenizer_ms = miditok.TSD(miditok.TokenizerConfig()) dataset_ms = miditok.pytorch_data.DatasetTok( - list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid")), + multitrack_midis_paths, 50, 100, tokenizer_ms, ) + _ = dataset_ms.__repr__() + dataset_ms.reduce_nb_samples(2) + assert len(dataset_ms) == 2 + + # JSON + one token stream + if not tokens_os_dir.is_dir(): + tokenizer_os.tokenize_midi_dataset( + multitrack_midis_paths, + tokens_os_dir, + ) + _ = miditok.pytorch_data.DatasetTok( + list(tokens_os_dir.glob("**/*.json")), + 50, + 100, + ) assert True def test_dataset_io(): - midi_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3] - tokens_dir = Path("tests", "dataset_io_tokens") + multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3] + tokens_os_dir = Path("tests", "multitrack_tokens_os") - config = miditok.TokenizerConfig(use_programs=True) - tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(midi_paths, tokens_dir) + if not tokens_os_dir.is_dir(): + config = miditok.TokenizerConfig(use_programs=True) + tokenizer = miditok.TSD(config) + tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir) dataset = miditok.pytorch_data.DatasetJsonIO( - list(tokens_dir.glob("**/*.json")), + list(tokens_os_dir.glob("**/*.json")), 100, ) + dataset.reduce_nb_samples(2) + assert len(dataset) == 2 + for _ in dataset: pass @@ -64,23 +87,37 @@ def test_dataset_io(): def test_split_dataset_to_subsequences(): - midi_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3] - tokens_dir = Path("tests", "dataset_io_tokens") - tokens_split_dir = Path("tests", "dataset_io_tokens_split") + multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3] + tokens_os_dir = Path("tests", "multitrack_tokens_os") + tokens_split_dir = Path("tests", "multitrack_tokens_os_split") + tokens_split_dir_ms = Path("tests", "multitrack_tokens_ms_split") - if not tokens_dir.is_dir(): + # One token stream + if not tokens_os_dir.is_dir(): config = miditok.TokenizerConfig(use_programs=True) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(midi_paths, tokens_dir) - + tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir) miditok.pytorch_data.split_dataset_to_subsequences( - list(tokens_dir.glob("**/*.json")), + list(tokens_os_dir.glob("**/*.json")), tokens_split_dir, 50, 100, True, ) + # Multiple token streams + if not tokens_split_dir_ms.is_dir(): + config = miditok.TokenizerConfig(use_programs=False) + tokenizer = miditok.TSD(config) + tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_split_dir_ms) + miditok.pytorch_data.split_dataset_to_subsequences( + list(tokens_split_dir_ms.glob("**/*.json")), + tokens_split_dir, + 50, + 100, + False, + ) + assert True