Skip to content

Commit

Permalink
improving tests / coverage, legacy sos_eos_tokens arg removed from _c…
Browse files Browse the repository at this point in the history
…reate_base_vocabulary
  • Loading branch information
Natooz committed Aug 7, 2023
1 parent df96c76 commit a98c38a
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 33 deletions.
4 changes: 1 addition & 3 deletions miditok/tokenizations/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,7 @@ def tokens_to_midi(
midi.dump(output_path)
return midi

def _create_base_vocabulary(
self, sos_eos_tokens: Optional[bool] = None
) -> List[str]:
def _create_base_vocabulary(self) -> List[str]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/mumidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def tokens_to_track(
"""
pass

def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[List[str]]:
def _create_base_vocabulary(self) -> List[List[str]]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/octuple_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def tokens_to_track(

return instrument, tempo_changes

def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[List[str]]:
def _create_base_vocabulary(self) -> List[List[str]]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/remi.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def tokens_to_track(
tempo_changes[0].time = 0
return instrument, tempo_changes

def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]:
def _create_base_vocabulary(self) -> List[str]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand Down
4 changes: 1 addition & 3 deletions miditok/tokenizations/remi_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,7 @@ def tokens_to_midi(
midi.dump(output_path)
return midi

def _create_base_vocabulary(
self, sos_eos_tokens: Optional[bool] = None
) -> List[str]:
def _create_base_vocabulary(self) -> List[str]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand Down
7 changes: 1 addition & 6 deletions miditok/tokenizations/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def tokens_to_track(

return instrument, [TempoChange(TEMPO, 0)]

def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]:
def _create_base_vocabulary(self) -> List[str]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand All @@ -195,11 +195,6 @@ def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]:
:return: the vocabulary as a list of string.
"""
if sos_eos_tokens is not None:
print(
"\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, "
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = []

# PITCH
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/tsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def tokens_to_midi(
midi.dump(output_path)
return midi

def _create_base_vocabulary(self, sos_eos_tokens: bool = False) -> List[str]:
def _create_base_vocabulary(self) -> List[str]:
r"""Creates the vocabulary, as a list of string tokens.
Each token as to be given as the form of "Type_Value", separated with an underscore.
Example: Pitch_58
Expand Down
71 changes: 54 additions & 17 deletions tests/test_pytorch_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,67 +20,104 @@ def test_split_seq():


def test_dataset_ram():
# One token stream
multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3]
one_track_midis_paths = list(Path("tests", "Maestro_MIDIs").glob("**/*.mid"))[:3]
tokens_os_dir = Path("tests", "multitrack_tokens_os")

# MIDI + One token stream
config = miditok.TokenizerConfig(use_programs=True)
tokenizer_os = miditok.TSD(config)
dataset_os = miditok.pytorch_data.DatasetTok(
list(Path("tests", "Maestro_MIDIs").glob("**/*.mid")),
one_track_midis_paths,
50,
100,
tokenizer_os,
)
for _ in dataset_os:
pass

# Multiple token streams
# MIDI + Multiple token streams
tokenizer_ms = miditok.TSD(miditok.TokenizerConfig())
dataset_ms = miditok.pytorch_data.DatasetTok(
list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid")),
multitrack_midis_paths,
50,
100,
tokenizer_ms,
)
_ = dataset_ms.__repr__()
dataset_ms.reduce_nb_samples(2)
assert len(dataset_ms) == 2

# JSON + one token stream
if not tokens_os_dir.is_dir():
tokenizer_os.tokenize_midi_dataset(
multitrack_midis_paths,
tokens_os_dir,
)
_ = miditok.pytorch_data.DatasetTok(
list(tokens_os_dir.glob("**/*.json")),
50,
100,
)

assert True


def test_dataset_io():
midi_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3]
tokens_dir = Path("tests", "dataset_io_tokens")
multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3]
tokens_os_dir = Path("tests", "multitrack_tokens_os")

config = miditok.TokenizerConfig(use_programs=True)
tokenizer = miditok.TSD(config)
tokenizer.tokenize_midi_dataset(midi_paths, tokens_dir)
if not tokens_os_dir.is_dir():
config = miditok.TokenizerConfig(use_programs=True)
tokenizer = miditok.TSD(config)
tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir)

dataset = miditok.pytorch_data.DatasetJsonIO(
list(tokens_dir.glob("**/*.json")),
list(tokens_os_dir.glob("**/*.json")),
100,
)

dataset.reduce_nb_samples(2)
assert len(dataset) == 2

for _ in dataset:
pass

assert True


def test_split_dataset_to_subsequences():
midi_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3]
tokens_dir = Path("tests", "dataset_io_tokens")
tokens_split_dir = Path("tests", "dataset_io_tokens_split")
multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[:3]
tokens_os_dir = Path("tests", "multitrack_tokens_os")
tokens_split_dir = Path("tests", "multitrack_tokens_os_split")
tokens_split_dir_ms = Path("tests", "multitrack_tokens_ms_split")

if not tokens_dir.is_dir():
# One token stream
if not tokens_os_dir.is_dir():
config = miditok.TokenizerConfig(use_programs=True)
tokenizer = miditok.TSD(config)
tokenizer.tokenize_midi_dataset(midi_paths, tokens_dir)

tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir)
miditok.pytorch_data.split_dataset_to_subsequences(
list(tokens_dir.glob("**/*.json")),
list(tokens_os_dir.glob("**/*.json")),
tokens_split_dir,
50,
100,
True,
)

# Multiple token streams
if not tokens_split_dir_ms.is_dir():
config = miditok.TokenizerConfig(use_programs=False)
tokenizer = miditok.TSD(config)
tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_split_dir_ms)
miditok.pytorch_data.split_dataset_to_subsequences(
list(tokens_split_dir_ms.glob("**/*.json")),
tokens_split_dir,
50,
100,
False,
)

assert True


Expand Down

0 comments on commit a98c38a

Please sign in to comment.