From 89c4678673925d35d4c3ee18152a78349212d3d1 Mon Sep 17 00:00:00 2001 From: Nathan Fradet <56734983+Natooz@users.noreply.github.com> Date: Tue, 28 Nov 2023 09:01:23 +0100 Subject: [PATCH] Better tests + minor improvements (#108) * parametrizing tests, improvements in preprocess_midi, fixes for miditoolkit 1.0.1 * fixing absolute path for data aug and io tests * fix test file tree, data aug report saved in out_path * fix in data augmentation saving paths * forced disabling original in out_dir when calling data aug from tok_dataset * using pytest tmp_path to write files, and TEST_LOG_DIR if required * lighter and more elegant MIDI assertions + covering check_midi_equals * better tokenization test sets, set_midi_max_tick method, renamed "nb" contractions to "num", handling empty tokens lists in methods * dealing with empty midi file (#110) * dealing with empty midi file * add a new test midi tokenizer file instead of changing the original one * delete test_midi_tokenizer * Adding check empty input for _ids_to_tokens as well --------- Co-authored-by: Nathan Fradet <56734983+Natooz@users.noreply.github.com> * adding tests for empty MIDI and associated fixes * fixes from tests with empty midi + retry hf hub tests when http errors * fix convert_sequence_to_tokseq when list in last dim is empty * better tok test sets * testing with multiple time resolutions, adjusting notes ends * fix _quantize_time_signatures (delete_equal_successive_time_sig_changes) --------- Co-authored-by: feiyuehchen <46064584+feiyuehchen@users.noreply.github.com> --- .../{python-publish.yml => publish-pypi.yml} | 6 +- README.md | 2 +- miditok/classes.py | 40 +- miditok/constants.py | 4 +- miditok/midi_tokenizer.py | 66 +++- miditok/tokenizations/cp_word.py | 20 +- miditok/tokenizations/midi_like.py | 18 +- miditok/tokenizations/mmm.py | 13 +- miditok/tokenizations/mumidi.py | 10 +- miditok/tokenizations/octuple.py | 16 +- miditok/tokenizations/remi.py | 12 +- miditok/tokenizations/structured.py | 14 +- miditok/tokenizations/tsd.py | 14 +- miditok/utils/__init__.py | 2 + miditok/utils/utils.py | 44 ++- pyproject.toml | 3 +- .../Aicha.mid | Bin .../All The Small Things.mid | Bin .../Funkytown.mid | Bin .../Girls Just Want to Have Fun.mid | Bin .../I Gotta Feeling.mid | Bin .../In Too Deep.mid | Bin .../Les Yeux Revolvers.mid | Bin .../Mr. Blue Sky.mid | Bin .../Shut Up.mid | Bin .../What a Fool Believes.mid | Bin .../6338816_Etude No. 4.mid | Bin .../6354774_Macabre Waltz.mid | Bin .../Maestro_1.mid | Bin .../Maestro_10.mid | Bin .../Maestro_2.mid | Bin .../Maestro_3.mid | Bin .../Maestro_4.mid | Bin .../Maestro_5.mid | Bin .../Maestro_6.mid | Bin .../Maestro_7.mid | Bin .../Maestro_8.mid | Bin .../Maestro_9.mid | Bin .../POP909_008.mid | Bin .../POP909_010.mid | Bin .../POP909_022.mid | Bin .../POP909_191.mid | Bin tests/MIDIs_one_track/empty.mid | Bin 0 -> 76 bytes tests/conftest.py | 5 + tests/test_bpe.py | 249 ++++++------ tests/test_data_augmentation.py | 251 ++++++++++++ tests/test_hf_hub.py | 25 +- tests/test_io_formats.py | 193 ++++------ tests/test_methods.py | 277 +------------- tests/test_multitrack.py | 163 -------- tests/test_one_track.py | 167 -------- tests/test_pytorch_data_loading.py | 75 ++-- tests/test_results/.gitignore | 4 - tests/test_saving_loading_config.py | 62 ++- tests/test_tokenize_multitrack.py | 117 ++++++ tests/test_tokenize_one_track.py | 113 ++++++ tests/test_utils.py | 180 +++++++-- tests/tests_utils.py | 300 --------------- tests/utils.py | 360 ++++++++++++++++++ 59 files changed, 1462 insertions(+), 1363 deletions(-) rename .github/workflows/{python-publish.yml => publish-pypi.yml} (91%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/Aicha.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/All The Small Things.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/Funkytown.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/Girls Just Want to Have Fun.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/I Gotta Feeling.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/In Too Deep.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/Les Yeux Revolvers.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/Mr. Blue Sky.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/Shut Up.mid (100%) rename tests/{Multitrack_MIDIs => MIDIs_multitrack}/What a Fool Believes.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/6338816_Etude No. 4.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/6354774_Macabre Waltz.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_1.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_10.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_2.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_3.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_4.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_5.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_6.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_7.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_8.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/Maestro_9.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/POP909_008.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/POP909_010.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/POP909_022.mid (100%) rename tests/{One_track_MIDIs => MIDIs_one_track}/POP909_191.mid (100%) create mode 100644 tests/MIDIs_one_track/empty.mid create mode 100644 tests/test_data_augmentation.py delete mode 100644 tests/test_multitrack.py delete mode 100644 tests/test_one_track.py delete mode 100644 tests/test_results/.gitignore create mode 100644 tests/test_tokenize_multitrack.py create mode 100644 tests/test_tokenize_one_track.py delete mode 100644 tests/tests_utils.py create mode 100644 tests/utils.py diff --git a/.github/workflows/python-publish.yml b/.github/workflows/publish-pypi.yml similarity index 91% rename from .github/workflows/python-publish.yml rename to .github/workflows/publish-pypi.yml index 13291ca8..61c244dd 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/publish-pypi.yml @@ -6,7 +6,7 @@ # separate terms of service, privacy policy, and support # documentation. -name: Upload Python Package +name: Publish package on PyPi on: release: @@ -21,9 +21,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: Install dependencies diff --git a/README.md b/README.md index e0a65e78..a9a171d5 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ from miditoolkit import MidiFile from pathlib import Path # Creating a multitrack tokenizer configuration, read the doc to explore other parameters -config = TokenizerConfig(nb_velocities=16, use_chords=True, use_programs=True) +config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True) tokenizer = REMI(config) # Loads a midi, converts to tokens, and back to a MIDI diff --git a/miditok/classes.py b/miditok/classes.py index a09b806b..4ee234cf 100644 --- a/miditok/classes.py +++ b/miditok/classes.py @@ -2,6 +2,7 @@ Common classes. """ import json +import warnings from copy import deepcopy from dataclasses import dataclass from pathlib import Path @@ -20,8 +21,8 @@ DELETE_EQUAL_SUCCESSIVE_TIME_SIG_CHANGES, LOG_TEMPOS, MAX_PITCH_INTERVAL, - NB_TEMPOS, - NB_VELOCITIES, + NUM_TEMPOS, + NUM_VELOCITIES, ONE_TOKEN_STREAM_FOR_PROGRAMS, PITCH_BEND_RANGE, PITCH_INTERVALS_MAX_TIME_DIST, @@ -160,9 +161,9 @@ class TokenizerConfig: lengths / resolutions. Note: for tokenization with ``Position`` tokens, the total number of possible positions will be set at four times the maximum resolution given (``max(beat_res.values)``\). (default: ``{(0, 4): 8, (4, 12): 4}``) - :param nb_velocities: number of velocity bins. In the MIDI norm, velocities can take + :param num_velocities: number of velocity bins. In the MIDI norm, velocities can take up to 128 values (0 to 127). This parameter allows to reduce the number of velocity values. - The velocities of the MIDIs resolution will be downsampled to ``nb_velocities`` values, equally + The velocities of the MIDIs resolution will be downsampled to ``num_velocities`` values, equally separated between 0 and 127. (default: ``32``) :param special_tokens: list of special tokens. This must be given as a list of strings given only the names of the tokens. (default: ``["PAD", "BOS", "EOS", "MASK"]``\) @@ -178,7 +179,7 @@ class TokenizerConfig: values to represent with the ``beat_res_rest`` argument. (default: ``False``) :param use_tempos: will use ``Tempo`` tokens, if the tokenizer is compatible. ``Tempo`` tokens will specify the current tempo. This allows to train a model to predict tempo changes. - Tempo values are quantized accordingly to the ``nb_tempos`` and ``tempo_range`` entries in the + Tempo values are quantized accordingly to the ``num_tempos`` and ``tempo_range`` entries in the ``additional_tokens`` dictionary (default is 32 tempos from 40 to 250). (default: ``False``) :param use_time_signatures: will use ``TimeSignature`` tokens, if the tokenizer is compatible. ``TimeSignature`` tokens will specify the current time signature. Note that :ref:`REMI` adds a @@ -215,7 +216,7 @@ class TokenizerConfig: :param chord_unknown: range of number of notes to represent unknown chords. If you want to represent chords that does not match any combination in ``chord_maps``, use this argument. Leave ``None`` to not represent unknown chords. (default: ``None``) - :param nb_tempos: number of tempos "bins" to use. (default: ``32``) + :param num_tempos: number of tempos "bins" to use. (default: ``32``) :param tempo_range: range of minimum and maximum tempos within which the bins fall. (default: ``(40, 250)``) :param log_tempos: will use log scaled tempo values instead of linearly scaled. (default: ``False``) :param delete_equal_successive_tempo_changes: setting this option True will delete identical successive tempo @@ -234,7 +235,7 @@ class TokenizerConfig: durations. If you use this parameter, make sure to configure ``beat_res`` to cover the durations you expect. (default: ``False``) :param pitch_bend_range: range of the pitch bend to consider, to be given as a tuple with the form - ``(lowest_value, highest_value, nb_of_values)``. There will be ``nb_of_values`` tokens equally spaced + ``(lowest_value, highest_value, num_of_values)``. There will be ``num_of_values`` tokens equally spaced between ``lowest_value` and `highest_value``. (default: ``(-8192, 8191, 32)``) :param delete_equal_successive_time_sig_changes: setting this option True will delete identical successive time signature changes when preprocessing a MIDI file after loading it. For examples, if a MIDI has two time @@ -268,7 +269,7 @@ def __init__( self, pitch_range: Tuple[int, int] = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES, - nb_velocities: int = NB_VELOCITIES, + num_velocities: int = NUM_VELOCITIES, special_tokens: Sequence[str] = SPECIAL_TOKENS, use_chords: bool = USE_CHORDS, use_rests: bool = USE_RESTS, @@ -282,7 +283,7 @@ def __init__( chord_maps: Dict[str, Tuple] = CHORD_MAPS, chord_tokens_with_root_note: bool = CHORD_TOKENS_WITH_ROOT_NOTE, chord_unknown: Tuple[int, int] = CHORD_UNKNOWN, - nb_tempos: int = NB_TEMPOS, + num_tempos: int = NUM_TEMPOS, tempo_range: Tuple[int, int] = TEMPO_RANGE, log_tempos: bool = LOG_TEMPOS, delete_equal_successive_tempo_changes: bool = DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES, @@ -306,8 +307,8 @@ def __init__( f"(received {pitch_range})" ) assert ( - 1 <= nb_velocities <= 127 - ), f"nb_velocities must be within 1 and 127 (received {nb_velocities})" + 1 <= num_velocities <= 127 + ), f"num_velocities must be within 1 and 127 (received {num_velocities})" assert ( 0 <= max_pitch_interval <= 127 ), f"max_pitch_interval must be within 0 and 127 (received {max_pitch_interval})." @@ -315,7 +316,7 @@ def __init__( # Global parameters self.pitch_range: Tuple[int, int] = pitch_range self.beat_res: Dict[Tuple[int, int], int] = beat_res - self.nb_velocities: int = nb_velocities + self.num_velocities: int = num_velocities self.special_tokens: Sequence[str] = special_tokens # Additional token types params, enabling additional token types @@ -339,7 +340,7 @@ def __init__( self.chord_unknown: Tuple[int, int] = chord_unknown # Tempo params - self.nb_tempos: int = nb_tempos # nb of tempo bins for additional tempo tokens, quantized like velocities + self.num_tempos: int = num_tempos self.tempo_range: Tuple[int, int] = tempo_range # (min_tempo, max_tempo) self.log_tempos: bool = log_tempos self.delete_equal_successive_tempo_changes = ( @@ -372,6 +373,19 @@ def __init__( self.max_pitch_interval = max_pitch_interval self.pitch_intervals_max_time_dist = pitch_intervals_max_time_dist + # Pop legacy kwargs + legacy_args = ( + ("nb_velocities", "num_velocities"), + ("nb_tempos", "num_tempos"), + ) + for legacy_arg, new_arg in legacy_args: + if legacy_arg in kwargs: + setattr(self, new_arg, kwargs.pop(legacy_arg)) + warnings.warn( + f"Argument {legacy_arg} has been renamed {new_arg}, you should consider to update" + f"your code with this new argument name." + ) + # Additional params self.additional_params = kwargs diff --git a/miditok/constants.py b/miditok/constants.py index e4015f18..5133fc88 100644 --- a/miditok/constants.py +++ b/miditok/constants.py @@ -22,7 +22,7 @@ PITCH_RANGE = (21, 109) BEAT_RES = {(0, 4): 8, (4, 12): 4} # samples per beat # nb of velocity bins, velocities values from 0 to 127 will be quantized -NB_VELOCITIES = 32 +NUM_VELOCITIES = 32 # default special tokens SPECIAL_TOKENS = ["PAD", "BOS", "EOS", "MASK"] @@ -70,7 +70,7 @@ # Tempo params # nb of tempo bins for additional tempo tokens, quantized like velocities -NB_TEMPOS = 32 # TODO raname num contractions +NUM_TEMPOS = 32 TEMPO_RANGE = (40, 250) # (min_tempo, max_tempo) LOG_TEMPOS = False # log or linear scale tempos DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES = False diff --git a/miditok/midi_tokenizer.py b/miditok/midi_tokenizer.py index 3df10086..58bb7227 100644 --- a/miditok/midi_tokenizer.py +++ b/miditok/midi_tokenizer.py @@ -34,6 +34,7 @@ DEFAULT_TOKENIZER_FILE_NAME, MIDI_FILES_EXTENSIONS, PITCH_CLASSES, + TEMPO, TIME_DIVISION, TIME_SIGNATURE, UNKNOWN_CHORD_PREFIX, @@ -45,6 +46,7 @@ get_midi_programs, merge_same_program_tracks, remove_duplicated_notes, + set_midi_max_tick, ) @@ -76,18 +78,20 @@ def convert_sequence_to_tokseq( # Deduce nb of subscripts / dims nb_io_dims = len(tokenizer.io_format) nb_seq_dims = 1 - if isinstance(arg[1][0], list): + if len(arg[1]) > 0 and isinstance(arg[1][0], list): nb_seq_dims += 1 - if isinstance(arg[1][0][0], list): + if len(arg[1][0]) > 0 and isinstance(arg[1][0][0], list): + nb_seq_dims += 1 + elif len(arg[1][0]) == 0 and nb_seq_dims == nb_io_dims - 1: + # Special case where the sequence contains no tokens, we increment anyway nb_seq_dims += 1 # Check the number of dimensions is good # In case of no one_token_stream and one dimension short --> unsqueeze if not tokenizer.one_token_stream and nb_seq_dims == nb_io_dims - 1: print( - f"The input sequence has one dimension less than expected ({nb_seq_dims} instead of " - f"{nb_io_dims}). It is being unsqueezed to conform with the tokenizer's i/o format " - f"({tokenizer.io_format})" + f"The input sequence has one dimension less than expected ({nb_seq_dims} instead of {nb_io_dims})." + f"It is being unsqueezed to conform with the tokenizer's i/o format ({tokenizer.io_format})" ) arg = (arg[0], [arg[1]]) @@ -218,7 +222,7 @@ def __init__( self.config.pitch_range[0] >= 0 and self.config.pitch_range[1] <= 128 ), "You must specify a pitch_range between 0 and 127 (included, i.e. range.stop at 128)" assert ( - 0 < self.config.nb_velocities < 128 + 0 < self.config.num_velocities < 128 ), "You must specify a nb_velocities between 1 and 127 (included)" # Tweak the tokenizer's configuration and / or attributes before creating the vocabulary @@ -233,7 +237,7 @@ def __init__( self.durations = self.__create_durations_tuples() # [1:] so that there is no velocity_0 self.velocities = np.linspace( - 0, 127, self.config.nb_velocities + 1, dtype=np.intc + 0, 127, self.config.num_velocities + 1, dtype=np.intc )[1:] self._first_beat_res = list(self.config.beat_res.values())[0] for beat_range, res in self.config.beat_res.items(): @@ -242,9 +246,15 @@ def __init__( break # Tempos + # _DEFAULT_TEMPO is useful when `log_tempos` is enabled self.tempos = np.zeros(1) + self._DEFAULT_TEMPO = TEMPO if self.config.use_tempos: self.tempos = self.__create_tempos() + if self.config.log_tempos: + self._DEFAULT_TEMPO = self.tempos[ + np.argmin(np.abs(self.tempos - TEMPO)) + ] # Rests self.rests = [] @@ -366,8 +376,7 @@ def preprocess_midi(self, midi: MidiFile): if self.config.use_programs and self.one_token_stream: merge_same_program_tracks(midi.instruments) - t = 0 - while t < len(midi.instruments): + for t in range(len(midi.instruments) - 1, -1, -1): # quantize notes attributes self._quantize_notes(midi.instruments[t].notes, midi.ticks_per_beat) # sort notes @@ -388,17 +397,12 @@ def preprocess_midi(self, midi: MidiFile): midi.instruments[t].pitch_bends, midi.ticks_per_beat ) # TODO quantize control changes - t += 1 - - # Recalculate max_tick is this could have changed after notes quantization - if len(midi.instruments) > 0: - midi.max_tick = max( - [max([note.end for note in track.notes]) for track in midi.instruments] - ) + # Process tempo changes if self.config.use_tempos: self._quantize_tempos(midi.tempo_changes, midi.ticks_per_beat) + # Process time signature changes if len(midi.time_signature_changes) == 0: # can sometimes happen midi.time_signature_changes.append( TimeSignature(*TIME_SIGNATURE, 0) @@ -408,6 +412,11 @@ def preprocess_midi(self, midi: MidiFile): midi.time_signature_changes, midi.ticks_per_beat ) + # We do not change key signature changes, markers and lyrics here as they are not used by MidiTok (yet) + + # Recalculate max_tick is this could have changed after notes quantization + set_midi_max_tick(midi) + def _quantize_notes(self, notes: List[Note], time_division: int): r"""Quantize the notes attributes: their pitch, velocity, start and end values. It shifts the notes so that they start at times that match the time resolution @@ -464,6 +473,11 @@ def _quantize_tempos(self, tempos: List[TempoChange], time_division: int): """ ticks_per_sample = int(time_division / max(self.config.beat_res.values())) prev_tempo = TempoChange(-1, -1) + # If we delete the successive equal tempo changes, we need to sort them by time + # Otherwise it is not required here as the tokens will be sorted by time + if self.config.delete_equal_successive_tempo_changes: + tempos.sort(key=lambda x: x.time) + i = 0 while i < len(tempos): # Quantize tempo value @@ -505,6 +519,11 @@ def _quantize_time_signatures( ) previous_tick = 0 # first time signature change is always at tick 0 prev_ts = time_sigs[0] + # If we delete the successive equal tempo changes, we need to sort them by time + # Otherwise it is not required here as the tokens will be sorted by time + if self.config.delete_equal_successive_time_sig_changes: + time_sigs.sort(key=lambda x: x.time) + i = 1 while i < len(time_sigs): time_sig = time_sigs[i] @@ -562,7 +581,6 @@ def _quantize_sustain_pedals(self, pedals: List[Pedal], time_division: int): ) if pedal.start == pedal.end: pedal.end += ticks_per_sample - pedal.duration = pedal.end - pedal.start def _quantize_pitch_bends(self, pitch_bends: List[PitchBend], time_division: int): r"""Quantize the pitch bend events from a track. Their onset and offset times will be adjusted @@ -609,6 +627,8 @@ def _midi_to_tokens( # Create events list all_events = [] if not self.one_token_stream: + if len(midi.instruments) == 0: + all_events.append([]) for i in range(len(midi.instruments)): all_events.append([]) @@ -778,7 +798,7 @@ def _create_track_events(self, track: Instrument) -> List[Event]: # Pitch / interval add_absolute_pitch_token = True - if self.config.use_pitch_intervals: + if self.config.use_pitch_intervals and not track.is_drum: if note.start != previous_note_onset: if ( note.start - previous_note_onset <= max_time_interval @@ -1487,7 +1507,7 @@ def __create_tempos(self) -> np.ndarray: :return: the tempos. """ tempo_fn = np.geomspace if self.config.log_tempos else np.linspace - tempos = tempo_fn(*self.config.tempo_range, self.config.nb_tempos).round(2) + tempos = tempo_fn(*self.config.tempo_range, self.config.num_tempos).round(2) return tempos @@ -1810,7 +1830,7 @@ def decode_bpe(self, seq: Union[TokSequence, List[TokSequence]]): def tokenize_midi_dataset( self, - midi_paths: Union[str, Path, List[str], List[Path]], + midi_paths: Union[str, Path, Sequence[Union[str, Path]]], out_dir: Union[str, Path], overwrite_mode: bool = True, tokenizer_config_file_name: str = DEFAULT_TOKENIZER_FILE_NAME, @@ -1854,7 +1874,7 @@ def tokenize_midi_dataset( out_dir.mkdir(parents=True, exist_ok=True) # User gave a path to a directory, we'll scan it to find MIDI files - if not isinstance(midi_paths, list): + if not isinstance(midi_paths, Sequence): if isinstance(midi_paths, str): midi_paths = Path(midi_paths) root_dir = midi_paths @@ -1988,6 +2008,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + elif len(tokens) == 0: + return 0 nb_tok_predicted = len(tokens) # used to norm the score if self.has_bpe: @@ -2094,6 +2116,8 @@ def save_tokens( self.complete_sequence(tokens) ids_bpe_encoded = tokens.ids_bpe_encoded ids = tokens.ids + elif isinstance(tokens, list) and len(tokens) == 0: + pass elif isinstance(tokens[0], TokSequence): ids_bpe_encoded = [] for seq in tokens: diff --git a/miditok/tokenizations/cp_word.py b/miditok/tokenizations/cp_word.py index 0d4794a5..4835ede1 100644 --- a/miditok/tokenizations/cp_word.py +++ b/miditok/tokenizations/cp_word.py @@ -7,8 +7,9 @@ from miditoolkit import Instrument, MidiFile, Note, TempoChange, TimeSignature from ..classes import Event, TokSequence -from ..constants import MIDI_INSTRUMENTS, TEMPO, TIME_DIVISION, TIME_SIGNATURE +from ..constants import MIDI_INSTRUMENTS, TIME_DIVISION, TIME_SIGNATURE from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class CPWord(MIDITokenizer): @@ -91,9 +92,11 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]: current_time_sig = TIME_SIGNATURE if self.config.log_tempos: # pick the closest to the default value - current_tempo = float(self.tempos[(np.abs(self.tempos - TEMPO)).argmin()]) + current_tempo = float( + self.tempos[(np.abs(self.tempos - self._DEFAULT_TEMPO)).argmin()] + ) else: - current_tempo = TEMPO + current_tempo = self._DEFAULT_TEMPO current_program = None ticks_per_bar = self._compute_ticks_per_bar( TimeSignature(*current_time_sig, 0), time_division @@ -372,7 +375,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [] def check_inst(prog: int): @@ -536,12 +539,7 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -675,6 +673,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + if len(tokens) == 0: + return 0 def cp_token_type(tok: List[int]) -> List[str]: family = self[0, tok[0]].split("_")[1] diff --git a/miditok/tokenizations/midi_like.py b/miditok/tokenizations/midi_like.py index 195c38f4..f3cbf82b 100644 --- a/miditok/tokenizations/midi_like.py +++ b/miditok/tokenizations/midi_like.py @@ -15,12 +15,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq -from ..utils import fix_offsets_overlapping_notes +from ..utils import fix_offsets_overlapping_notes, set_midi_max_tick class MIDILike(MIDITokenizer): @@ -174,7 +173,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] active_notes = {p: {} for p in self.config.programs} @@ -356,12 +355,7 @@ def clear_active_notes(): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -427,7 +421,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: first_note_token_type = "NoteOn" dic["Velocity"] = [first_note_token_type, "TimeShift"] dic["NoteOff"] = ["NoteOff", first_note_token_type, "TimeShift"] - dic["TimeShift"] = ["NoteOff", first_note_token_type] + dic["TimeShift"] = ["NoteOff", first_note_token_type, "TimeShift"] if self.config.use_pitch_intervals: for token_type in ("PitchIntervalTime", "PitchIntervalChord"): dic[token_type] = ["Velocity"] @@ -594,6 +588,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + if len(tokens) == 0: + return 0 nb_tok_predicted = len(tokens) # used to norm the score if self.has_bpe: @@ -620,7 +616,7 @@ def tokens_errors( ) for i in range(1, len(events)): - # err_tokens = events[i - 4: i + 4] # uncomment for debug + # err_tokens = events[i - 4 : i + 4] # uncomment for debug # Good token type if events[i].type in self.tokens_types_graph[events[i - 1].type]: if events[i].type in [ diff --git a/miditok/tokenizations/mmm.py b/miditok/tokenizations/mmm.py index a3c06e0c..46339da0 100644 --- a/miditok/tokenizations/mmm.py +++ b/miditok/tokenizations/mmm.py @@ -9,11 +9,11 @@ from ..constants import ( MIDI_INSTRUMENTS, MMM_DENSITY_BINS_MAX, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class MMM(MIDITokenizer): @@ -223,7 +223,7 @@ def tokens_to_midi( # RESULTS instruments: List[Instrument] = [] tempo_changes = [ - TempoChange(TEMPO, -1) + TempoChange(self._DEFAULT_TEMPO, -1) ] # mock the first tempo change to optimize below time_signature_changes = [ TimeSignature(*TIME_SIGNATURE, 0) @@ -321,12 +321,7 @@ def tokens_to_midi( midi.instruments = instruments midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -436,6 +431,8 @@ def tokens_errors( ) -> float: tokens_to_check = cast(TokSequence, tokens_to_check) nb_tok_predicted = len(tokens_to_check) # used to norm the score + if nb_tok_predicted == 0: + return 0 if self.has_bpe: self.decode_bpe(tokens_to_check) self.complete_sequence(tokens_to_check) diff --git a/miditok/tokenizations/mumidi.py b/miditok/tokenizations/mumidi.py index 1a382d10..20752b4b 100644 --- a/miditok/tokenizations/mumidi.py +++ b/miditok/tokenizations/mumidi.py @@ -9,11 +9,10 @@ from ..constants import ( DRUM_PITCH_RANGE, MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq -from ..utils import detect_chords +from ..utils import detect_chords, set_midi_max_tick class MuMIDI(MIDITokenizer): @@ -292,10 +291,10 @@ def tokens_to_midi( midi = MidiFile(ticks_per_beat=time_division) # Tempos - if self.config.use_tempos: + if self.config.use_tempos and len(tokens) > 0: first_tempo = float(tokens.tokens[0][3].split("_")[1]) else: - first_tempo = TEMPO + first_tempo = self._DEFAULT_TEMPO midi.tempo_changes.append(TempoChange(first_tempo, 0)) ticks_per_sample = time_division // max(self.config.beat_res.values()) @@ -351,6 +350,7 @@ def tokens_to_midi( ) ) midi.instruments[-1].notes = notes + set_midi_max_tick(midi) # Write MIDI file if output_path: @@ -462,6 +462,8 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl :param tokens: sequence of tokens to check :return: the error ratio (lower is better) """ + if len(tokens) == 0: + return 0 tokens = tokens.tokens err = 0 previous_type = tokens[0][0].split("_")[0] diff --git a/miditok/tokenizations/octuple.py b/miditok/tokenizations/octuple.py index 17b178fd..7c3f42d7 100644 --- a/miditok/tokenizations/octuple.py +++ b/miditok/tokenizations/octuple.py @@ -8,11 +8,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class Octuple(MIDITokenizer): @@ -102,7 +102,7 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]: current_pos = 0 previous_tick = 0 current_time_sig = TIME_SIGNATURE - current_tempo = TEMPO + current_tempo = self._DEFAULT_TEMPO current_program = None ticks_per_bar = self._compute_ticks_per_bar( TimeSignature(*current_time_sig, 0), time_division @@ -214,7 +214,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [] def check_inst(prog: int): @@ -339,12 +339,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -429,6 +425,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + if len(tokens) == 0: + return 0 err = 0 current_bar = current_pos = -1 diff --git a/miditok/tokenizations/remi.py b/miditok/tokenizations/remi.py index 767e2e51..39560213 100644 --- a/miditok/tokenizations/remi.py +++ b/miditok/tokenizations/remi.py @@ -16,11 +16,11 @@ from ..classes import Event, TokenizerConfig, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class REMI(MIDITokenizer): @@ -261,7 +261,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [] def check_inst(prog: int): @@ -457,12 +457,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) diff --git a/miditok/tokenizations/structured.py b/miditok/tokenizations/structured.py index 35dc0e13..19a721ed 100644 --- a/miditok/tokenizations/structured.py +++ b/miditok/tokenizations/structured.py @@ -7,11 +7,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class Structured(MIDITokenizer): @@ -150,6 +150,8 @@ def _midi_to_tokens( all_events = [] # Adds note tokens + if not self.one_token_stream and len(midi.instruments) == 0: + all_events.append([]) for track in midi.instruments: note_events = self._create_track_events(track) if self.one_token_stream: @@ -204,7 +206,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, 0)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, 0)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] def check_inst(prog: int): @@ -274,12 +276,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) diff --git a/miditok/tokenizations/tsd.py b/miditok/tokenizations/tsd.py index aedddf89..be476cf8 100644 --- a/miditok/tokenizations/tsd.py +++ b/miditok/tokenizations/tsd.py @@ -15,11 +15,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class TSD(MIDITokenizer): @@ -142,7 +142,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] def check_inst(prog: int): @@ -304,12 +304,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -371,7 +367,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: dic["Pitch"] = ["Velocity"] dic["Velocity"] = ["Duration"] dic["Duration"] = [first_note_token_type, "TimeShift"] - dic["TimeShift"] = [first_note_token_type] + dic["TimeShift"] = [first_note_token_type, "TimeShift"] if self.config.use_pitch_intervals: for token_type in ("PitchIntervalTime", "PitchIntervalChord"): dic[token_type] = ["Velocity"] diff --git a/miditok/utils/__init__.py b/miditok/utils/__init__.py index 88cacd46..59cc4a57 100644 --- a/miditok/utils/__init__.py +++ b/miditok/utils/__init__.py @@ -8,6 +8,7 @@ merge_tracks_per_class, nb_bar_pos, remove_duplicated_notes, + set_midi_max_tick, ) __all__ = [ @@ -20,4 +21,5 @@ "merge_tracks", "merge_same_program_tracks", "nb_bar_pos", + "set_midi_max_tick", ] diff --git a/miditok/utils/utils.py b/miditok/utils/utils.py index c20ff969..70effbbb 100644 --- a/miditok/utils/utils.py +++ b/miditok/utils/utils.py @@ -38,10 +38,12 @@ def convert_ids_tensors_to_list(ids: Any): # Recursively checks the content are ints (only check first item) el = ids[0] while isinstance(el, list): - el = el[0] + el = el[0] if len(el) > 0 else None # Check endpoint type - if not isinstance(el, int): + if el is None: + pass + elif not isinstance(el, int): # Recursively try to convert elements of the list for ei in range(len(ids)): ids[ei] = convert_ids_tensors_to_list(ids[ei]) @@ -391,6 +393,44 @@ def merge_same_program_tracks(tracks: List[Instrument]): del tracks[i] +def set_midi_max_tick(midi: MidiFile): + midi.max_tick = 0 + + # Parse track events + if len(midi.instruments) > 0: + event_type_attr = ( + ("notes", "end"), + ("pedals", "end"), + ("control_changes", "time"), + ("pitch_bends", "time"), + ) + for track in midi.instruments: + for event_type, time_attr in event_type_attr: + if len(getattr(track, event_type)) > 0: + midi.max_tick = max( + midi.max_tick, + max( + [ + getattr(event, time_attr) + for event in getattr(track, event_type) + ] + ), + ) + + # Parse global MIDI events + for event_type in ( + "tempo_changes", + "time_signature_changes", + "key_signature_changes", + "lyrics", + ): + if len(getattr(midi, event_type)) > 0: + midi.max_tick = max( + midi.max_tick, + max(event.time for event in getattr(midi, event_type)), + ) + + def nb_bar_pos( seq: Sequence[int], bar_token: int, position_tokens: Sequence[int] ) -> Tuple[int, int]: diff --git a/pyproject.toml b/pyproject.toml index 286ef4ee..aa4f6c5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,10 @@ classifiers = [ ] dependencies = [ "numpy>=1.19", - "miditoolkit", # TODO >=v1.0.1? + "miditoolkit", "tqdm", "tokenizers>=0.13.0", "huggingface_hub>=0.16.4", - "scipy", # TODO remove when miditoolkit v1.0.1 ] [project.optional-dependencies] diff --git a/tests/Multitrack_MIDIs/Aicha.mid b/tests/MIDIs_multitrack/Aicha.mid similarity index 100% rename from tests/Multitrack_MIDIs/Aicha.mid rename to tests/MIDIs_multitrack/Aicha.mid diff --git a/tests/Multitrack_MIDIs/All The Small Things.mid b/tests/MIDIs_multitrack/All The Small Things.mid similarity index 100% rename from tests/Multitrack_MIDIs/All The Small Things.mid rename to tests/MIDIs_multitrack/All The Small Things.mid diff --git a/tests/Multitrack_MIDIs/Funkytown.mid b/tests/MIDIs_multitrack/Funkytown.mid similarity index 100% rename from tests/Multitrack_MIDIs/Funkytown.mid rename to tests/MIDIs_multitrack/Funkytown.mid diff --git a/tests/Multitrack_MIDIs/Girls Just Want to Have Fun.mid b/tests/MIDIs_multitrack/Girls Just Want to Have Fun.mid similarity index 100% rename from tests/Multitrack_MIDIs/Girls Just Want to Have Fun.mid rename to tests/MIDIs_multitrack/Girls Just Want to Have Fun.mid diff --git a/tests/Multitrack_MIDIs/I Gotta Feeling.mid b/tests/MIDIs_multitrack/I Gotta Feeling.mid similarity index 100% rename from tests/Multitrack_MIDIs/I Gotta Feeling.mid rename to tests/MIDIs_multitrack/I Gotta Feeling.mid diff --git a/tests/Multitrack_MIDIs/In Too Deep.mid b/tests/MIDIs_multitrack/In Too Deep.mid similarity index 100% rename from tests/Multitrack_MIDIs/In Too Deep.mid rename to tests/MIDIs_multitrack/In Too Deep.mid diff --git a/tests/Multitrack_MIDIs/Les Yeux Revolvers.mid b/tests/MIDIs_multitrack/Les Yeux Revolvers.mid similarity index 100% rename from tests/Multitrack_MIDIs/Les Yeux Revolvers.mid rename to tests/MIDIs_multitrack/Les Yeux Revolvers.mid diff --git a/tests/Multitrack_MIDIs/Mr. Blue Sky.mid b/tests/MIDIs_multitrack/Mr. Blue Sky.mid similarity index 100% rename from tests/Multitrack_MIDIs/Mr. Blue Sky.mid rename to tests/MIDIs_multitrack/Mr. Blue Sky.mid diff --git a/tests/Multitrack_MIDIs/Shut Up.mid b/tests/MIDIs_multitrack/Shut Up.mid similarity index 100% rename from tests/Multitrack_MIDIs/Shut Up.mid rename to tests/MIDIs_multitrack/Shut Up.mid diff --git a/tests/Multitrack_MIDIs/What a Fool Believes.mid b/tests/MIDIs_multitrack/What a Fool Believes.mid similarity index 100% rename from tests/Multitrack_MIDIs/What a Fool Believes.mid rename to tests/MIDIs_multitrack/What a Fool Believes.mid diff --git a/tests/One_track_MIDIs/6338816_Etude No. 4.mid b/tests/MIDIs_one_track/6338816_Etude No. 4.mid similarity index 100% rename from tests/One_track_MIDIs/6338816_Etude No. 4.mid rename to tests/MIDIs_one_track/6338816_Etude No. 4.mid diff --git a/tests/One_track_MIDIs/6354774_Macabre Waltz.mid b/tests/MIDIs_one_track/6354774_Macabre Waltz.mid similarity index 100% rename from tests/One_track_MIDIs/6354774_Macabre Waltz.mid rename to tests/MIDIs_one_track/6354774_Macabre Waltz.mid diff --git a/tests/One_track_MIDIs/Maestro_1.mid b/tests/MIDIs_one_track/Maestro_1.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_1.mid rename to tests/MIDIs_one_track/Maestro_1.mid diff --git a/tests/One_track_MIDIs/Maestro_10.mid b/tests/MIDIs_one_track/Maestro_10.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_10.mid rename to tests/MIDIs_one_track/Maestro_10.mid diff --git a/tests/One_track_MIDIs/Maestro_2.mid b/tests/MIDIs_one_track/Maestro_2.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_2.mid rename to tests/MIDIs_one_track/Maestro_2.mid diff --git a/tests/One_track_MIDIs/Maestro_3.mid b/tests/MIDIs_one_track/Maestro_3.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_3.mid rename to tests/MIDIs_one_track/Maestro_3.mid diff --git a/tests/One_track_MIDIs/Maestro_4.mid b/tests/MIDIs_one_track/Maestro_4.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_4.mid rename to tests/MIDIs_one_track/Maestro_4.mid diff --git a/tests/One_track_MIDIs/Maestro_5.mid b/tests/MIDIs_one_track/Maestro_5.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_5.mid rename to tests/MIDIs_one_track/Maestro_5.mid diff --git a/tests/One_track_MIDIs/Maestro_6.mid b/tests/MIDIs_one_track/Maestro_6.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_6.mid rename to tests/MIDIs_one_track/Maestro_6.mid diff --git a/tests/One_track_MIDIs/Maestro_7.mid b/tests/MIDIs_one_track/Maestro_7.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_7.mid rename to tests/MIDIs_one_track/Maestro_7.mid diff --git a/tests/One_track_MIDIs/Maestro_8.mid b/tests/MIDIs_one_track/Maestro_8.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_8.mid rename to tests/MIDIs_one_track/Maestro_8.mid diff --git a/tests/One_track_MIDIs/Maestro_9.mid b/tests/MIDIs_one_track/Maestro_9.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_9.mid rename to tests/MIDIs_one_track/Maestro_9.mid diff --git a/tests/One_track_MIDIs/POP909_008.mid b/tests/MIDIs_one_track/POP909_008.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_008.mid rename to tests/MIDIs_one_track/POP909_008.mid diff --git a/tests/One_track_MIDIs/POP909_010.mid b/tests/MIDIs_one_track/POP909_010.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_010.mid rename to tests/MIDIs_one_track/POP909_010.mid diff --git a/tests/One_track_MIDIs/POP909_022.mid b/tests/MIDIs_one_track/POP909_022.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_022.mid rename to tests/MIDIs_one_track/POP909_022.mid diff --git a/tests/One_track_MIDIs/POP909_191.mid b/tests/MIDIs_one_track/POP909_191.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_191.mid rename to tests/MIDIs_one_track/POP909_191.mid diff --git a/tests/MIDIs_one_track/empty.mid b/tests/MIDIs_one_track/empty.mid new file mode 100644 index 0000000000000000000000000000000000000000..4d2cb2e75ba109e2bb61da1eb58f9e234e7fd2e1 GIT binary patch literal 76 zcmeYb$w*;fU|<7cM#cxeAw}6hmKno;My|a4l2nC~qQvBEhW~-g> 0 + and tokenizer[original_track[idx - 1]] == "Program_-1" + ): + pitch_offset = 0 + assert aug_token == original_token + pitch_offset + elif original_token in vel_tokens: + assert aug_token in [ + original_token + offsets[1], + tok_vel_min, + tok_vel_max, + ] + elif original_token in dur_tokens and tokenization != "MIDILike": + assert aug_token in [ + original_token + offsets[2], + tok_dur_min, + tok_dur_max, + ] + elif original_token in note_off_tokens: + assert aug_token == original_token + offsets[0] + else: + if original_token[pitch_voc_idx] in pitch_tokens: + assert ( + aug_token[pitch_voc_idx] + == original_token[pitch_voc_idx] + offsets[0] + ) + elif original_token[vel_voc_idx] in vel_tokens: + assert aug_token[vel_voc_idx] in [ + original_token[vel_voc_idx] + offsets[1], + tok_vel_min, + tok_vel_max, + ] + elif ( + original_token[dur_voc_idx] in dur_tokens + and tokenization != "MIDILike" + ): + assert aug_token[dur_voc_idx] in [ + original_token[dur_voc_idx] + offsets[2], + tok_dur_min, + tok_dur_max, + ] + + +"""def time_data_augmentation_tokens_vs_mid(): + from time import time + tokenizers = [miditok.TSD(), miditok.REMI()] + data_paths = [Path('./tests/One_track_MIDIs'), Path('./tests/Multitrack_MIDIs')] + + for data_path in data_paths: + for tokenizer in tokenizers: + print(f'\n{data_path.stem} - {type(tokenizer).__name__}') + files = list(data_path.glob('**/*.mid')) + + # Testing opening midi -> augment midis -> tokenize midis + t0 = time() + for file_path in files: + # Reads the MIDI + try: + midi = MidiFile(Path(file_path)) + except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError + continue + + offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, midi=midi) + midis = miditok.data_augmentation.data_augmentation_midi(midi, tokenizer, *offsets) + for _, aug_mid in midis: + _ = tokenizer(aug_mid) + tt = time() - t0 + print(f'Opening midi -> augment midis -> tokenize midis: took {tt:.2f} sec ' + f'({tt / len(files):.2f} sec/file)') + + # Testing opening midi -> tokenize midi -> augment tokens + t0 = time() + for file_path in files: + # Reads the MIDI + try: + midi = MidiFile(Path(file_path)) + except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError + continue + + tokens = tokenizer(midi) + for track_tokens in tokens: + offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, tokens=tokens) + _ = miditok.data_augmentation.data_augmentation_tokens(track_tokens, tokenizer, *offsets) + tt = time() - t0 + print(f'Opening midi -> tokenize midi -> augment tokens: took {tt:.2f} sec ' + f'({tt / len(files):.2f} sec/file)')""" diff --git a/tests/test_hf_hub.py b/tests/test_hf_hub.py index de0c5229..8c64b5d5 100644 --- a/tests/test_hf_hub.py +++ b/tests/test_hf_hub.py @@ -4,20 +4,37 @@ """ +from pathlib import Path +from time import sleep + +from huggingface_hub.utils._errors import HfHubHTTPError + from miditok import REMI, TSD +MAX_NUM_TRIES_HF_PUSH = 5 +NUM_SECONDS_RETRY = 8 + def test_push_and_load_to_hf_hub(hf_token: str): tokenizer = REMI() - tokenizer.push_to_hub("Natooz/MidiTok-tests", private=True, token=hf_token) + num_tries = 0 + while num_tries < MAX_NUM_TRIES_HF_PUSH: + try: + tokenizer.push_to_hub("Natooz/MidiTok-tests", private=True, token=hf_token) + except HfHubHTTPError as e: + if e.response.status_code in [500, 412, 429]: + num_tries += 1 + sleep(NUM_SECONDS_RETRY) + else: + num_tries = MAX_NUM_TRIES_HF_PUSH tokenizer2 = REMI.from_pretrained("Natooz/MidiTok-tests", token=hf_token) assert tokenizer == tokenizer2 -def test_from_pretrained_local(): +def test_from_pretrained_local(tmp_path: Path): # Here using paths to directories tokenizer = TSD() - tokenizer.save_pretrained("tests/tokenizer_confs") - tokenizer2 = TSD.from_pretrained("tests/tokenizer_confs") + tokenizer.save_pretrained(tmp_path) + tokenizer2 = TSD.from_pretrained(tmp_path) assert tokenizer == tokenizer2 diff --git a/tests/test_io_formats.py b/tests/test_io_formats.py index 93ad5559..22b09736 100644 --- a/tests/test_io_formats.py +++ b/tests/test_io_formats.py @@ -6,71 +6,79 @@ from copy import deepcopy from pathlib import Path +from typing import Any, Dict, Tuple, Union +import pytest from miditoolkit import MidiFile import miditok -from .tests_utils import ALL_TOKENIZATIONS, midis_equals - -BEAT_RES_TEST = {(0, 16): 8} -TOKENIZER_PARAMS = { - "beat_res": BEAT_RES_TEST, - "use_chords": True, - "use_rests": True, - "use_tempos": True, - "use_time_signatures": True, - "use_sustain_pedals": True, - "use_pitch_bends": True, - "use_programs": False, - "chord_maps": miditok.constants.CHORD_MAPS, - "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" - "chord_unknown": (3, 6), - "beat_res_rest": {(0, 16): 4}, - "nb_tempos": 32, - "tempo_range": (40, 250), - "time_signature_range": {4: [4]}, -} - -programs_tokenizations = ["TSD", "REMI", "MIDILike", "Structured", "CPWord"] -test_cases_programs = [ - ( - { - "use_programs": True, - "one_token_stream_for_programs": True, - "program_changes": False, - }, - [], - ), - ( - { - "use_programs": True, - "one_token_stream_for_programs": True, - "program_changes": True, - }, - ["Structured", "CPWord"], - ), - ( - { - "use_programs": True, - "one_token_stream_for_programs": False, - "program_changes": False, - }, - ["Structured"], - ), +from .utils import ( + ALL_TOKENIZATIONS, + HERE, + TOKENIZER_CONFIG_KWARGS, + adjust_tok_params_for_tests, + prepare_midi_for_tests, +) + +default_params = deepcopy(TOKENIZER_CONFIG_KWARGS) +default_params.update( + { + "use_chords": True, + "use_rests": True, + "use_tempos": True, + "use_time_signatures": True, + "use_sustain_pedals": True, + "use_pitch_bends": True, + } +) +tokenizations_no_one_stream = [ + "TSD", + "REMI", + "MIDILike", + "Structured", + "CPWord", + "Octuple", ] - - -def encode_decode_and_check(tokenizer: miditok.MIDITokenizer, midi: MidiFile): +configs = ( + { + "use_programs": True, + "one_token_stream_for_programs": True, + "program_changes": False, + }, + { + "use_programs": True, + "one_token_stream_for_programs": True, + "program_changes": True, + }, + { + "use_programs": True, + "one_token_stream_for_programs": False, + "program_changes": False, + }, +) +TOK_PARAMS_IO = [] +for tokenization_ in ALL_TOKENIZATIONS: + params_ = deepcopy(default_params) + adjust_tok_params_for_tests(tokenization_, params_) + TOK_PARAMS_IO.append((tokenization_, params_)) + + if tokenization_ in tokenizations_no_one_stream: + for config in configs: + params_tmp = deepcopy(params_) + params_tmp.update(config) + TOK_PARAMS_IO.append((tokenization_, params_tmp)) + + +def encode_decode_and_check(tokenizer: miditok.MIDITokenizer, midi: MidiFile) -> bool: + """Tests if a + + :param tokenizer: + :param midi: + :return: + """ # Process the MIDI - midi_to_compare = deepcopy(midi) - for track in midi_to_compare.instruments: - if track.is_drum: - track.program = 0 # need to be done before sorting tracks per program - # MIDI produced with one_token_stream contains tracks with different orders - midi_to_compare.instruments.sort( - key=lambda x: (x.program, x.is_drum) - ) # sort tracks + midi_to_compare = prepare_midi_for_tests(midi) # Convert the midi to tokens, and keeps the ids (integers) tokens = tokenizer(midi_to_compare) @@ -90,62 +98,27 @@ def encode_decode_and_check(tokenizer: miditok.MIDITokenizer, midi: MidiFile): return True # Checks its good - decoded_midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) - if type(tokenizer).__name__ == "MIDILike": - for track in decoded_midi.instruments: - track.notes.sort(key=lambda x: (x.start, x.pitch, x.end)) - errors = midis_equals(midi_to_compare, decoded_midi) - if len(errors) > 0: - print( - f"Failed to encode/decode NOTES with {tokenizer.__class__.__name__} ({len(errors)} errors)" - ) - return True + decoded_midi = prepare_midi_for_tests(decoded_midi, sort_notes=True) + return decoded_midi == midi_to_compare - return False - -def test_io_formats(): +@pytest.mark.parametrize("tok_params_set", TOK_PARAMS_IO) +def test_io_formats( + tok_params_set: Tuple[str, Dict[str, Any]], + midi_path: Union[str, Path] = HERE / "MIDIs_multitrack" / "Funkytown.mid", +): r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. The converted back MIDI files should identical to original one, expect with note starting and ending - times quantized, and maybe a some duplicated notes removed + times quantized, and maybe a some duplicated notes removed. + + :param tok_params_set: tokenizer and its parameters to run. + :param midi_path: path to the MIDI file to test. """ - at_least_one_error = False - - file_path = Path("tests", "Multitrack_MIDIs", "Funkytown.mid") - midi = MidiFile(file_path) - - for tokenization in ALL_TOKENIZATIONS: - params = deepcopy(TOKENIZER_PARAMS) - if tokenization == "Structured": - params["beat_res"] = {(0, 512): 8} - elif tokenization == "Octuple": - params["use_time_signatures"] = False - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - - at_least_one_error = ( - encode_decode_and_check(tokenizer, midi) or at_least_one_error - ) - - # If TSD, also test in use_programs / one_token_stream mode - if tokenization in programs_tokenizations: - for custom_params, excluded_tok in test_cases_programs: - if tokenization in excluded_tok: - continue - params = deepcopy(TOKENIZER_PARAMS) - params.update(custom_params) - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - at_least_one_error = ( - encode_decode_and_check(tokenizer, midi) or at_least_one_error - ) + midi = MidiFile(midi_path) + tokenization, params = tok_params_set + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=miditok.TokenizerConfig(**params) + ) + at_least_one_error = encode_decode_and_check(tokenizer, midi) assert not at_least_one_error - - -if __name__ == "__main__": - test_io_formats() diff --git a/tests/test_methods.py b/tests/test_methods.py index 5f71bf57..dbd37a2f 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -4,13 +4,9 @@ """ -import json -import random from pathlib import Path -from typing import Union +from typing import Sequence, Union -import numpy as np -from miditoolkit import MidiFile from tensorflow import Tensor as tfTensor from tensorflow import convert_to_tensor from torch import ( @@ -22,11 +18,10 @@ from torch import ( Tensor as ptTensor, ) -from tqdm import tqdm import miditok -from .tests_utils import ALL_TOKENIZATIONS +from .utils import HERE, MIDI_PATHS_ALL def test_convert_tensors(): @@ -44,259 +39,23 @@ def test_convert_tensors(): assert as_list == original -"""def time_data_augmentation_tokens_vs_mid(): - from time import time - tokenizers = [miditok.TSD(), miditok.REMI()] - data_paths = [Path('./tests/One_track_MIDIs'), Path('./tests/Multitrack_MIDIs')] +def test_tokenize_datasets_file_tree( + tmp_path: Path, midi_paths: Sequence[Union[str, Path]] = None +): + if midi_paths is None: + midi_paths = MIDI_PATHS_ALL - for data_path in data_paths: - for tokenizer in tokenizers: - print(f'\n{data_path.stem} - {type(tokenizer).__name__}') - files = list(data_path.glob('**/*.mid')) - - # Testing opening midi -> augment midis -> tokenize midis - t0 = time() - for file_path in files: - # Reads the MIDI - try: - midi = MidiFile(Path(file_path)) - except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue - - offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, midi=midi) - midis = miditok.data_augmentation.data_augmentation_midi(midi, tokenizer, *offsets) - for _, aug_mid in midis: - _ = tokenizer(aug_mid) - tt = time() - t0 - print(f'Opening midi -> augment midis -> tokenize midis: took {tt:.2f} sec ' - f'({tt / len(files):.2f} sec/file)') - - # Testing opening midi -> tokenize midi -> augment tokens - t0 = time() - for file_path in files: - # Reads the MIDI - try: - midi = MidiFile(Path(file_path)) - except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue - - tokens = tokenizer(midi) - for track_tokens in tokens: - offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, tokens=tokens) - _ = miditok.data_augmentation.data_augmentation_tokens(track_tokens, tokenizer, *offsets) - tt = time() - t0 - print(f'Opening midi -> tokenize midi -> augment tokens: took {tt:.2f} sec ' - f'({tt / len(files):.2f} sec/file)')""" - - -def test_data_augmentation(): - data_path = Path("./tests/Multitrack_MIDIs") - original_midi_paths = list(data_path.glob("**/*.mid"))[:7] - ALL_TOKENIZATIONS.remove("MuMIDI") # not compatible - - for tokenization in ALL_TOKENIZATIONS: - print(f"TESTING WITH {tokenization}") - tokenizer = getattr(miditok, tokenization)() - midi_aug_path = Path("tests", "Multitrack_MIDIs_aug", tokenization) - tokens_path = Path("tests", "Multitrack_tokens", tokenization) - tokens_aug_path = Path("tests", "Multitrack_tokens_aug", tokenization) - - # We only perform and test data augmentation on MIDIs once, as tokenizers does not play here - if tokenization == "MIDILike": - print("PERFORMING DATA AUGMENTATION ON MIDIS") - miditok.data_augmentation.data_augmentation_dataset( - data_path, - tokenizer, - 2, - 1, - 1, - out_path=midi_aug_path, - copy_original_in_new_location=False, - ) - aug_midi_paths = list(midi_aug_path.glob("**/*.mid")) - for aug_midi_path in tqdm( - aug_midi_paths, desc="CHECKING DATA AUGMENTATION ON MIDIS" - ): - if "Mr. Blue Sky" in aug_midi_path.stem: - continue # TODO remove when miditoolkit v1.0.1 is released - # Determine offsets of file - parts = aug_midi_path.stem.split("ยง") - original_stem, offsets_str = parts[0], parts[1].split("_") - offsets = [0, 0, 0] - for offset_str in offsets_str: - for pos, letter in enumerate(["p", "v", "d"]): - if offset_str[0] == letter: - offsets[pos] = int(offset_str[1:]) - - # Loads MIDIs to compare - try: - aug_midi = MidiFile(aug_midi_path) - original_midi = MidiFile(data_path / f"{original_stem}.mid") - except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue - - # Compare them - for original_track, aug_track in zip( - original_midi.instruments, aug_midi.instruments - ): - if original_track.is_drum: - continue - original_track.notes.sort( - key=lambda x: (x.start, x.pitch, x.end, x.velocity) - ) # sort notes - aug_track.notes.sort( - key=lambda x: (x.start, x.pitch, x.end, x.velocity) - ) # sort notes - for note_o, note_s in zip(original_track.notes, aug_track.notes): - assert note_s.pitch == note_o.pitch + offsets[0] - assert note_s.velocity in [ - tokenizer.velocities[0], - tokenizer.velocities[-1], - note_o.velocity + offsets[1], - ] - - print("PERFORMING DATA AUGMENTATION ON TOKENS") - tokenizer.tokenize_midi_dataset(original_midi_paths, tokens_path) - miditok.data_augmentation.data_augmentation_dataset( - tokens_path, - tokenizer, - 2, - 1, - 1, - out_path=tokens_aug_path, - copy_original_in_new_location=False, - ) - - # Getting tokens idx from tokenizer for assertions - aug_tokens_paths = list(tokens_aug_path.glob("**/*.json")) - pitch_voc_idx, vel_voc_idx, dur_voc_idx = None, None, None - note_off_tokens = [] - if tokenizer.is_multi_voc: - pitch_voc_idx = tokenizer.vocab_types_idx["Pitch"] - vel_voc_idx = tokenizer.vocab_types_idx["Velocity"] - dur_voc_idx = tokenizer.vocab_types_idx["Duration"] - pitch_tokens = np.array(tokenizer.token_ids_of_type("Pitch", pitch_voc_idx)) - vel_tokens = np.array(tokenizer.token_ids_of_type("Velocity", vel_voc_idx)) - dur_tokens = np.array(tokenizer.token_ids_of_type("Duration", dur_voc_idx)) - else: - pitch_tokens = np.array( - tokenizer.token_ids_of_type("Pitch") - + tokenizer.token_ids_of_type("NoteOn") - ) - vel_tokens = np.array(tokenizer.token_ids_of_type("Velocity")) - dur_tokens = np.array(tokenizer.token_ids_of_type("Duration")) - note_off_tokens = np.array( - tokenizer.token_ids_of_type("NoteOff") - ) # for MidiLike - tok_vel_min, tok_vel_max = vel_tokens[0], vel_tokens[-1] - tok_dur_min, tok_dur_max = None, None - if tokenization != "MIDILike": - tok_dur_min, tok_dur_max = dur_tokens[0], dur_tokens[-1] - - for aug_token_path in aug_tokens_paths: - # Determine offsets of file - parts = aug_token_path.stem.split("ยง") - original_stem, offsets_str = parts[0], parts[1].split("_") - offsets = [0, 0, 0] - for offset_str in offsets_str: - for pos, letter in enumerate(["p", "v", "d"]): - if offset_str[0] == letter: - offsets[pos] = int(offset_str[1:]) - - # Loads tokens to compare - with open(aug_token_path) as json_file: - file = json.load(json_file) - aug_tokens = file["ids"] - - with open(tokens_path / f"{original_stem}.json") as json_file: - file = json.load(json_file) - original_tokens = file["ids"] - original_programs = file["programs"] if "programs" in file else None - - # Compare them - if tokenizer.one_token_stream: - original_tokens, aug_tokens = [original_tokens], [aug_tokens] - for ti, (original_track, aug_track) in enumerate( - zip(original_tokens, aug_tokens) - ): - if original_programs is not None and original_programs[ti][1]: # drums - continue - for idx, (original_token, aug_token) in enumerate( - zip(original_track, aug_track) - ): - if not tokenizer.is_multi_voc: - if original_token in pitch_tokens: - pitch_offset = offsets[0] - # no offset for drum pitches - if ( - tokenizer.one_token_stream - and idx > 0 - and tokenizer[original_track[idx - 1]] == "Program_-1" - ): - pitch_offset = 0 - assert aug_token == original_token + pitch_offset - elif original_token in vel_tokens: - assert aug_token in [ - original_token + offsets[1], - tok_vel_min, - tok_vel_max, - ] - elif ( - original_token in dur_tokens and tokenization != "MIDILike" - ): - assert aug_token in [ - original_token + offsets[2], - tok_dur_min, - tok_dur_max, - ] - elif original_token in note_off_tokens: - assert aug_token == original_token + offsets[0] - else: - if original_token[pitch_voc_idx] in pitch_tokens: - assert ( - aug_token[pitch_voc_idx] - == original_token[pitch_voc_idx] + offsets[0] - ) - elif original_token[vel_voc_idx] in vel_tokens: - assert aug_token[vel_voc_idx] in [ - original_token[vel_voc_idx] + offsets[1], - tok_vel_min, - tok_vel_max, - ] - elif ( - original_token[dur_voc_idx] in dur_tokens - and tokenization != "MIDILike" - ): - assert aug_token[dur_voc_idx] in [ - original_token[dur_voc_idx] + offsets[2], - tok_dur_min, - tok_dur_max, - ] - - -def test_tokenize_datasets(data_path: Union[str, Path] = Path("./tests")): # Check the file tree is copied - random.seed(8) - midi_paths = list((data_path / "One_track_MIDIs").glob("**/*.mid")) + list( - (data_path / "Multitrack_MIDIs").glob("**/*.mid") - ) - midi_paths = random.choices(midi_paths, k=6) - config = miditok.TokenizerConfig() - tokenizer = miditok.TSD(config) - out_path = Path("tests", "test_results", "file_tree") - tokenizer.tokenize_midi_dataset(midi_paths, out_path) - json_paths = list(out_path.glob("**/*.json")) + tokenizer = miditok.TSD(miditok.TokenizerConfig()) + tokenizer.tokenize_midi_dataset(midi_paths, tmp_path, overwrite_mode=True) + json_paths = list(tmp_path.glob("**/*.json")) json_paths.sort(key=lambda x: x.stem) midi_paths.sort(key=lambda x: x.stem) - assert all( - json_path.relative_to(out_path).with_suffix(".test") - == midi_path.relative_to(data_path).with_suffix(".test") - for json_path, midi_path in zip(json_paths, midi_paths) - ) - tokenizer.tokenize_midi_dataset(midi_paths, out_path, overwrite_mode=False) - - -if __name__ == "__main__": - test_tokenize_datasets() - test_convert_tensors() - test_data_augmentation() + for json_path, midi_path in zip(json_paths, midi_paths): + assert ( + json_path.relative_to(tmp_path).with_suffix(".test") + == midi_path.relative_to(HERE).with_suffix(".test") + ), f"The file tree has not been reproduced as it should, for the file {midi_path} tokenized {json_path}" + + # Just make sure the non-overwrite mode doesn't crash + tokenizer.tokenize_midi_dataset(midi_paths, tmp_path, overwrite_mode=False) diff --git a/tests/test_multitrack.py b/tests/test_multitrack.py deleted file mode 100644 index 3549af7d..00000000 --- a/tests/test_multitrack.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/python3 python - -"""Multitrack test file -""" - -from copy import deepcopy -from pathlib import Path -from time import time -from typing import Union - -from miditoolkit import MidiFile, Pedal -from tqdm import tqdm - -import miditok - -from .tests_utils import ( - ALL_TOKENIZATIONS, - adapt_tempo_changes_times, - remove_equal_successive_tempos, - tokenize_check_equals, -) - -BEAT_RES_TEST = {(0, 16): 8} -TOKENIZER_PARAMS = { - "beat_res": BEAT_RES_TEST, - "use_chords": True, - "use_rests": True, # tempo decode fails when False for MIDILike because beat_res range is too short - "use_tempos": True, - "use_time_signatures": True, - "use_sustain_pedals": True, - "use_pitch_bends": True, - "use_programs": True, - "chord_maps": miditok.constants.CHORD_MAPS, - "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" - "chord_unknown": (3, 6), - "beat_res_rest": {(0, 2): 4, (2, 12): 2}, - "nb_tempos": 32, - "tempo_range": (40, 250), - "log_tempos": False, - "sustain_pedal_duration": False, - "one_token_stream_for_programs": True, - "program_changes": False, -} - -# Define kwargs sets -# The first set is empty, using the default params -params_kwargs_sets = {tok: [{}] for tok in ALL_TOKENIZATIONS} -programs_tokenizations = ["TSD", "REMI", "MIDILike", "Structured", "CPWord", "Octuple"] -for tok in programs_tokenizations: - params_kwargs_sets[tok].append( - {"one_token_stream_for_programs": False}, - ) -for tok in ["TSD", "REMI", "MIDILike"]: - params_kwargs_sets[tok].append( - {"program_changes": True}, - ) -# Disable tempos for Octuple with one_token_stream_for_programs, as tempos are carried by note tokens, and -# time signatures for the same reasons (as time could be shifted by on or several bars) -params_kwargs_sets["Octuple"][1]["use_tempos"] = False -params_kwargs_sets["Octuple"][0]["use_time_signatures"] = False -params_kwargs_sets["Octuple"][1]["use_time_signatures"] = False -# Increase the TimeShift voc for Structured as it doesn't support successive TimeShifts -for kwargs_set in params_kwargs_sets["Structured"]: - kwargs_set["beat_res"] = {(0, 512): 8} - - -def test_multitrack_midi_to_tokens_to_midi( - data_path: Union[str, Path] = "./tests/Multitrack_MIDIs", - saving_erroneous_midis: bool = False, -): - r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. - The converted back MIDI files should identical to original one, expect with note starting and ending - times quantized, and maybe a some duplicated notes removed - """ - files = list(Path(data_path).glob("**/*.mid")) - at_least_one_error = False - t0 = time() - - for fi, file_path in enumerate(tqdm(files, desc="Testing multitrack")): - # Reads the MIDI - midi = MidiFile(Path(file_path)) - if midi.ticks_per_beat % max(BEAT_RES_TEST.values()) != 0: - continue - # add pedal messages - for ti in range(max(3, len(midi.instruments))): - midi.instruments[ti].pedals = [ - Pedal(start, start + 200) for start in [100, 600, 1800, 2200] - ] - - for tokenization in ALL_TOKENIZATIONS: - for pi, params_kwargs in enumerate(params_kwargs_sets[tokenization]): - idx = f"{fi}_{pi}" - params = deepcopy(TOKENIZER_PARAMS) - params.update(params_kwargs) - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - - # Process the MIDI - # midi notes / tempos / time signature quantized with the line above - midi_to_compare = deepcopy(midi) - for track in midi_to_compare.instruments: - if track.is_drum: - track.program = ( - 0 # need to be done before sorting tracks per program - ) - - # Sort and merge tracks if needed - # MIDI produced with one_token_stream contains tracks with different orders - # This step is also performed in preprocess_midi, but we need to call it here for the assertions below - tokenizer.preprocess_midi(midi_to_compare) - # For Octuple, as tempo is only carried at notes times, we need to adapt their times for comparison - # Same for CPWord which carries tempo with Position (for notes) - if tokenization in ["Octuple", "CPWord"]: - adapt_tempo_changes_times( - midi_to_compare.instruments, midi_to_compare.tempo_changes - ) - # When the tokenizer only decoded tempo changes different from the last tempo val - if tokenization in ["CPWord"]: - remove_equal_successive_tempos(midi_to_compare.tempo_changes) - - # MIDI -> Tokens -> MIDI - decoded_midi, has_errors = tokenize_check_equals( - midi_to_compare, tokenizer, idx, file_path.stem - ) - - if has_errors: - at_least_one_error = True - if saving_erroneous_midis: - decoded_midi.dump( - Path( - "tests", - "test_results", - f"{file_path.stem}_{tokenization}.mid", - ) - ) - midi_to_compare.dump( - Path( - "tests", - "test_results", - f"{file_path.stem}_{tokenization}_original.mid", - ) - ) - - ttotal = time() - t0 - print(f"Took {ttotal:.2f} seconds") - assert not at_least_one_error - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="MIDI Encoding test") - parser.add_argument( - "--data", - type=str, - default="tests/Multitrack_MIDIs", - help="directory of MIDI files to use for test", - ) - args = parser.parse_args() - - test_multitrack_midi_to_tokens_to_midi(args.data) diff --git a/tests/test_one_track.py b/tests/test_one_track.py deleted file mode 100644 index 195d7278..00000000 --- a/tests/test_one_track.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/python3 python - -"""One track test file -""" - -from copy import deepcopy -from pathlib import Path, PurePath -from time import time -from typing import Union - -from miditoolkit import MidiFile -from tqdm import tqdm - -import miditok -from miditok.constants import CHORD_MAPS - -from .tests_utils import ( - ALL_TOKENIZATIONS, - TIME_SIGNATURE_RANGE_TESTS, - adapt_tempo_changes_times, - adjust_pedal_durations, - remove_equal_successive_tempos, - tokenize_check_equals, -) - -BEAT_RES_TEST = {(0, 16): 8} -TOKENIZER_PARAMS = { - "beat_res": BEAT_RES_TEST, - "use_chords": False, # set false to speed up tests as it takes some time on maestro MIDIs - "use_rests": True, - "use_tempos": True, - "use_time_signatures": True, - "use_sustain_pedals": True, - "use_pitch_bends": True, - "use_programs": False, - "use_pitch_intervals": True, - "beat_res_rest": {(0, 2): 4, (2, 12): 2}, - "nb_tempos": 32, - "tempo_range": (40, 250), - "log_tempos": True, - "time_signature_range": TIME_SIGNATURE_RANGE_TESTS, - "chord_maps": CHORD_MAPS, - "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" - "chord_unknown": False, - "delete_equal_successive_time_sig_changes": True, - "delete_equal_successive_tempo_changes": True, - "sustain_pedal_duration": True, - "one_token_stream_for_programs": False, - "program_changes": True, -} - - -def test_one_track_midi_to_tokens_to_midi( - data_path: Union[str, Path, PurePath] = "./tests/One_track_MIDIs", - saving_erroneous_midis: bool = True, -): - r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. - The converted back MIDI files should identical to original one, expect with note starting and ending - times quantized, and maybe a some duplicated notes removed - - :param data_path: root path to the data to test - :param saving_erroneous_midis: will save MIDIs converted back with errors, to be used to debug - """ - files = list(Path(data_path).glob("**/*.mid")) - at_least_one_error = False - t0 = time() - - for i, file_path in enumerate(tqdm(files, desc="Testing One Track")): - # Reads the midi - midi = MidiFile(file_path) - # midi.instruments = [midi.instruments[0]] - # Will store the tracks tokenized / detokenized, to be saved in case of errors - for ti, track in enumerate(midi.instruments): - track.name = f"original {ti} not quantized" - tracks_with_errors = [] - - for tokenization in ALL_TOKENIZATIONS: - params = deepcopy(TOKENIZER_PARAMS) - # Special beat res for test, up to 64 beats so the duration and time-shift values are - # long enough for Structured, and with a single beat resolution - if tokenization == "Structured": - params["beat_res"] = {(0, 64): 8} - elif tokenization == "Octuple": - params["max_bar_embedding"] = 300 - params["use_time_signatures"] = False # because of time shifted - elif tokenization == "CPWord": - # Rests and time sig can mess up with CPWord, when a Rest that is crossing new bar is followed - # by a new TimeSig change, as TimeSig are carried with Bar tokens (and there is None is this case) - if params["use_time_signatures"] and params["use_rests"]: - params["use_rests"] = False - - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - - # Process the MIDI - # midi notes / tempos / time signature quantized with the line above - midi_to_compare = deepcopy(midi) - for track in midi_to_compare.instruments: - if track.is_drum: - track.program = ( - 0 # need to be done before sorting tracks per program - ) - - # This step is also performed in preprocess_midi, but we need to call it here for the assertions below - tokenizer.preprocess_midi(midi_to_compare) - # For Octuple, as tempo is only carried at notes times, we need to adapt their times for comparison - # Same for CPWord which carries tempo with Position (for notes) - if tokenization in ["Octuple", "CPWord"]: - # We use the first track only, as it is the one for which tempos are decoded - adapt_tempo_changes_times( - [midi_to_compare.instruments[0]], midi_to_compare.tempo_changes - ) - # When the tokenizer only decoded tempo changes different from the last tempo val - if tokenization in ["CPWord"]: - remove_equal_successive_tempos(midi_to_compare.tempo_changes) - # Adjust pedal ends to the maximum possible value - if tokenizer.config.use_sustain_pedals: - for track in midi_to_compare.instruments: - adjust_pedal_durations(track.pedals, tokenizer, midi.ticks_per_beat) - # Store preprocessed track - if len(tracks_with_errors) == 0: - tracks_with_errors += midi_to_compare.instruments - for ti, track in enumerate(midi_to_compare.instruments): - track.name = f"original {ti} quantized" - - # printing the tokenizer shouldn't fail - _ = str(tokenizer) - - # MIDI -> Tokens -> MIDI - decoded_midi, has_errors = tokenize_check_equals( - midi_to_compare, tokenizer, i, file_path.stem - ) - - # Add track to error list - if has_errors: - for ti, track in enumerate(decoded_midi.instruments): - track.name = f"{ti} encoded with {tokenization}" - tracks_with_errors += decoded_midi.instruments - - # > 1 as the first one is the preprocessed - if len(tracks_with_errors) > len(midi.instruments): - at_least_one_error = True - if saving_erroneous_midis: - midi.tempo_changes = midi_to_compare.tempo_changes - midi.time_signature_changes = midi_to_compare.time_signature_changes - midi.instruments += tracks_with_errors - midi.dump(PurePath("tests", "test_results", file_path.name)) - - ttotal = time() - t0 - print(f"Took {ttotal:.2f} seconds") - assert not at_least_one_error - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="MIDI Encoding test") - parser.add_argument( - "--data", - type=str, - default="tests/One_track_MIDIs", - help="directory of MIDI files to use for test", - ) - args = parser.parse_args() - test_one_track_midi_to_tokens_to_midi(args.data) diff --git a/tests/test_pytorch_data_loading.py b/tests/test_pytorch_data_loading.py index 4197b867..ca371559 100644 --- a/tests/test_pytorch_data_loading.py +++ b/tests/test_pytorch_data_loading.py @@ -5,13 +5,15 @@ """ from pathlib import Path -from typing import Sequence +from typing import Sequence, Union from miditoolkit import MidiFile from torch import randint import miditok +from .utils import MIDI_PATHS_MULTITRACK, MIDI_PATHS_ONE_TRACK + def test_split_seq(): min_seq_len = 50 @@ -26,16 +28,20 @@ def test_split_seq(): ], "Sequence split failed" -def test_dataset_ram(): - multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[ - :3 - ] - one_track_midis_paths = list(Path("tests", "One_track_MIDIs").glob("**/*.mid"))[:3] - tokens_os_dir = Path("tests", "multitrack_tokens_os") +def test_dataset_ram( + tmp_path: Path, + midi_paths_one_track: Sequence[Union[str, Path]] = None, + midi_paths_multitrack: Sequence[Union[str, Path]] = None, +): + if midi_paths_one_track is None: + midi_paths_one_track = MIDI_PATHS_ONE_TRACK[:3] + if midi_paths_multitrack is None: + midi_paths_multitrack = MIDI_PATHS_MULTITRACK[:3] + tokens_os_dir = tmp_path / "multitrack_tokens_os" dummy_labels = { label: i for i, label in enumerate( - set(path.name.split("_")[0] for path in one_track_midis_paths) + set(path.name.split("_")[0] for path in midi_paths_one_track) ) } @@ -52,7 +58,7 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: config = miditok.TokenizerConfig(use_programs=True) tokenizer_os = miditok.TSD(config) dataset_os = miditok.pytorch_data.DatasetTok( - one_track_midis_paths, + midi_paths_one_track, 50, 100, tokenizer_os, @@ -72,7 +78,7 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: # MIDI + Multiple token streams + labels tokenizer_ms = miditok.TSD(miditok.TokenizerConfig()) dataset_ms = miditok.pytorch_data.DatasetTok( - multitrack_midis_paths, + midi_paths_multitrack, 50, 100, tokenizer_ms, @@ -85,7 +91,7 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: # JSON + one token stream if not tokens_os_dir.is_dir(): tokenizer_os.tokenize_midi_dataset( - multitrack_midis_paths, + midi_paths_multitrack, tokens_os_dir, ) _ = miditok.pytorch_data.DatasetTok( @@ -95,19 +101,16 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: func_to_get_labels=get_labels_multitrack_one_stream, ) - assert True - -def test_dataset_io(): - multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[ - :3 - ] - tokens_os_dir = Path("tests", "multitrack_tokens_os") +def test_dataset_io(tmp_path: Path, midi_path: Sequence[Union[str, Path]] = None): + if midi_path is None: + midi_path = MIDI_PATHS_MULTITRACK[:3] + tokens_os_dir = tmp_path / "multitrack_tokens_os" if not tokens_os_dir.is_dir(): config = miditok.TokenizerConfig(use_programs=True) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir) + tokenizer.tokenize_midi_dataset(midi_path, tokens_os_dir) dataset = miditok.pytorch_data.DatasetJsonIO( list(tokens_os_dir.glob("**/*.json")), @@ -120,22 +123,22 @@ def test_dataset_io(): for _ in dataset: pass - assert True - -def test_split_dataset_to_subsequences(): - multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[ - :3 - ] - tokens_os_dir = Path("tests", "multitrack_tokens_os") - tokens_split_dir = Path("tests", "multitrack_tokens_os_split") - tokens_split_dir_ms = Path("tests", "multitrack_tokens_ms_split") +def test_split_dataset_to_subsequences( + tmp_path: Path, + midi_path: Sequence[Union[str, Path]] = None, +): + if midi_path is None: + midi_path = MIDI_PATHS_MULTITRACK[:3] + tokens_os_dir = tmp_path / "multitrack_tokens_os" + tokens_split_dir = tmp_path / "multitrack_tokens_os_split" + tokens_split_dir_ms = tmp_path / "multitrack_tokens_ms_split" # One token stream if not tokens_os_dir.is_dir(): config = miditok.TokenizerConfig(use_programs=True) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir) + tokenizer.tokenize_midi_dataset(midi_path, tokens_os_dir) miditok.pytorch_data.split_dataset_to_subsequences( list(tokens_os_dir.glob("**/*.json")), tokens_split_dir, @@ -148,7 +151,7 @@ def test_split_dataset_to_subsequences(): if not tokens_split_dir_ms.is_dir(): config = miditok.TokenizerConfig(use_programs=False) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_split_dir_ms) + tokenizer.tokenize_midi_dataset(midi_path, tokens_split_dir_ms) miditok.pytorch_data.split_dataset_to_subsequences( list(tokens_split_dir_ms.glob("**/*.json")), tokens_split_dir, @@ -157,8 +160,6 @@ def test_split_dataset_to_subsequences(): False, ) - assert True - def test_collator(): collator = miditok.pytorch_data.DataCollator( @@ -197,13 +198,3 @@ def test_collator(): max(seq_lengths) + 1, 5, ] - - assert True - - -if __name__ == "__main__": - test_split_seq() - test_dataset_ram() - test_dataset_io() - test_split_dataset_to_subsequences() - test_collator() diff --git a/tests/test_results/.gitignore b/tests/test_results/.gitignore deleted file mode 100644 index 5e7d2734..00000000 --- a/tests/test_results/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/tests/test_saving_loading_config.py b/tests/test_saving_loading_config.py index 09d8f88a..cbfd44ef 100644 --- a/tests/test_saving_loading_config.py +++ b/tests/test_saving_loading_config.py @@ -5,9 +5,13 @@ """ +from pathlib import Path + +import pytest + import miditok -from .tests_utils import ALL_TOKENIZATIONS +from .utils import ALL_TOKENIZATIONS ADDITIONAL_TOKENS_TEST = { "use_chords": False, # set False to speed up tests as it takes some time on maestro MIDIs @@ -21,41 +25,35 @@ } -def test_saving_loading_tokenizer_config(): - for tokenization in ALL_TOKENIZATIONS: - config1 = miditok.TokenizerConfig() - config1.save_to_json(f"./tests/configs/tok_conf_{tokenization}.json") +@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS) +def test_saving_loading_tokenizer_config(tokenization: str, tmp_path: Path): + config1 = miditok.TokenizerConfig() + config1.save_to_json(tmp_path / f"tok_conf_{tokenization}.json") - config2 = miditok.TokenizerConfig.load_from_json( - f"./tests/configs/tok_conf_{tokenization}.json" - ) + config2 = miditok.TokenizerConfig.load_from_json( + tmp_path / f"tok_conf_{tokenization}.json" + ) - assert config1 == config2 - config1.pitch_range = (0, 777) - assert config1 != config2 + assert config1 == config2 + config1.pitch_range = (0, 777) + assert config1 != config2 -def test_saving_loading_tokenizer(): +@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS) +def test_saving_loading_tokenizer(tokenization: str, tmp_path: Path): r"""Tests to create tokenizers, save their config, and load it back. If all went well the tokenizer should be identical. """ - - for tokenization in ALL_TOKENIZATIONS: - tokenizer_config = miditok.TokenizerConfig(**ADDITIONAL_TOKENS_TEST) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - tokenizer.save_params(f"./tests/configs/{tokenization}.txt") - - tokenizer2: miditok.MIDITokenizer = getattr(miditok, tokenization)( - params=f"./tests/configs/{tokenization}.txt" - ) - assert tokenizer == tokenizer2 - if tokenization == "Octuple": - tokenizer.vocab[0]["PAD_None"] = 8 - assert tokenizer != tokenizer2 - - -if __name__ == "__main__": - test_saving_loading_tokenizer_config() - test_saving_loading_tokenizer() + tokenizer_config = miditok.TokenizerConfig(**ADDITIONAL_TOKENS_TEST) + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=tokenizer_config + ) + tokenizer.save_params(tmp_path / f"{tokenization}.txt") + + tokenizer2: miditok.MIDITokenizer = getattr(miditok, tokenization)( + params=tmp_path / f"{tokenization}.txt" + ) + assert tokenizer == tokenizer2 + if tokenization == "Octuple": + tokenizer.vocab[0]["PAD_None"] = 8 + assert tokenizer != tokenizer2 diff --git a/tests/test_tokenize_multitrack.py b/tests/test_tokenize_multitrack.py new file mode 100644 index 00000000..7d5dd6b3 --- /dev/null +++ b/tests/test_tokenize_multitrack.py @@ -0,0 +1,117 @@ +#!/usr/bin/python3 python + +"""Multitrack test file +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Sequence, Tuple, Union + +import pytest +from miditoolkit import MidiFile, Pedal + +import miditok + +from .utils import ( + ALL_TOKENIZATIONS, + MIDI_PATHS_MULTITRACK, + TEST_LOG_DIR, + TOKENIZER_CONFIG_KWARGS, + adjust_tok_params_for_tests, + prepare_midi_for_tests, + tokenize_and_check_equals, +) + +default_params = deepcopy(TOKENIZER_CONFIG_KWARGS) +default_params.update( + { + "use_chords": True, + "use_rests": True, # tempo decode fails when False for MIDILike because beat_res range is too short + "use_tempos": True, + "use_time_signatures": True, + "use_sustain_pedals": True, + "use_pitch_bends": True, + "use_programs": True, + "sustain_pedal_duration": False, + "one_token_stream_for_programs": True, + "program_changes": False, + } +) +TOK_PARAMS_MULTITRACK = [] +tokenizations_non_one_stream = [ + "TSD", + "REMI", + "MIDILike", + "Structured", + "CPWord", + "Octuple", +] +tokenizations_program_change = ["TSD", "REMI", "MIDILike"] +for tokenization_ in ALL_TOKENIZATIONS: + params_ = deepcopy(default_params) + adjust_tok_params_for_tests(tokenization_, params_) + TOK_PARAMS_MULTITRACK.append((tokenization_, params_)) + + if tokenization_ in tokenizations_non_one_stream: + params_tmp = deepcopy(params_) + params_tmp["one_token_stream_for_programs"] = False + # Disable tempos for Octuple with one_token_stream_for_programs, as tempos are carried by note tokens + if tokenization_ == "Octuple": + params_tmp["use_tempos"] = False + TOK_PARAMS_MULTITRACK.append((tokenization_, params_tmp)) + if tokenization_ in tokenizations_program_change: + params_tmp = deepcopy(params_) + params_tmp["program_changes"] = True + TOK_PARAMS_MULTITRACK.append((tokenization_, params_tmp)) + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS_MULTITRACK) +def test_multitrack_midi_to_tokens_to_midi( + midi_path: Union[str, Path], + tok_params_sets: Sequence[Tuple[str, Dict[str, Any]]] = None, + saving_erroneous_midis: bool = False, +): + r"""Reads a MIDI file, converts it into tokens, convert it back to a MIDI object. + The decoded MIDI should be identical to the original one after downsampling, and potentially notes deduplication. + We only parametrize for midi files, as it would otherwise require to load them multiple times each. + # TODO test parametrize tokenization / params_set + + :param midi_path: path to the MIDI file to test. + :param tok_params_sets: sequence of tokenizer and its parameters to run. + :param saving_erroneous_midis: will save MIDIs decoded with errors, to be used to debug. + """ + if tok_params_sets is None: + tok_params_sets = TOK_PARAMS_MULTITRACK + at_least_one_error = False + + # Reads the MIDI and add pedal messages + midi = MidiFile(Path(midi_path)) + for ti in range(max(3, len(midi.instruments))): + midi.instruments[ti].pedals = [ + Pedal(start, start + 200) for start in [100, 600, 1800, 2200] + ] + + for tok_i, (tokenization, params) in enumerate(tok_params_sets): + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=miditok.TokenizerConfig(**params) + ) + + # Process the MIDI + # midi notes / tempos / time signature quantized with the line above + midi_to_compare = prepare_midi_for_tests(midi, tokenizer=tokenizer) + + # MIDI -> Tokens -> MIDI + decoded_midi, has_errors = tokenize_and_check_equals( + midi_to_compare, tokenizer, tok_i, midi_path.stem + ) + + if has_errors: + TEST_LOG_DIR.mkdir(exist_ok=True, parents=True) + at_least_one_error = True + if saving_erroneous_midis: + decoded_midi.dump(TEST_LOG_DIR / f"{midi_path.stem}_{tokenization}.mid") + midi_to_compare.dump( + TEST_LOG_DIR / f"{midi_path.stem}_{tokenization}_original.mid" + ) + + assert not at_least_one_error diff --git a/tests/test_tokenize_one_track.py b/tests/test_tokenize_one_track.py new file mode 100644 index 00000000..193ab7f6 --- /dev/null +++ b/tests/test_tokenize_one_track.py @@ -0,0 +1,113 @@ +#!/usr/bin/python3 python + +"""One track test file +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Sequence, Tuple, Union + +import pytest +from miditoolkit import MidiFile + +import miditok + +from .utils import ( + ALL_TOKENIZATIONS, + MIDI_PATHS_ONE_TRACK, + TEST_LOG_DIR, + TOKENIZER_CONFIG_KWARGS, + adjust_tok_params_for_tests, + prepare_midi_for_tests, + tokenize_and_check_equals, +) + +default_params = deepcopy(TOKENIZER_CONFIG_KWARGS) +default_params.update( + { + "use_chords": False, # set false to speed up tests as it takes some time on maestro MIDIs + "use_rests": True, + "use_tempos": True, + "use_time_signatures": True, + "use_sustain_pedals": True, + "use_pitch_bends": True, + "use_pitch_intervals": True, + "log_tempos": True, + "chord_unknown": False, + "delete_equal_successive_time_sig_changes": True, + "delete_equal_successive_tempo_changes": True, + "sustain_pedal_duration": True, + } +) +TOK_PARAMS_ONE_TRACK = [] +for tokenization_ in ALL_TOKENIZATIONS: + params_ = deepcopy(default_params) + adjust_tok_params_for_tests(tokenization_, params_) + TOK_PARAMS_ONE_TRACK.append((tokenization_, params_)) + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS_ONE_TRACK) +def test_one_track_midi_to_tokens_to_midi( + midi_path: Union[str, Path], + tok_params_sets: Sequence[Tuple[str, Dict[str, Any]]] = None, + saving_erroneous_midis: bool = True, +): + r"""Reads a MIDI file, converts it into tokens, convert it back to a MIDI object. + The decoded MIDI should be identical to the original one after downsampling, and potentially notes deduplication. + We only parametrize for midi files, as it would otherwise require to load them multiple times each. + # TODO test parametrize tokenization / params_set, if faster --> unique method for test tok (one+multi) + + :param midi_path: path to the MIDI file to test. + :param tok_params_sets: sequence of tokenizer and its parameters to run. + :param saving_erroneous_midis: will save MIDIs decoded with errors, to be used to debug. + """ + if tok_params_sets is None: + tok_params_sets = TOK_PARAMS_ONE_TRACK + at_least_one_error = False + + # Reads the midi + midi = MidiFile(midi_path) + # Will store the tracks tokenized / detokenized, to be saved in case of errors + for ti, track in enumerate(midi.instruments): + track.name = f"original {ti} not quantized" + tracks_with_errors = [] + + for tok_i, (tokenization, params) in enumerate(tok_params_sets): + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=miditok.TokenizerConfig(**params) + ) + + # Process the MIDI + # preprocess_midi is also performed when tokenizing, but we need to call it here for following adaptations + midi_to_compare = prepare_midi_for_tests(midi, tokenizer=tokenizer) + # Store preprocessed track + if len(tracks_with_errors) == 0: + tracks_with_errors += midi_to_compare.instruments + for ti, track in enumerate(midi_to_compare.instruments): + track.name = f"original {ti} quantized" + + # printing the tokenizer shouldn't fail + _ = str(tokenizer) + + # MIDI -> Tokens -> MIDI + decoded_midi, has_errors = tokenize_and_check_equals( + midi_to_compare, tokenizer, tok_i, midi_path.stem + ) + + # Add track to error list + if has_errors: + at_least_one_error = True + for ti, track in enumerate(decoded_midi.instruments): + track.name = f"{tok_i} encoded with {tokenization}" + tracks_with_errors += decoded_midi.instruments + + # > 1 as the first one is the preprocessed + if len(tracks_with_errors) > len(midi.instruments): + if saving_erroneous_midis: + TEST_LOG_DIR.mkdir(exist_ok=True, parents=True) + midi.tempo_changes = midi_to_compare.tempo_changes + midi.time_signature_changes = midi_to_compare.time_signature_changes + midi.instruments += tracks_with_errors + midi.dump(TEST_LOG_DIR / midi_path.name) + + assert not at_least_one_error diff --git a/tests/test_utils.py b/tests/test_utils.py index 9e2f5f0a..5b9c926b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,8 +6,18 @@ from copy import deepcopy from pathlib import Path - -from miditoolkit import MidiFile +from typing import Union + +import pytest +from miditoolkit import ( + ControlChange, + KeySignature, + MidiFile, + Pedal, + PitchBend, + TempoChange, + TimeSignature, +) from miditok import REMI from miditok.constants import CLASS_OF_INST @@ -18,18 +28,113 @@ nb_bar_pos, ) +from .utils import MIDI_PATHS_MULTITRACK, MIDI_PATHS_ONE_TRACK, check_midis_equals + + +def test_containers_assertions(): + tc1 = [TempoChange(120, 2), TempoChange(110, 0)] + tc2 = [TempoChange(120, 3), TempoChange(110, 0)] + tc3 = [TempoChange(120, 3), TempoChange(110, 0)] + assert tc1 != tc2 + assert tc2 == tc3 + + ts1 = [TimeSignature(4, 4, 0), TimeSignature(6, 4, 10)] + ts2 = [TimeSignature(2, 4, 0), TimeSignature(6, 4, 10)] + ts3 = [TimeSignature(2, 4, 0), TimeSignature(6, 4, 10)] + assert ts1 != ts2 + assert ts2 == ts3 + + sp1 = [Pedal(0, 2), TempoChange(10, 20)] + sp2 = [Pedal(0, 2), TempoChange(15, 20)] + sp3 = [Pedal(0, 2), TempoChange(15, 20)] + assert sp1 != sp2 + assert sp2 == sp3 + + pb1 = [PitchBend(120, 2), PitchBend(110, 0)] + pb2 = [PitchBend(120, 3), PitchBend(110, 0)] + pb3 = [PitchBend(120, 3), PitchBend(110, 0)] + assert pb1 != pb2 + assert pb2 == pb3 + + ks1 = [KeySignature("C#", 2), KeySignature("C#", 0)] + ks2 = [KeySignature("C#", 20), KeySignature("C#", 0)] + ks3 = [KeySignature("C#", 20), KeySignature("C#", 0)] + assert ks1 != ks2 + assert ks2 == ks3 + + cc1 = [ControlChange(120, 50, 2), ControlChange(110, 50, 0)] + cc2 = [ControlChange(120, 50, 2), ControlChange(110, 50, 10)] + cc3 = [ControlChange(120, 50, 2), ControlChange(110, 50, 10)] + assert cc1 != cc2 + assert cc2 == cc3 + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS_ONE_TRACK) +def test_check_midi_equals(midi_path: Path): + midi = MidiFile(midi_path) + midi_copy = deepcopy(midi) + + # Check when midi is untouched + assert check_midis_equals(midi, midi_copy)[1] + + # Altering notes + i = 0 + while i < len(midi_copy.instruments): + if len(midi_copy.instruments[i].notes) > 0: + midi_copy.instruments[i].notes[-1].pitch += 5 + assert not check_midis_equals(midi, midi_copy)[1] + break + i += 1 + + # Altering track events + if len(midi_copy.instruments) > 0: + # Altering pedals + midi_copy = deepcopy(midi) + if len(midi_copy.instruments[0].pedals) == 0: + midi_copy.instruments[0].pedals.append(Pedal(0, 10)) + else: + midi_copy.instruments[0].pedals[-1].end += 10 + assert not check_midis_equals(midi, midi_copy)[1] -def test_merge_tracks(): - midi = MidiFile(Path("tests", "One_track_MIDIs", "Maestro_1.mid")) + # Altering pitch bends + midi_copy = deepcopy(midi) + if len(midi_copy.instruments[0].pitch_bends) == 0: + midi_copy.instruments[0].pitch_bends.append(PitchBend(50, 10)) + else: + midi_copy.instruments[0].pitch_bends[-1].end += 10 + assert not check_midis_equals(midi, midi_copy)[1] + + # Altering tempos + midi_copy = deepcopy(midi) + if len(midi_copy.tempo_changes) == 0: + midi_copy.tempo_changes.append(TempoChange(50, 10)) + else: + midi_copy.tempo_changes[-1].time += 10 + assert not check_midis_equals(midi, midi_copy)[1] + + # Altering time signatures + midi_copy = deepcopy(midi) + if len(midi_copy.time_signature_changes) == 0: + midi_copy.time_signature_changes.append(TimeSignature(4, 4, 10)) + else: + midi_copy.time_signature_changes[-1].time += 10 + assert not check_midis_equals(midi, midi_copy)[1] + + +def test_merge_tracks( + midi_path: Union[str, Path] = MIDI_PATHS_ONE_TRACK[0], +): + # Load MIDI and only keep the first track + midi = MidiFile(midi_path) + midi.instruments = [midi.instruments[0]] + + # Duplicate the track and merge it original_track = deepcopy(midi.instruments[0]) midi.instruments.append(deepcopy(midi.instruments[0])) - merge_tracks(midi.instruments) - assert len(midi.instruments[0].notes) == 2 * len(original_track.notes) # Test merge with effects - midi.instruments.append(deepcopy(midi.instruments[0])) merge_tracks(midi, effects=True) - assert len(midi.instruments[0].notes) == 4 * len(original_track.notes) + assert len(midi.instruments[0].notes) == 2 * len(original_track.notes) assert len(midi.instruments[0].pedals) == 2 * len(original_track.pedals) assert len(midi.instruments[0].control_changes) == 2 * len( original_track.control_changes @@ -37,44 +142,37 @@ def test_merge_tracks(): assert len(midi.instruments[0].pitch_bends) == 2 * len(original_track.pitch_bends) -def test_merge_same_program_tracks_and_by_class(): - multitrack_midi_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid")) - for midi_path in multitrack_midi_paths: - midi = MidiFile(midi_path) - for track in midi.instruments: - if track.is_drum: - track.program = -1 - - # Test merge same program - midi_copy = deepcopy(midi) - programs = [track.program for track in midi_copy.instruments] - unique_programs = list(set(programs)) - merge_same_program_tracks(midi_copy.instruments) - new_programs = [track.program for track in midi_copy.instruments] - unique_programs.sort() - new_programs.sort() - assert new_programs == unique_programs - - # Test merge same class - midi_copy = deepcopy(midi) - merge_tracks_per_class( - midi_copy, - CLASS_OF_INST, - valid_programs=list(range(-1, 128)), - filter_pitches=True, - ) +@pytest.mark.parametrize("midi_path", MIDI_PATHS_MULTITRACK) +def test_merge_same_program_tracks_and_by_class(midi_path: Union[str, Path]): + midi = MidiFile(midi_path) + for track in midi.instruments: + if track.is_drum: + track.program = -1 + + # Test merge same program + midi_copy = deepcopy(midi) + programs = [track.program for track in midi_copy.instruments] + unique_programs = list(set(programs)) + merge_same_program_tracks(midi_copy.instruments) + new_programs = [track.program for track in midi_copy.instruments] + unique_programs.sort() + new_programs.sort() + assert new_programs == unique_programs + + # Test merge same class + midi_copy = deepcopy(midi) + merge_tracks_per_class( + midi_copy, + CLASS_OF_INST, + valid_programs=list(range(-1, 128)), + filter_pitches=True, + ) def test_nb_pos(): tokenizer = REMI() _ = nb_bar_pos( - tokenizer(Path("tests", "One_track_MIDIs", "Maestro_1.mid"))[0].ids, + tokenizer(MIDI_PATHS_ONE_TRACK[0])[0].ids, tokenizer["Bar_None"], tokenizer.token_ids_of_type("Position"), ) - - -if __name__ == "__main__": - test_merge_tracks() - test_merge_same_program_tracks_and_by_class() - test_nb_pos() diff --git a/tests/tests_utils.py b/tests/tests_utils.py deleted file mode 100644 index a7f5ce7c..00000000 --- a/tests/tests_utils.py +++ /dev/null @@ -1,300 +0,0 @@ -""" Test validation methods - -""" - -from typing import List, Tuple, Union - -import numpy as np -from miditoolkit import ( - Instrument, - Marker, - MidiFile, - Note, - Pedal, - PitchBend, - TempoChange, - TimeSignature, -) - -import miditok -from miditok.constants import TIME_SIGNATURE_RANGE - -ALL_TOKENIZATIONS = [ - "MIDILike", - "TSD", - "Structured", - "REMI", - "CPWord", - "Octuple", - "MuMIDI", - "MMM", -] -TIME_SIGNATURE_RANGE_TESTS = TIME_SIGNATURE_RANGE -TIME_SIGNATURE_RANGE_TESTS.update({2: [2, 3, 4]}) -TIME_SIGNATURE_RANGE_TESTS[4].append(8) - - -def midis_equals( - midi1: MidiFile, midi2: MidiFile -) -> List[Tuple[int, str, List[Tuple[str, Union[Note, int], int]]]]: - errors = [] - for track1, track2 in zip(midi1.instruments, midi2.instruments): - track_errors = track_equals(track1, track2) - if len(track_errors) > 0: - errors.append((track1.program, track1.name, track_errors)) - return errors - - -def track_equals( - track1: Instrument, track2: Instrument -) -> List[Tuple[str, Union[Note, int], int]]: - if len(track1.notes) != len(track2.notes): - return [("len", len(track2.notes), len(track1.notes))] - errors = [] - for note1, note2 in zip(track1.notes, track2.notes): - err = notes_equals(note1, note2) - if err != "": - errors.append((err, note2, getattr(note1, err))) - return errors - - -def notes_equals(note1: Note, note2: Note) -> str: - if note1.start != note2.start: - return "start" - elif note1.end != note2.end: - return "end" - elif note1.pitch != note2.pitch: - return "pitch" - elif note1.velocity != note2.velocity: - return "velocity" - return "" - - -def tempo_changes_equals( - tempo_changes1: List[TempoChange], tempo_changes2: List[TempoChange] -) -> List[Tuple[str, Union[TempoChange, int], float]]: - if len(tempo_changes1) != len(tempo_changes2): - return [("len", len(tempo_changes2), len(tempo_changes1))] - errors = [] - for tempo_change1, tempo_change2 in zip(tempo_changes1, tempo_changes2): - if tempo_change1.time != tempo_change2.time: - errors.append(("time", tempo_change1, tempo_change2.time)) - if tempo_change1.tempo != tempo_change2.tempo: - errors.append(("tempo", tempo_change1, tempo_change2.tempo)) - return errors - - -def time_signature_changes_equals( - time_sig_changes1: List[TimeSignature], time_sig_changes2: List[TimeSignature] -) -> List[Tuple[str, Union[TimeSignature, int], float]]: - if len(time_sig_changes1) != len(time_sig_changes2): - return [("len", len(time_sig_changes1), len(time_sig_changes2))] - errors = [] - for time_sig_change1, time_sig_change2 in zip(time_sig_changes1, time_sig_changes2): - if time_sig_change1.time != time_sig_change2.time: - errors.append(("time", time_sig_change1, time_sig_change2.time)) - if time_sig_change1.numerator != time_sig_change2.numerator: - errors.append(("numerator", time_sig_change1, time_sig_change2.numerator)) - if time_sig_change1.denominator != time_sig_change2.denominator: - errors.append( - ("denominator", time_sig_change1, time_sig_change2.denominator) - ) - return errors - - -def pedal_equals( - midi1: MidiFile, midi2: MidiFile -) -> List[List[Tuple[str, Union[Pedal, int], float]]]: - errors = [] - for inst1, inst2 in zip(midi1.instruments, midi2.instruments): - if len(inst1.pedals) != len(inst2.pedals): - errors.append([("len", len(inst1.pedals), len(inst2.pedals))]) - continue - errors.append([]) - for pedal1, pedal2 in zip(inst1.pedals, inst2.pedals): - if pedal1.start != pedal2.start: - errors[-1].append(("start", pedal1, pedal2.start)) - elif pedal1.end != pedal2.end: - errors[-1].append(("end", pedal1, pedal2.end)) - return errors - - -def pitch_bend_equals( - midi1: MidiFile, midi2: MidiFile -) -> List[List[Tuple[str, Union[PitchBend, int], float]]]: - errors = [] - for inst1, inst2 in zip(midi1.instruments, midi2.instruments): - if len(inst1.pitch_bends) != len(inst2.pitch_bends): - errors.append([("len", len(inst1.pitch_bends), len(inst2.pitch_bends))]) - continue - errors.append([]) - for pitch_bend1, pitch_bend2 in zip(inst1.pitch_bends, inst2.pitch_bends): - if pitch_bend1.time != pitch_bend2.time: - errors[-1].append(("time", pitch_bend1, pitch_bend2.time)) - elif pitch_bend1.pitch != pitch_bend2.pitch: - errors[-1].append(("pitch", pitch_bend1, pitch_bend2.pitch)) - return errors - - -def tokenize_check_equals( - midi: MidiFile, - tokenizer: miditok.MIDITokenizer, - file_idx: Union[int, str], - file_name: str, -) -> Tuple[MidiFile, bool]: - has_errors = False - tokenization = type(tokenizer).__name__ - midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) - # merging is performed in preprocess only in one_token_stream mode - # but in multi token stream, decoding will actually keep one track per program - if tokenizer.config.use_programs: - miditok.utils.merge_same_program_tracks(midi.instruments) - - tokens = tokenizer(midi) - midi_decoded = tokenizer( - tokens, - miditok.utils.get_midi_programs(midi), - time_division=midi.ticks_per_beat, - ) - midi_decoded.instruments.sort(key=lambda x: (x.program, x.is_drum)) - if tokenization == "MIDILike": - for track in midi_decoded.instruments: - track.notes.sort(key=lambda x: (x.start, x.pitch, x.end)) - - # Checks types and values conformity following the rules - err_tse = tokenizer.tokens_errors(tokens) - if isinstance(err_tse, list): - err_tse = sum(err_tse) - if err_tse != 0.0: - print( - f"Validation of tokens types / values successions failed with {tokenization}: {err_tse:.2f}" - ) - - # Checks notes - errors = midis_equals(midi, midi_decoded) - if len(errors) > 0: - has_errors = True - for e, track_err in enumerate(errors): - if track_err[-1][0][0] != "len": - for err, note, exp in track_err[-1]: - midi_decoded.markers.append( - Marker( - f"{e}: with note {err} (pitch {note.pitch})", - note.start, - ) - ) - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode NOTES" - f"({sum(len(t[2]) for t in errors)} errors)" - ) - - # Checks tempos - if ( - tokenizer.config.use_tempos and tokenization != "MuMIDI" - ): # MuMIDI doesn't decode tempos - tempo_errors = tempo_changes_equals( - midi.tempo_changes, midi_decoded.tempo_changes - ) - if len(tempo_errors) > 0: - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode TEMPO changes" - f"({len(tempo_errors)} errors)" - ) - - # Checks time signatures - if tokenizer.config.use_time_signatures: - time_sig_errors = time_signature_changes_equals( - midi.time_signature_changes, - midi_decoded.time_signature_changes, - ) - if len(time_sig_errors) > 0: - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode TIME SIGNATURE changes" - f"({len(time_sig_errors)} errors)" - ) - - # Checks pedals - if tokenizer.config.use_sustain_pedals: - pedal_errors = pedal_equals(midi, midi_decoded) - if any(len(err) > 0 for err in pedal_errors): - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode PEDALS" - f"({sum(len(err) for err in pedal_errors)} errors)" - ) - - # Checks pitch bends - if tokenizer.config.use_pitch_bends: - pitch_bend_errors = pitch_bend_equals(midi, midi_decoded) - if any(len(err) > 0 for err in pitch_bend_errors): - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode PITCH BENDS" - f"({sum(len(err) for err in pitch_bend_errors)} errors)" - ) - - # TODO check control changes - - return midi_decoded, has_errors - - -def adapt_tempo_changes_times( - tracks: List[Instrument], tempo_changes: List[TempoChange] -): - r"""Will adapt the times of tempo changes depending on the - onset times of the notes of the MIDI. - This is needed to pass the tempo tests for Octuple as the tempos - will be decoded only from the notes. - - :param tracks: tracks of the MIDI to adapt the tempo changes - :param tempo_changes: tempo changes to adapt - """ - notes = sum((t.notes for t in tracks), []) - notes.sort(key=lambda x: x.start) - max_tick = max(note.start for note in notes) - - current_note_idx = 0 - tempo_idx = 1 - while tempo_idx < len(tempo_changes): - if tempo_changes[tempo_idx].time > max_tick: - del tempo_changes[tempo_idx] - continue - for n, note in enumerate(notes[current_note_idx:]): - if note.start >= tempo_changes[tempo_idx].time: - tempo_changes[tempo_idx].time = note.start - current_note_idx += n - break - if tempo_changes[tempo_idx].time == tempo_changes[tempo_idx - 1].time: - del tempo_changes[tempo_idx - 1] - continue - tempo_idx += 1 - - -def adjust_pedal_durations( - pedals: List[Pedal], tokenizer: miditok.MIDITokenizer, time_division: int -): - durations_in_tick = np.array( - [ - (beat * res + pos) * time_division // res - for beat, pos, res in tokenizer.durations - ] - ) - for pedal in pedals: - dur_index = np.argmin(np.abs(durations_in_tick - pedal.duration)) - beat, pos, res = tokenizer.durations[dur_index] - dur_index_in_tick = (beat * res + pos) * time_division // res - pedal.end = pedal.start + dur_index_in_tick - pedal.duration = pedal.end - pedal.start - - -def remove_equal_successive_tempos(tempo_changes: List[TempoChange]): - current_tempo = -1 - i = 0 - while i < len(tempo_changes): - if tempo_changes[i].tempo == current_tempo: - del tempo_changes[i] - continue - current_tempo = tempo_changes[i].tempo - i += 1 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..66d02796 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,360 @@ +""" +Test validation methods. +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from miditoolkit import ( + Instrument, + Marker, + MidiFile, + Note, + Pedal, + TempoChange, + TimeSignature, +) + +import miditok +from miditok.constants import CHORD_MAPS, TIME_SIGNATURE, TIME_SIGNATURE_RANGE + +SEED = 777 + +HERE = Path(__file__).parent +MIDI_PATHS_ONE_TRACK = sorted((HERE / "MIDIs_one_track").rglob("*.mid")) +MIDI_PATHS_MULTITRACK = sorted((HERE / "MIDIs_multitrack").rglob("*.mid")) +MIDI_PATHS_ALL = sorted( + deepcopy(MIDI_PATHS_ONE_TRACK) + deepcopy(MIDI_PATHS_MULTITRACK) +) +TEST_LOG_DIR = HERE / "test_logs" + +# TOKENIZATIONS +ALL_TOKENIZATIONS = miditok.tokenizations.__all__ +TOKENIZATIONS_BPE = ["REMI", "MIDILike", "TSD", "MMM", "Structured"] + +# TOK CONFIG PARAMS +TIME_SIGNATURE_RANGE_TESTS = TIME_SIGNATURE_RANGE +TIME_SIGNATURE_RANGE_TESTS.update({2: [2, 3, 4]}) +TIME_SIGNATURE_RANGE_TESTS[4].append(8) +TOKENIZER_CONFIG_KWARGS = { + "beat_res": {(0, 4): 8, (4, 12): 4, (12, 16): 2}, + "beat_res_rest": {(0, 2): 4, (2, 12): 2}, + "num_tempos": 32, + "tempo_range": (40, 250), + "time_signature_range": TIME_SIGNATURE_RANGE_TESTS, + "chord_maps": CHORD_MAPS, + "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" + "chord_unknown": (3, 6), + "delete_equal_successive_time_sig_changes": True, + "delete_equal_successive_tempo_changes": True, +} + + +def adjust_tok_params_for_tests(tokenization: str, params: Dict[str, Any]): + """Adjusts parameters (as dictionary for keyword arguments) depending on the tokenization. + + :param tokenization: tokenization. + :param params: parameters as a dictionary of keyword arguments. + """ + # Increase the TimeShift voc for Structured as it doesn't support successive TimeShifts. + if tokenization == "Structured": + params["beat_res"] = {(0, 512): 8} + # We don't test time signatures with Octuple as it can lead to time shifts, as the TS changes are only + # detectable at the onset times of the notes. + elif tokenization == "Octuple": + params["max_bar_embedding"] = 300 + params["use_time_signatures"] = False + # Rests and time sig can mess up with CPWord, when a Rest that is crossing new bar is followed + # by a new TimeSig change, as TimeSig are carried with Bar tokens (and there is None is this case). + elif tokenization == "CPWord": + if params["use_time_signatures"] and params["use_rests"]: + params["use_rests"] = False + + +def prepare_midi_for_tests( + midi: MidiFile, sort_notes: bool = False, tokenizer: miditok.MIDITokenizer = None +) -> MidiFile: + """Prepares a midi for test by returning a copy with tracks sorted, and optionally notes. + It also + + :param midi: midi reference. + :param sort_notes: whether to sort the notes. This is not necessary before tokenizing a MIDI, as the sorting + will be performed by the tokenizer. (default: False) + :param tokenizer: in order to downsample the MIDI before sorting its content. + :return: a new MIDI object with track (and notes) sorted. + """ + tokenization = type(tokenizer).__name__ if tokenizer is not None else None + new_midi = deepcopy(midi) + + # Downsamples the MIDI if a tokenizer is given + if tokenizer is not None: + tokenizer.preprocess_midi(new_midi) + + # For Octuple/CPWord, as tempo is only carried at notes times, we need to adapt their times for comparison + # Set tempo changes at onset times of notes + # We use the first track only, as it is the one for which tempos are decoded + if tokenizer.config.use_tempos and tokenization in ["Octuple", "CPWord"]: + if len(new_midi.instruments) > 0: + adapt_tempo_changes_times( + [new_midi.instruments[0]], new_midi.tempo_changes + ) + else: + new_midi.tempo_changes = [TempoChange(tokenizer._DEFAULT_TEMPO, 0)] + if ( + tokenizer.config.use_time_signatures + and tokenization in ["Octuple", "CPWord", "MMM"] + and len(new_midi.instruments) == 0 + ): + new_midi.time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] + + for track in new_midi.instruments: + # Adjust notes and pedal ends to the maximum possible value + if tokenizer is not None: + adjust_notes_durations(track.notes, tokenizer, midi.ticks_per_beat) + if tokenizer.config.use_sustain_pedals: + adjust_pedal_durations(track.pedals, tokenizer, midi.ticks_per_beat) + if track.is_drum: + track.program = 0 # need to be done before sorting tracks per program + if sort_notes: + track.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) + + # Sorts tracks + # MIDI detokenized with one_token_stream contains tracks sorted by note occurrence + new_midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) + + return new_midi + + +def midis_notes_equals( + midi1: MidiFile, midi2: MidiFile +) -> List[Tuple[int, str, List[Tuple[str, Union[Note, int], int]]]]: + """Checks if the notes from two MIDIs are all equal, and if not returns the list of errors. + + :param midi1: first MIDI. + :param midi2: second MIDI. + :return: list of errors. + """ + errors = [] + for track1, track2 in zip(midi1.instruments, midi2.instruments): + track_errors = tracks_notes_equals(track1, track2) + if len(track_errors) > 0: + errors.append((track1.program, track1.name, track_errors)) + return errors + + +def tracks_notes_equals( + track1: Instrument, track2: Instrument +) -> List[Tuple[str, Union[Note, int], int]]: + if len(track1.notes) != len(track2.notes): + return [("len", len(track2.notes), len(track1.notes))] + errors = [] + for note1, note2 in zip(track1.notes, track2.notes): + err = notes_equals(note1, note2) + if err != "": + errors.append((err, note2, getattr(note1, err))) + return errors + + +def notes_equals(note1: Note, note2: Note) -> str: + if note1.start != note2.start: + return "start" + elif note1.end != note2.end: + return "end" + elif note1.pitch != note2.pitch: + return "pitch" + elif note1.velocity != note2.velocity: + return "velocity" + return "" + + +def check_midis_equals( + midi1: MidiFile, + midi2: MidiFile, + check_tempos: bool = True, + check_time_signatures: bool = True, + check_pedals: bool = True, + check_pitch_bends: bool = True, + log_prefix: str = "", +) -> Tuple[MidiFile, bool]: + has_errors = False + types_of_errors = [] + + # Checks notes and add markers if errors + errors = midis_notes_equals(midi1, midi2) + if len(errors) > 0: + has_errors = True + for e, track_err in enumerate(errors): + if track_err[-1][0][0] != "len": + for err, note, exp in track_err[-1]: + midi2.markers.append( + Marker( + f"{e}: with note {err} (pitch {note.pitch})", + note.start, + ) + ) + print( + f"{log_prefix} failed to encode/decode NOTES ({sum(len(t[2]) for t in errors)} errors)" + ) + + # Check pedals + if check_pedals: + for inst1, inst2 in zip(midi1.instruments, midi2.instruments): + if inst1.pedals != inst2.pedals: + types_of_errors.append("PEDALS") + break + + # Check pitch bends + if check_pitch_bends: + for inst1, inst2 in zip(midi1.instruments, midi2.instruments): + if inst1.pitch_bends != inst2.pitch_bends: + types_of_errors.append("PITCH BENDS") + break + + """# Check control changes + if check_control_changes: + for inst1, inst2 in zip(midi1.instruments, midi2.instruments): + if inst1.control_changes != inst2.control_changes: + types_of_errors.append("CONTROL CHANGES") + break""" + + # Checks tempos + if check_tempos: + if midi1.tempo_changes != midi2.tempo_changes: + types_of_errors.append("TEMPOS") + + # Checks time signatures + if check_time_signatures: + if midi1.time_signature_changes != midi2.time_signature_changes: + types_of_errors.append("TIME SIGNATURES") + + # Prints types of errors + has_errors = has_errors or len(types_of_errors) > 0 + for err_type in types_of_errors: + print(f"{log_prefix} failed to encode/decode {err_type}") + + return midi2, not has_errors + + +def tokenize_and_check_equals( + midi: MidiFile, + tokenizer: miditok.MIDITokenizer, + file_idx: Union[int, str], + file_name: str, +) -> Tuple[MidiFile, bool]: + tokenization = type(tokenizer).__name__ + log_prefix = f"MIDI {file_idx} - {file_name} / {tokenization}" + midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) + # merging is performed in preprocess only in one_token_stream mode + # but in multi token stream, decoding will actually keep one track per program + if tokenizer.config.use_programs: + miditok.utils.merge_same_program_tracks(midi.instruments) + + # Tokenize and detokenize + tokens = tokenizer(midi) + midi_decoded = tokenizer( + tokens, + miditok.utils.get_midi_programs(midi) if len(midi.instruments) > 0 else None, + time_division=midi.ticks_per_beat, + ) + midi_decoded = prepare_midi_for_tests( + midi_decoded, sort_notes=tokenization == "MIDILike" + ) + + # Check decoded MIDI is identical + midi_decoded, no_error = check_midis_equals( + midi, + midi_decoded, + check_tempos=tokenizer.config.use_tempos and not tokenization == "MuMIDI", + check_time_signatures=tokenizer.config.use_time_signatures, + check_pedals=tokenizer.config.use_sustain_pedals, + check_pitch_bends=tokenizer.config.use_pitch_bends, + log_prefix=log_prefix, + ) + + # Checks types and values conformity following the rules + err_tse = tokenizer.tokens_errors(tokens) + if isinstance(err_tse, list): + err_tse = sum(err_tse) + if err_tse != 0.0: + no_error = False + print(f"{log_prefix} Validation of tokens types / values successions failed") + + return midi_decoded, not no_error + + +def adapt_tempo_changes_times( + tracks: List[Instrument], tempo_changes: List[TempoChange] +): + r"""Will adapt the times of tempo changes depending on the + onset times of the notes of the MIDI. + This is needed to pass the tempo tests for Octuple as the tempos + will be decoded only from the notes. + + :param tracks: tracks of the MIDI to adapt the tempo changes + :param tempo_changes: tempo changes to adapt + """ + notes = sum((t.notes for t in tracks), []) + notes.sort(key=lambda x: x.start) + max_tick = max(note.start for note in notes) + + current_note_idx = 0 + tempo_idx = 1 + while tempo_idx < len(tempo_changes): + if tempo_changes[tempo_idx].time > max_tick: + del tempo_changes[tempo_idx] + continue + for n, note in enumerate(notes[current_note_idx:]): + if note.start >= tempo_changes[tempo_idx].time: + tempo_changes[tempo_idx].time = note.start + current_note_idx += n + break + if tempo_changes[tempo_idx].time == tempo_changes[tempo_idx - 1].time: + del tempo_changes[tempo_idx - 1] + continue + tempo_idx += 1 + + +def adjust_notes_durations( + notes: List[Note], tokenizer: miditok.MIDITokenizer, time_division: int +): + """Adapt notes offset times so that they match the possible durations covered by a tokenizer. + + :param notes: list of Note objects to adapt. + :param tokenizer: tokenizer (needed for durations). + :param time_division: time division of the MIDI of origin. + """ + durations_in_tick = np.array( + [ + (beat * res + pos) * time_division // res + for beat, pos, res in tokenizer.durations + ] + ) + for note in notes: + dur_index = np.argmin(np.abs(durations_in_tick - note.duration)) + beat, pos, res = tokenizer.durations[dur_index] + dur_index_in_tick = (beat * res + pos) * time_division // res + note.end = note.start + dur_index_in_tick + + +def adjust_pedal_durations( + pedals: List[Pedal], tokenizer: miditok.MIDITokenizer, time_division: int +): + """Adapt pedal offset times so that they match the possible durations covered by a tokenizer. + + :param pedals: list of Pedal objects to adapt. + :param tokenizer: tokenizer (needed for durations). + :param time_division: time division of the MIDI of origin. + """ + durations_in_tick = np.array( + [ + (beat * res + pos) * time_division // res + for beat, pos, res in tokenizer.durations + ] + ) + for pedal in pedals: + dur_index = np.argmin(np.abs(durations_in_tick - pedal.duration)) + beat, pos, res = tokenizer.durations[dur_index] + dur_index_in_tick = (beat * res + pos) * time_division // res + pedal.end = pedal.start + dur_index_in_tick