diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 96bcdf2a..3860eb4f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools flake8 pytest coverage torch tensorflow + pip install setuptools flake8 pytest-xdist[psutil] coverage torch tensorflow pip install -r requirements.txt - name: Lint with flake8 run: | @@ -35,6 +35,6 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - coverage run -m pytest + coverage run -m pytest -n auto - name: Codecov uses: codecov/codecov-action@v3.1.0 diff --git a/README.md b/README.md index c2de7307..5e2f66c6 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Python package to tokenize MIDI music files, presented at the ISMIR 2021 LBD. [![GitHub CI](https://github.com/Natooz/MidiTok/actions/workflows/pytest.yml/badge.svg)](https://github.com/Natooz/MidiTok/actions/workflows/pytest.yml) [![Codecov](https://img.shields.io/codecov/c/github/Natooz/MidiTok)](https://codecov.io/gh/Natooz/MidiTok) [![GitHub license](https://img.shields.io/github/license/Natooz/MidiTok.svg)](https://github.com/Natooz/MidiTok/blob/main/LICENSE) -[![Downloads](https://pepy.tech/badge/MidiTok)](https://pepy.tech/project/MidiTok) +[![Downloads](https://static.pepy.tech/badge/miditok)](https://pepy.tech/project/MidiTok) [![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) Using Deep Learning with symbolic music ? MidiTok can take care of converting (tokenizing) your MIDI files into tokens, ready to be fed to models such as Transformer, for any generation, transcription or MIR task. diff --git a/docs/additional_tokens_table.csv b/docs/additional_tokens_table.csv new file mode 100644 index 00000000..6805f444 --- /dev/null +++ b/docs/additional_tokens_table.csv @@ -0,0 +1,9 @@ +Tokenization,Tempo,Time signature,Chord,Rest +MIDILike,✅,✅,✅,✅ +REMI,✅,✅,✅,✅ +TSD,✅,✅,✅,✅ +Structured,❌,❌,❌,❌ +CPWord,✅,✅,✅,✅ +Octuple,✅,✅,❌,❌ +MuMIDI,✅,❌,✅,❌ +MMM,✅,✅,✅,❌ \ No newline at end of file diff --git a/docs/midi_tokenizer.rst b/docs/midi_tokenizer.rst index 55e69453..27b3a627 100644 --- a/docs/midi_tokenizer.rst +++ b/docs/midi_tokenizer.rst @@ -62,72 +62,10 @@ Additional tokens MidiTok offers to include additional tokens on music information. You can specify them in the ``tokenizer_config`` argument (:class:`miditok.TokenizerConfig`) when creating a tokenizer. The :class:`miditok.TokenizerConfig` documentations specifically details the role of each of them, and their associated parameters. Cells with ❕ markers means the additional token is implemented by default and not optionnal. -.. list-table:: Compatibility table of tokenizations and additional tokens. +.. csv-table:: Compatibility table of tokenizations and additional tokens. + :file: additional_tokens_table.csv :header-rows: 1 - * - Token type - - :ref:`REMI` - - :ref:`REMIPlus` - - :ref:`MIDI-Like` - - :ref:`TSD` - - :ref:`Structured` - - :ref:`CPWord` - - :ref:`Octuple` - - :ref:`MuMIDI` - - :ref:`MMM` - * - Chord - - ✅ - - ✅ - - ✅ - - ✅ - - ✅ - - ❌ - - ❌ - - ✅ - - ✅ - * - Rest - - ✅ - - ✅ - - ✅ - - ✅ - - ✅ - - ❌ - - ❌ - - ❌ - - ❌ - * - Tempo - - ✅ - - ✅ - - ✅ - - ✅ - - ✅ - - ❌ - - ✅ - - ✅ - - ✅ - * - Program - - ✅¹ - - ✅¹ - - ✅¹ - - ✅¹ - - ✅¹ - - ✅² - - ✅❕ - - ✅❕ - - ✅❕ - * - Time signature - - ✅ - - ✅ - - ✅ - - ✅ - - ❌ - - ❌ - - ✅ - - ❌ - - ✅ - -**¹** the tokenizer will add `Program` tokens before each `Pitch` / `NoteOn` token, and will treat all the tracks of a MIDI as a single sequence of tokens. -**²** unimplemented, the tokenizer's vocabulary will contain the `Program` tokens, but it will not use it. Special tokens ------------------------ @@ -148,7 +86,7 @@ Tokens & TokSequence input / output format Depending on the tokenizer at use, the **format** of the tokens returned by the ``midi_to_tokens`` method may vary, as well as the expected format for the ``tokens_to_midi`` method. The format is given by the ``tokenizer.io_format` property. For any tokenizer, the format is the same for both methods. -The format is deduced from the ``is_multi_voc`` and ``one_token_stream`` tokenizer properties. In short: **one_token_stream** being True means that the tokenizer will convert a MIDI file into a single stream of tokens for all instrument tracks, otherwise it will convert each track to a distinct token stream; **is_mult_voc** being True means that each token stream is a list of lists of tokens, of shape ``(T,C)`` for T time steps and C subtokens per time step. +The format is deduced from the ``is_multi_voc`` and ``one_token_stream`` tokenizer properties. **one_token_stream** being True means that the tokenizer will convert a MIDI file into a single stream of tokens for all instrument tracks, otherwise it will convert each track to a distinct token sequence. **is_mult_voc** being True means that each token stream is a list of lists of tokens, of shape ``(T,C)`` for T time steps and C subtokens per time step. This results in four situations, where I is the number of tracks, T is the number of tokens (or time steps) and C the number of subtokens per time step: @@ -163,7 +101,7 @@ Some tokenizer examples to illustrate: * **TSD** without ``config.use_programs`` will not have multiple vocabularies and will treat each MIDI track as a unique stream of tokens, hence it will convert MIDI files to a list of ``TokSequence`` objects, ``(I,T)`` format. * **TSD** with ``config.use_programs`` being True will convert all MIDI tracks to a single stream of tokens, hence one ``TokSequence`` object, ``(T)`` format. -* **CPWord** is a multi-voc tokenizer and treats each MIDI track as a distinct stream of tokens, hence it will convert MIDI files to a list of ``TokSequence`` objects with ``(I,T,C)`` format. +* **CPWord** is a multi-voc tokenizer, without ``config.use_programs`` it will treat each MIDI track as a distinct stream of tokens, hence it will convert MIDI files to a list of ``TokSequence`` objects with the ``(I,T,C)`` format. * **Octuple** is a multi-voc tokenizer and converts all MIDI track to a single stream of tokens, hence it will convert MIDI files to a ``TokSequence`` object, ``(T,C)`` format. diff --git a/docs/tokenizations.rst b/docs/tokenizations.rst index ea2ee45b..61ec2dea 100644 --- a/docs/tokenizations.rst +++ b/docs/tokenizations.rst @@ -85,13 +85,6 @@ Octuple :noindex: :show-inheritance: -Octuple Mono ------------------------- - -.. autoclass:: miditok.OctupleMono - :noindex: - :show-inheritance: - MuMIDI ------------------------ diff --git a/miditok/__init__.py b/miditok/__init__.py index 8a3095d0..c814be98 100644 --- a/miditok/__init__.py +++ b/miditok/__init__.py @@ -6,7 +6,6 @@ TSD, Structured, Octuple, - OctupleMono, CPWord, MuMIDI, MMM, @@ -44,7 +43,6 @@ def _tweak_config_before_creating_voc(self): "TSD", "Structured", "Octuple", - "OctupleMono", "CPWord", "MuMIDI", "MMM", diff --git a/miditok/classes.py b/miditok/classes.py index 6ccd9f11..70aba14f 100644 --- a/miditok/classes.py +++ b/miditok/classes.py @@ -26,7 +26,9 @@ NB_TEMPOS, TEMPO_RANGE, LOG_TEMPOS, + DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES, TIME_SIGNATURE_RANGE, + DELETE_EQUAL_SUCCESSIVE_TIME_SIG_CHANGES, PROGRAMS, CURRENT_VERSION_PACKAGE, ) @@ -164,8 +166,8 @@ class TokenizerConfig: add more TimeSignatureChange objects. (default: False) :param use_programs: will use ``Program`` tokens, if the tokenizer is compatible. Used to specify an instrument / MIDI program. The :ref:`Octuple`, :ref:`MMM` and :ref:`MuMIDI` tokenizers - use natively `Program` tokens, this option is always enabled. :ref:`TSD`, :ref:`REMI`, :ref:`REMIPlus`, - :ref:`MIDILike` and :ref:`Structured` will add `Program` tokens before each `Pitch` / `NoteOn` token to + use natively `Program` tokens, this option is always enabled. :ref:`TSD`, :ref:`REMI`, :ref:`MIDILike`, + :ref:`Structured` and :ref:`CPWord` will add `Program` tokens before each `Pitch` / `NoteOn` token to indicate its associated instrument and will treat all the tracks of a MIDI as a single sequence of tokens. :ref:`CPWord`, :ref:`Octuple` and :ref:`MuMIDI` add a `Program` tokens with the stacks of `Pitch`, `Velocity` and `Duration` tokens. (default: False) @@ -183,8 +185,25 @@ class TokenizerConfig: :param nb_tempos: number of tempos "bins" to use. (default: 32) :param tempo_range: range of minimum and maximum tempos within which the bins fall. (default: (40, 250)) :param log_tempos: will use log scaled tempo values instead of linearly scaled. (default: False) + :param delete_equal_successive_tempo_changes: setting this option True will delete identical successive tempo + changes when preprocessing a MIDI file after loading it. For examples, if a MIDI has two tempo changes + for tempo 120 at tick 1000 and the next one is for tempo 121 at tick 1200, during preprocessing the tempo + values are likely to be downsampled and become identical (120 or 121). If that's the case, the second + tempo change will be deleted and not tokenized. This parameter doesn't apply for tokenizations that natively + inject the tempo information at recurrent timings (e.g. Octuple). For others, note that setting it True + might reduce the number of `Tempo` tokens and in turn the recurrence of this information. Leave it False if + you want to have recurrent `Tempo` tokens, that you might inject yourself by adding `TempoChange` objects to + your MIDIs. (default: False) :param time_signature_range: range as a dictionary {denom_i: [num_i1, ..., num_in] / (min_num_i, max_num_i)}. (default: {4: [4]}) + :param delete_equal_successive_time_sig_changes: setting this option True will delete identical successive time + signature changes when preprocessing a MIDI file after loading it. For examples, if a MIDI has two time + signature changes for 4/4 at tick 1000 and the next one is also 4/4 at tick 1200, the second time signature + change will be deleted and not tokenized. This parameter doesn't apply for tokenizations that natively + inject the time signature information at recurrent timings (e.g. Octuple). For others, note that setting it + True might reduce the number of `TimeSig` tokens and in turn the recurrence of this information. Leave it + False if you want to have recurrent `TimeSig` tokens, that you might inject yourself by adding + `TimeSignatureChange` objects to your MIDIs. (default: False) :param programs: sequence of MIDI programs to use. Note that `-1` is used and reserved for drums tracks. (default: from -1 to 127 included) :param **kwargs: additional parameters that will be saved in `config.additional_params`. @@ -208,7 +227,11 @@ def __init__( nb_tempos: int = NB_TEMPOS, tempo_range: Tuple[int, int] = TEMPO_RANGE, log_tempos: bool = LOG_TEMPOS, - time_signature_range: Dict[int, Union[List[int], Tuple[int, int]]] = TIME_SIGNATURE_RANGE, + delete_equal_successive_tempo_changes: bool = DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES, + time_signature_range: Dict[ + int, Union[List[int], Tuple[int, int]] + ] = TIME_SIGNATURE_RANGE, + delete_equal_successive_time_sig_changes: bool = DELETE_EQUAL_SUCCESSIVE_TIME_SIG_CHANGES, programs: Sequence[int] = PROGRAMS, **kwargs, ): @@ -239,12 +262,20 @@ def __init__( self.nb_tempos: int = nb_tempos # nb of tempo bins for additional tempo tokens, quantized like velocities self.tempo_range: Tuple[int, int] = tempo_range # (min_tempo, max_tempo) self.log_tempos: bool = log_tempos + self.delete_equal_successive_tempo_changes = ( + delete_equal_successive_tempo_changes + ) # Time signature params self.time_signature_range: Dict[int, List[int]] = { - beat_res: list(range(beats[0], beats[1] + 1)) if isinstance(beats, tuple) else beats + beat_res: list(range(beats[0], beats[1] + 1)) + if isinstance(beats, tuple) + else beats for beat_res, beats in time_signature_range.items() } + self.delete_equal_successive_time_sig_changes = ( + delete_equal_successive_time_sig_changes + ) # Programs self.programs: Sequence[int] = programs diff --git a/miditok/constants.py b/miditok/constants.py index 149564fa..991cebde 100644 --- a/miditok/constants.py +++ b/miditok/constants.py @@ -62,9 +62,11 @@ NB_TEMPOS = 32 TEMPO_RANGE = (40, 250) # (min_tempo, max_tempo) LOG_TEMPOS = False # log or linear scale tempos +DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES = False # Time signature params -TIME_SIGNATURE_RANGE = {4: [4]} # {denom_i: [num_i1, ..., num_in] / (min_num_i, max_num_i)} +# {denom_i: [num_i1, ..., num_in] / (min_num_i, max_num_i)} +TIME_SIGNATURE_RANGE = {8: [3, 12, 6], 4: [5, 6, 3, 2, 1, 4]} # Programs PROGRAMS = list(range(-1, 128)) @@ -80,6 +82,7 @@ TIME_DIVISION = 384 # 384 and 480 are convenient as divisible by 4, 8, 12, 16, 24, 32 TEMPO = 120 TIME_SIGNATURE = (4, 4) +DELETE_EQUAL_SUCCESSIVE_TIME_SIG_CHANGES = False # Used with chords PITCH_CLASSES = [ diff --git a/miditok/midi_tokenizer.py b/miditok/midi_tokenizer.py index 4e9712dd..ff610033 100644 --- a/miditok/midi_tokenizer.py +++ b/miditok/midi_tokenizer.py @@ -423,25 +423,36 @@ def _quantize_tempos(self, tempos: List[TempoChange], time_division: int): :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed). """ ticks_per_sample = int(time_division / max(self.config.beat_res.values())) - prev_tempo = -1 + prev_tempo = TempoChange(-1, -1) i = 0 while i < len(tempos): # Quantize tempo value tempos[i].tempo = self.tempos[ np.argmin(np.abs(self.tempos - tempos[i].tempo)) ] - if tempos[i].tempo == prev_tempo: + if ( + self.config.delete_equal_successive_tempo_changes + and tempos[i].tempo == prev_tempo.tempo + ): del tempos[i] continue rest = tempos[i].time % ticks_per_sample tempos[i].time += ( -rest if rest <= ticks_per_sample / 2 else ticks_per_sample - rest ) - prev_tempo = tempos[i].tempo + + # If the current tempo is now at the same time as the previous one, we delete the previous + if tempos[i].time == prev_tempo.time: + prev_tempo = tempos[i] + del tempos[i - 1] + continue + + prev_tempo = tempos[i] i += 1 - @staticmethod - def _quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int): + def _quantize_time_signatures( + self, time_sigs: List[TimeSignature], time_division: int + ): r"""Quantize the time signature changes, delayed to the next bar. See MIDI 1.0 Detailed specifications, pages 54 - 56, for more information on delayed time signature messages. @@ -449,18 +460,19 @@ def _quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int :param time_sigs: time signature changes to quantize. :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed). """ - ticks_per_bar = MIDITokenizer._compute_ticks_per_bar(time_sigs[0], time_division) - current_bar = 0 + ticks_per_bar = MIDITokenizer._compute_ticks_per_bar( + time_sigs[0], time_division + ) previous_tick = 0 # first time signature change is always at tick 0 - prev_time_sig = time_sigs[0] + prev_ts = time_sigs[0] i = 1 while i < len(time_sigs): time_sig = time_sigs[i] - if (time_sig.numerator, time_sig.denominator) == ( - prev_time_sig.numerator, - prev_time_sig.denominator, - ) or time_sig.time == previous_tick: + if self.config.delete_equal_successive_time_sig_changes and ( + time_sig.numerator, + time_sig.denominator, + ) == (prev_ts.numerator, prev_ts.denominator): del time_sigs[i] continue @@ -473,13 +485,23 @@ def _quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int time_sig.time = previous_tick + bar_offset * ticks_per_bar # Update values - ticks_per_bar = MIDITokenizer._compute_ticks_per_bar(time_sig, time_division) - current_bar += bar_offset + ticks_per_bar = MIDITokenizer._compute_ticks_per_bar( + time_sig, time_division + ) + + # If the current time signature is now at the same time as the previous one, we delete the previous + if time_sig.time == previous_tick: + previous_tick = time_sig.time + del time_sigs[i - 1] + continue + previous_tick = time_sig.time - prev_time_sig = time_sig + prev_ts = time_sig i += 1 - def _midi_to_tokens(self, midi: MidiFile, *args, **kwargs) -> List[TokSequence]: + def _midi_to_tokens( + self, midi: MidiFile, *args, **kwargs + ) -> Union[TokSequence, List[TokSequence]]: r"""Converts a preprocessed MIDI object to a sequence of tokens. The workflow of this method is as follows: the events (Pitch, Velocity, Tempo, TimeSignature...) are gathered into a list, then the time events are added. If `one_token_stream` is true, all events of all tracks @@ -712,7 +734,7 @@ def complete_sequence(self, seq: TokSequence): """ if seq.tokens is None: if seq.events is not None: - seq.tokens = [str(event) for event in seq.events] + seq.tokens = self._events_to_tokens(seq.events) elif seq.ids is not None: seq.tokens = self._ids_to_tokens(seq.ids) elif seq.bytes is not None: @@ -770,6 +792,29 @@ def _ids_to_tokens( tokens.append(event_str if as_str else Event(*event_str.split("_"))) return tokens + @staticmethod + def _events_to_tokens( + events: List[Union[Event, List[Event]]] + ) -> List[Union[str, List[str]]]: + r"""Converts a sequence of Events to their associated tokens (str). + + :param events: sequence of Events to convert. + :return: the sequence of corresponding tokens (str). + """ + tokens = [] + if isinstance(events[0], list): # multiple vocabularies + for ( + multi_event + ) in events: # cannot use recursion here because of the vocabulary type id + multi_token = [] + for i, event in enumerate(multi_event): + multi_token.append(str(event)) + tokens.append(multi_token) + return tokens + + tokens = [str(event) for event in events] + return tokens + def _ids_to_bytes( self, ids: List[Union[int, List[int]]], as_one_str: bool = False ) -> Union[str, List[str]]: @@ -1141,8 +1186,9 @@ def __create_time_signatures(self) -> List[Tuple]: time_signatures = [] for beat_res, beats in time_signature_range.items(): - assert beat_res > 0 and math.log2(beat_res).is_integer(), \ - f"The beat resolution ({beat_res}) in time signature must be a power of 2" + assert ( + beat_res > 0 and math.log2(beat_res).is_integer() + ), f"The beat resolution ({beat_res}) in time signature must be a power of 2" time_signatures.extend([(nb_beats, beat_res) for nb_beats in beats]) @@ -1176,7 +1222,10 @@ def validate_midi_time_signatures(self, midi: MidiFile) -> bool: """ if self.config.use_time_signatures: for time_sig in midi.time_signature_changes: - if (time_sig.numerator, time_sig.denominator) not in self.time_signatures: + if ( + time_sig.numerator, + time_sig.denominator, + ) not in self.time_signatures: return False return True @@ -1743,10 +1792,7 @@ def _load_params(self, config_file_path: Union[str, Path]): for beat_range, res in value.items() } elif key == "time_signature_range": - value = { - int(res): beat_range - for res, beat_range in value.items() - } + value = {int(res): beat_range for res, beat_range in value.items()} # Convert old attribute from < v2.1.0 to new for TokenizerConfig elif key in old_add_tokens_attr: key = old_add_tokens_attr[key] diff --git a/miditok/tokenizations/__init__.py b/miditok/tokenizations/__init__.py index aca8a663..e52aa6a9 100644 --- a/miditok/tokenizations/__init__.py +++ b/miditok/tokenizations/__init__.py @@ -5,7 +5,6 @@ from .cp_word import CPWord from .mumidi import MuMIDI from .octuple import Octuple -from .octuple_mono import OctupleMono from .mmm import MMM __all__ = [ @@ -14,7 +13,6 @@ "TSD", "Structured", "Octuple", - "OctupleMono", "CPWord", "MuMIDI", "MMM", diff --git a/miditok/tokenizations/cp_word.py b/miditok/tokenizations/cp_word.py index 1d7436ce..7a63781f 100644 --- a/miditok/tokenizations/cp_word.py +++ b/miditok/tokenizations/cp_word.py @@ -1,13 +1,13 @@ from math import ceil from typing import List, Tuple, Dict, Optional, Union, Any +from pathlib import Path import numpy as np -from miditoolkit import MidiFile, Instrument, Note, TempoChange +from miditoolkit import MidiFile, Instrument, Note, TempoChange, TimeSignature -from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq +from ..midi_tokenizer import MIDITokenizer, _in_as_seq from ..classes import TokSequence, Event -from ..utils import detect_chords -from ..constants import TIME_DIVISION, TEMPO, MIDI_INSTRUMENTS +from ..constants import TIME_DIVISION, TEMPO, MIDI_INSTRUMENTS, TIME_SIGNATURE class CPWord(MIDITokenizer): @@ -31,16 +31,18 @@ class CPWord(MIDITokenizer): (one per token type). This means that the training requires to add multiple losses. For generation, the decoding implies sample from several distributions, which can be very delicate. Hence, we do not recommend this tokenization for generation with small models. + **Note:** When decoding multiple token sequences (of multiple tracks), i.e. when `config.use_programs` is False, + only the tempos and time signatures of the first sequence will be decoded for the whole MIDI. """ def _tweak_config_before_creating_voc(self): - self.config.use_time_signatures = False token_types = ["Family", "Position", "Pitch", "Velocity", "Duration"] for add_tok_attr, add_token in [ ("use_programs", "Program"), ("use_chords", "Chord"), ("use_rests", "Rest"), ("use_tempos", "Tempo"), + ("use_time_signatures", "TimeSig"), ]: if getattr(self.config, add_tok_attr): token_types.append(add_token) @@ -48,79 +50,74 @@ def _tweak_config_before_creating_voc(self): type_: idx for idx, type_ in enumerate(token_types) } # used for data augmentation self.vocab_types_idx["Bar"] = 1 # same as position + if self.config.use_programs: + self.one_token_stream = True - def _midi_to_tokens(self, midi: MidiFile, *args, **kwargs) -> List[TokSequence]: - # Convert each track to tokens - tokens = [] - for track in midi.instruments: - tokens.append(self._track_to_tokens(track)) - self.complete_sequence(tokens[-1]) - return tokens - - @_out_as_complete_seq - def _track_to_tokens(self, track: Instrument) -> TokSequence: - r"""Converts a track (miditoolkit.Instrument object) into a sequence of tokens (:class:`miditok.TokSequence`). + def _add_time_events(self, events: List[Event]) -> List[List[Event]]: + r""" + Takes a sequence of note events (containing optionally Chord, Tempo and TimeSignature tokens), + and insert (not inplace) time tokens (TimeShift, Rest) to complete the sequence. - :param track: MIDI track to convert - :return: :class:`miditok.TokSequence` of corresponding tokens. + :param events: note events to complete. + :return: the same events, with time events inserted. """ - # Make sure the notes are sorted first by their onset (start) times, second by pitch - # notes.sort(key=lambda x: (x.start, x.pitch)) # done in midi_to_tokens - ticks_per_sample = self._current_midi_metadata["time_division"] / max( - self.config.beat_res.values() - ) - ticks_per_bar = self._current_midi_metadata["time_division"] * 4 - dur_bins = self._durations_ticks[self._current_midi_metadata["time_division"]] + time_division = self._current_midi_metadata["time_division"] + ticks_per_sample = time_division / max(self.config.beat_res.values()) min_rest = ( - self._current_midi_metadata["time_division"] * self.rests[0][0] - + ticks_per_sample * self.rests[0][1] + time_division * self.rests[0][0] + ticks_per_sample * self.rests[0][1] if self.config.use_rests else 0 ) - tokens: List[List[Union[str, Event]]] = [] # list of lists of tokens - # Creates tokens - previous_tick = -1 - previous_note_end = ( - track.notes[0].start + 1 - ) # so that no rest is created before the first note + # Add time events + all_events = [] current_bar = -1 - current_tempo_idx = 0 - current_tempo = self._current_midi_metadata["tempo_changes"][ - current_tempo_idx - ].tempo - for note in track.notes: - # Bar / Position / (Tempo) / (Rest) - if note.start != previous_tick: + previous_tick = -1 + previous_note_end = 0 + current_time_sig = TIME_SIGNATURE + current_tempo = TEMPO + current_program = None + ticks_per_bar = self._compute_ticks_per_bar( + TimeSignature(*current_time_sig, 0), time_division + ) + for e, event in enumerate(events): + if event.type == "TimeSig": + current_time_sig = list(map(int, event.value.split("/"))) + ticks_per_bar = self._compute_ticks_per_bar( + TimeSignature(*current_time_sig, event.time), time_division + ) + elif event.type == "Tempo": + current_tempo = event.value + elif event.type == "Program": + current_program = event.value + if event.time != previous_tick: # (Rest) if ( - self.config.use_rests - and note.start > previous_note_end - and note.start - previous_note_end >= min_rest + event.type in ["Pitch", "Chord", "Tempo", "TimeSig"] + and self.config.use_rests + and event.time - previous_note_end >= min_rest ): previous_tick = previous_note_end rest_beat, rest_pos = divmod( - note.start - previous_tick, - self._current_midi_metadata["time_division"], + event.time - previous_tick, + time_division, ) rest_beat = min(rest_beat, max([r[0] for r in self.rests])) rest_pos = round(rest_pos / ticks_per_sample) if rest_beat > 0: - tokens.append( + all_events.append( self.__create_cp_token( previous_note_end, rest=f"{rest_beat}.0", desc="Rest" ) ) - previous_tick += ( - rest_beat * self._current_midi_metadata["time_division"] - ) + previous_tick += rest_beat * time_division while rest_pos >= self.rests[0][1]: rest_pos_temp = min( [r[1] for r in self.rests], key=lambda x: abs(x - rest_pos) ) - tokens.append( + all_events.append( self.__create_cp_token( previous_note_end, rest=f"0.{rest_pos_temp}", @@ -132,93 +129,53 @@ def _track_to_tokens(self, track: Instrument) -> TokSequence: current_bar = previous_tick // ticks_per_bar - # (Tempo) - if self.config.use_tempos: - # If the current tempo is not the last one - if current_tempo_idx + 1 < len( - self._current_midi_metadata["tempo_changes"] - ): - # Will loop over incoming tempo changes - for tempo_change in self._current_midi_metadata[ - "tempo_changes" - ][current_tempo_idx + 1 :]: - # If this tempo change happened before the current moment - if tempo_change.time <= note.start: - current_tempo = tempo_change.tempo - current_tempo_idx += ( - 1 # update tempo value (might not change) and index - ) - elif tempo_change.time > note.start: - break # this tempo change is beyond the current time step, we break the loop - # Bar - nb_new_bars = note.start // ticks_per_bar - current_bar + nb_new_bars = event.time // ticks_per_bar - current_bar for i in range(nb_new_bars): - tokens.append( + if self.config.use_time_signatures: + time_sig_arg = f"{current_time_sig[0]}/{current_time_sig[1]}" + else: + time_sig_arg = None + all_events.append( self.__create_cp_token( - (current_bar + i + 1) * ticks_per_bar, bar=True, desc="Bar" + (current_bar + i + 1) * ticks_per_bar, + bar=True, + desc="Bar", + time_signature=time_sig_arg, ) ) current_bar += nb_new_bars # Position - pos_index = int((note.start % ticks_per_bar) / ticks_per_sample) - tokens.append( + pos_index = int((event.time % ticks_per_bar) / ticks_per_sample) + all_events.append( self.__create_cp_token( - int(note.start), + event.time, pos=pos_index, tempo=current_tempo if self.config.use_tempos else None, desc="Position", ) ) - previous_tick = note.start - - # Note - duration = note.end - note.start - dur_index = np.argmin(np.abs(dur_bins - duration)) - dur_value = ".".join(map(str, self.durations[dur_index])) - tokens.append( - self.__create_cp_token( - int(note.start), - pitch=note.pitch, - vel=note.velocity, - dur=dur_value, - desc=f"{duration} ticks", - ) - ) - previous_note_end = max(previous_note_end, note.end) - - tokens.sort(key=lambda x: x[0].time) - - # Adds chord tokens if specified - if self.config.use_chords and not track.is_drum: - chord_events = detect_chords( - track.notes, - self._current_midi_metadata["time_division"], - chord_maps=self.config.chord_maps, - specify_root_note=self.config.chord_tokens_with_root_note, - beat_res=self._first_beat_res, - unknown_chords_nb_notes_range=self.config.chord_unknown, - ) - count = 0 - for chord_event in chord_events: - for e, cp_token in enumerate(tokens[count:]): - if ( - cp_token[0].time == chord_event.time - and cp_token[0].desc == "Position" - ): - cp_token[self.vocab_types_idx["Chord"]] = self[ - self.vocab_types_idx["Chord"], f"Chord_{chord_event.value}" - ] - count = e - break - # Convert the first element of each compound token from Event to int - for cp_token in tokens: - cp_token[0] = str(cp_token[0]) + previous_tick = event.time + + # Convert event to CP Event + # Update max offset time of the notes encountered + if event.type == "Pitch" and e + 2 < len(events): + all_events.append( + self.__create_cp_token( + event.time, + pitch=event.value, + vel=events[e + 1].value, + dur=events[e + 2].value, + program=current_program, + ) + ) + previous_note_end = max(previous_note_end, event.desc) + elif event.type == "Tempo": + previous_note_end = max(previous_note_end, event.time) - tokens: List[List[str]] = tokens # just to prevent IDE type warning - return TokSequence(tokens=tokens) + return all_events def __create_cp_token( self, @@ -231,9 +188,10 @@ def __create_cp_token( chord: str = None, rest: str = None, tempo: float = None, + time_signature: str = None, program: int = None, desc: str = "", - ) -> List[Union[Event, str]]: + ) -> List[Event]: r"""Create a CP Word token, with the following structure: (index. Token type) 0. Family @@ -245,6 +203,7 @@ def __create_cp_token( (6. Chord) optional, chords occurring with position tokens (7. Rest) optional, rest acting as a TimeShift token (8. Tempo) optional, occurring with position tokens + (9. TimeSig) optional, occurring with bar tokens NOTE: the first Family token (first in list) will be given as an Event object to keep track of time easily so that other method can sort CP tokens afterwards. @@ -261,122 +220,216 @@ def __create_cp_token( :param desc: an optional argument for debug and used to spot position tokens in track_to_tokens :return: The compound token as a list of integers """ - cp_token_template = [ + + def create_event(type_: str, value) -> Event: + return Event(type=type_, value=value, time=time, desc=desc) + + cp_token = [ Event(type="Family", value="Metric", time=time, desc=desc), - "Ignore_None", - "Ignore_None", - "Ignore_None", - "Ignore_None", + Event(type="Ignore", value="None", time=time, desc=desc), + Event(type="Ignore", value="None", time=time, desc=desc), + Event(type="Ignore", value="None", time=time, desc=desc), + Event(type="Ignore", value="None", time=time, desc=desc), ] - for add_tok_attr in ["use_programs", "use_chords", "use_rests", "use_tempos"]: + for add_tok_attr in [ + "use_programs", + "use_chords", + "use_rests", + "use_tempos", + "use_time_signatures", + ]: if getattr(self.config, add_tok_attr): - cp_token_template.append("Ignore_None") + cp_token.append(create_event("Ignore", "None")) if bar: - cp_token_template[1] = "Bar_None" + cp_token[1] = create_event("Bar", "None") + if time_signature is not None: + cp_token[self.vocab_types_idx["TimeSig"]] = create_event( + "TimeSig", time_signature + ) elif pos is not None: - cp_token_template[1] = f"Position_{pos}" + cp_token[1] = create_event("Position", pos) if chord is not None: - cp_token_template[self.vocab_types_idx["Chord"]] = f"Chord_{chord}" + cp_token[self.vocab_types_idx["Chord"]] = create_event("Chord", chord) if tempo is not None: - cp_token_template[self.vocab_types_idx["Tempo"]] = f"Tempo_{tempo}" + cp_token[self.vocab_types_idx["Tempo"]] = create_event("Tempo", tempo) elif rest is not None: - cp_token_template[self.vocab_types_idx["Rest"]] = f"Rest_{rest}" + cp_token[self.vocab_types_idx["Rest"]] = create_event("Rest", rest) elif pitch is not None: - cp_token_template[0].value = "Note" - cp_token_template[2] = f"Pitch_{pitch}" - cp_token_template[3] = f"Velocity_{vel}" - cp_token_template[4] = f"Duration_{dur}" + cp_token[0].value = "Note" + cp_token[2] = create_event("Pitch", pitch) + cp_token[3] = create_event("Velocity", vel) + cp_token[4] = create_event("Duration", dur) if program is not None: - cp_token_template[ - self.vocab_types_idx["Program"] - ] = f"Program_{program}" + cp_token[self.vocab_types_idx["Program"]] = create_event( + "Program", program + ) - return cp_token_template + return cp_token - def _tokens_to_track( + @_in_as_seq() + def tokens_to_midi( self, - tokens: TokSequence, - time_division: Optional[int] = TIME_DIVISION, - program: Optional[Tuple[int, bool]] = (0, False), - ) -> Tuple[Instrument, List[TempoChange]]: - r"""Converts a sequence of tokens into a track object - - :param tokens: sequence of tokens to convert. Can be either a Tensor (PyTorch and Tensorflow are supported), - a numpy array, a Python list or a TokSequence. - :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create) - :param program: the MIDI program of the produced track and if it drum, (default (0, False), piano) - :return: the miditoolkit instrument object and tempo changes + tokens: Union[ + Union[TokSequence, List, np.ndarray, Any], + List[Union[TokSequence, List, np.ndarray, Any]], + ], + programs: Optional[List[Tuple[int, bool]]] = None, + output_path: Optional[str] = None, + time_division: int = TIME_DIVISION, + ) -> MidiFile: + r"""Converts tokens (:class:`miditok.TokSequence`) into a MIDI and saves it. + + :param tokens: tokens to convert. Can be either a list of :class:`miditok.TokSequence`, + :param programs: programs of the tracks. If none is given, will default to piano, program 0. (default: None) + :param output_path: path to save the file. (default: None) + :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create). + :return: the midi object (:class:`miditoolkit.MidiFile`). """ + # Unsqueeze tokens in case of one_token_stream + if self.one_token_stream: # ie single token seq + tokens = [tokens] + for i in range(len(tokens)): + tokens[i] = tokens[i].tokens + midi = MidiFile(ticks_per_beat=time_division) assert ( time_division % max(self.config.beat_res.values()) == 0 ), f"Invalid time division, please give one divisible by {max(self.config.beat_res.values())}" - ticks_per_sample = time_division // max(self.config.beat_res.values()) - ticks_per_bar = time_division * 4 - name = "Drums" if program[1] else MIDI_INSTRUMENTS[program[0]]["name"] - instrument = Instrument(program[0], is_drum=program[1], name=name) - tempo_changes = [ - TempoChange(TEMPO, -1) - ] # mock the first tempo change to optimize below + + # RESULTS + instruments: Dict[int, Instrument] = {} + tempo_changes = [TempoChange(TEMPO, -1)] + time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division + ) # init current_tick = 0 current_bar = -1 + current_program = 0 previous_note_end = 0 - for compound_token in tokens.tokens: - token_family = compound_token[0].split("_")[1] - if token_family == "Note": - if any(tok.split("_")[1] == "None" for tok in compound_token[2:5]): - continue - pitch = int(compound_token[2].split("_")[1]) - vel = int(compound_token[3].split("_")[1]) - duration = self._token_duration_to_ticks( - compound_token[4].split("_")[1], time_division + for si, seq in enumerate(tokens): + # Set track / sequence program if needed + if not self.one_token_stream: + current_tick = 0 + current_bar = -1 + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division ) - instrument.notes.append( - Note(vel, pitch, current_tick, current_tick + duration) - ) - previous_note_end = max(previous_note_end, current_tick + duration) - elif token_family == "Metric": - if compound_token[1].split("_")[0] == "Bar": - current_bar += 1 - current_tick = current_bar * ticks_per_bar - elif ( - compound_token[1].split("_")[0] == "Position" - ): # i.e. its a position - if current_bar == -1: - current_bar = ( - 0 # as this Position token occurs before any Bar token - ) - current_tick = ( - current_bar * ticks_per_bar - + int(compound_token[1].split("_")[1]) * ticks_per_sample + previous_note_end = 0 + if programs is not None: + current_program = -1 if programs[si][1] else programs[si][0] + + # Decode tokens + for ti, compound_token in enumerate(seq): + token_family = compound_token[0].split("_")[1] + if token_family == "Note": + if any(tok.split("_")[1] == "None" for tok in compound_token[2:5]): + continue + pitch = int(compound_token[2].split("_")[1]) + vel = int(compound_token[3].split("_")[1]) + duration = self._token_duration_to_ticks( + compound_token[4].split("_")[1], time_division ) - if self.config.use_tempos: - tempo = float(compound_token[-1].split("_")[1]) - if tempo != tempo_changes[-1].tempo: - tempo_changes.append(TempoChange(tempo, current_tick)) - elif ( - self.config.use_rests - and compound_token[self.vocab_types_idx["Rest"]].split("_")[1] - != "None" - ): - if ( - current_tick < previous_note_end - ): # if in case successive rest happen - current_tick = previous_note_end - beat, pos = map( - int, - compound_token[self.vocab_types_idx["Rest"]] - .split("_")[1] - .split("."), + if self.config.use_programs: + current_program = int(compound_token[5].split("_")[1]) + if current_program not in instruments.keys(): + instruments[current_program] = Instrument( + program=0 if current_program == -1 else current_program, + is_drum=current_program == -1, + name="Drums" + if current_program == -1 + else MIDI_INSTRUMENTS[current_program]["name"], + ) + instruments[current_program].notes.append( + Note(vel, pitch, current_tick, current_tick + duration) ) - current_tick += beat * time_division + pos * ticks_per_sample - current_bar = current_tick // ticks_per_bar + previous_note_end = max(previous_note_end, current_tick + duration) + + elif token_family == "Metric": + bar_pos = compound_token[1].split("_")[0] + if bar_pos == "Bar": + current_bar += 1 + current_tick = current_bar * ticks_per_bar + # Add new TS only if different from the last one + if self.config.use_time_signatures and si == 0: + num, den = self._parse_token_time_signature( + compound_token[self.vocab_types_idx["TimeSig"]].split( + "_" + )[1] + ) + if ( + num != time_signature_changes[-1].numerator + and den != time_signature_changes[-1].denominator + ): + time_sig = TimeSignature(num, den, current_tick) + if si == 0: + time_signature_changes.append(time_sig) + ticks_per_bar = self._compute_ticks_per_bar( + time_sig, time_division + ) + elif bar_pos == "Position": # i.e. its a position + if current_bar == -1: + # in case this Position token comes before any Bar token + current_bar = 0 + current_tick = ( + current_bar * ticks_per_bar + + int(compound_token[1].split("_")[1]) * ticks_per_sample + ) + # Add new tempo change only if different from the last one + if self.config.use_tempos and si == 0: + tempo = float( + compound_token[self.vocab_types_idx["Tempo"]].split( + "_" + )[1] + ) + if ( + si == 0 + and tempo != tempo_changes[-1].tempo + and current_tick != tempo_changes[-1].time + ): + tempo_changes.append(TempoChange(tempo, current_tick)) + previous_note_end = max(previous_note_end, current_tick) + elif ( + self.config.use_rests + and compound_token[self.vocab_types_idx["Rest"]].split("_")[1] + != "None" + ): + if current_tick < previous_note_end: + # if in case successive rest happen + current_tick = previous_note_end + beat, pos = map( + int, + compound_token[self.vocab_types_idx["Rest"]] + .split("_")[1] + .split("."), + ) + current_tick += beat * time_division + pos * ticks_per_sample + current_bar = current_tick // ticks_per_bar + if len(tempo_changes) > 1: - del tempo_changes[0] + del tempo_changes[0] # delete mocked tempo change tempo_changes[0].time = 0 - return instrument, tempo_changes + if len(time_signature_changes) > 1: + del time_signature_changes[0] # delete mocked time signature change + time_signature_changes[0].time = 0 + + # create MidiFile + midi.instruments = list(instruments.values()) + midi.tempo_changes = tempo_changes + midi.time_signature_changes = time_signature_changes + midi.max_tick = max( + [ + max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 + for track in midi.instruments + ] + ) + # Write MIDI file + if output_path: + Path(output_path).mkdir(parents=True, exist_ok=True) + midi.dump(output_path) + return midi def _create_base_vocabulary(self) -> List[List[str]]: r"""Creates the vocabulary, as a list of string tokens. @@ -396,7 +449,9 @@ def _create_base_vocabulary(self) -> List[List[str]]: vocab[0].append("Family_Note") # POSITION - max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures)) + max_nb_beats = max( + map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures) + ) nb_positions = max(self.config.beat_res.values()) * max_nb_beats vocab[1].append("Ignore_None") vocab[1].append("Bar_None") @@ -438,6 +493,13 @@ def _create_base_vocabulary(self) -> List[List[str]]: if self.config.use_tempos: vocab += [["Ignore_None"] + [f"Tempo_{i}" for i in self.tempos]] + # TIME_SIGNATURE + if self.config.use_time_signatures: + vocab += [ + ["Ignore_None"] + + [f"TimeSig_{i[0]}/{i[1]}" for i in self.time_signatures] + ] + return vocab def _create_token_types_graph(self) -> Dict[str, List[str]]: @@ -467,6 +529,13 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: dic["Rest"] = ["Rest", "Position", "Bar"] dic["Pitch"] += ["Rest"] + if self.config.use_tempos: + # Because a tempo change can happen at any moment + dic["Position"] += ["Position", "Bar"] + if self.config.use_rests: + dic["Position"].append("Rest") + dic["Rest"].append("Position") + for key in dic: dic[key].append("Ignore") dic["Ignore"] = list(dic.keys()) @@ -513,7 +582,8 @@ def cp_token_type(tok: List[int]) -> List[str]: err = 0 previous_type = cp_token_type(tokens[0])[0] current_pos = -1 - current_pitches = [] + program = 0 + current_pitches = {p: [] for p in self.config.programs} for token in tokens[1:]: token_type, token_value = cp_token_type(token) @@ -521,18 +591,20 @@ def cp_token_type(tok: List[int]) -> List[str]: if token_type in self.tokens_types_graph[previous_type]: if token_type == "Bar": # reset current_pos = -1 - current_pitches = [] + current_pitches = {p: [] for p in self.config.programs} elif token_type == "Pitch": - if int(token_value) in current_pitches: + if self.config.use_programs: + program = int(self[5, token[5]].split("_")[1]) + if int(token_value) in current_pitches[program]: err += 1 # pitch already played at current position else: - current_pitches.append(int(token_value)) + current_pitches[program].append(int(token_value)) elif token_type == "Position": if int(token_value) <= current_pos and previous_type != "Rest": err += 1 # token position value <= to the current position else: current_pos = int(token_value) - current_pitches = [] + current_pitches = {p: [] for p in self.config.programs} # Bad token type else: err += 1 diff --git a/miditok/tokenizations/midi_like.py b/miditok/tokenizations/midi_like.py index 315a8470..5c431eaa 100644 --- a/miditok/tokenizations/midi_like.py +++ b/miditok/tokenizations/midi_like.py @@ -31,6 +31,8 @@ class MIDILike(MIDITokenizer): played when a second one is also played, the offset time of the first will be set to the onset time of the second. This is done to prevent unwanted duration alterations that could happen in such case, as the `NoteOff` token associated to the first note will also end the second one. + **Note:** When decoding multiple token sequences (of multiple tracks), i.e. when `config.use_programs` is False, + only the tempos and time signatures of the first sequence will be decoded for the whole MIDI. """ def _tweak_config_before_creating_voc(self): @@ -148,6 +150,8 @@ def _add_time_events(self, events: List[Event]) -> List[Event]: previous_note_end = max(previous_note_end, events[e + 2].desc) elif event.type in ["NoteOn", "Program"]: previous_note_end = max(previous_note_end, event.desc) + elif event.type == "Tempo": + previous_note_end = max(previous_note_end, event.time) elif event.type == "Chord" and e + 1 < len(events): # Next event is either a NoteOn or Program with the end info previous_note_end = max(previous_note_end, events[e + 1].desc) @@ -309,13 +313,15 @@ def tokens_to_midi( # If your encoding include tempo tokens, each Position token should be followed by # a tempo token, but if it is not the case this method will skip this step tempo = float(token.split("_")[1]) - if tempo != tempo_changes[-1].tempo: + if si == 0 and current_tick != tempo_changes[-1].time: tempo_changes.append(TempoChange(tempo, current_tick)) - elif token.split("_")[0] == "TimeSig": + previous_note_end = max(previous_note_end, current_tick) + elif si == 0 and token.split("_")[0] == "TimeSig": num, den = self._parse_token_time_signature(token.split("_")[1]) current_time_signature = time_signature_changes[-1] if ( - num != current_time_signature.numerator + si == 0 + and num != current_time_signature.numerator and den != current_time_signature.denominator ): time_signature_changes.append( diff --git a/miditok/tokenizations/mmm.py b/miditok/tokenizations/mmm.py index 26e5d81e..69bc2a62 100644 --- a/miditok/tokenizations/mmm.py +++ b/miditok/tokenizations/mmm.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Union, cast import numpy as np from miditoolkit import Instrument, MidiFile, Note, TempoChange, TimeSignature -from ..classes import Event, TokSequence, TokenizerConfig +from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, TEMPO, @@ -26,31 +26,19 @@ class MMM(MIDITokenizer): strategy of the [original paper](https://arxiv.org/abs/2008.06048). The reason being that ``NoteOff`` tokens perform poorer for generation with causal models. - :param tokenizer_config: the tokenizer's configuration, as a :class:`miditok.classes.TokenizerConfig` object. - :param density_bins_max: tuple specifying the number of density bins, and the maximum density in - notes per beat to consider. (default: (10, 20)) - :param params: path to a tokenizer config file. This will override other arguments and - load the tokenizer based on the config file. This is particularly useful if the - tokenizer learned Byte Pair Encoding. (default: None) - """ + **Add a `density_bins_max` entry in the config, mapping to a tuple specifying the number of density bins, and the + maximum density in notes per beat to consider. (default: (10, 20))** - def __init__( - self, - tokenizer_config: TokenizerConfig = None, - density_bins_max: Tuple[int, int] = MMM_DENSITY_BINS_MAX, - params: Optional[Union[str, Path]] = None, - ): - if ( - tokenizer_config is not None - and "density_bins_max" not in tokenizer_config.additional_params - ): - tokenizer_config.additional_params["density_bins_max"] = density_bins_max - super().__init__(tokenizer_config, True, params) + **Note:** When decoding tokens with tempos, only the tempos of the first track will be decoded. + """ def _tweak_config_before_creating_voc(self): + self.one_token_stream = True self.config.use_programs = True self.config.use_rests = False # Recreate densities here just in case density_bins_max was loaded from params (list to np array) + if "density_bins_max" not in self.config.additional_params: + self.config.additional_params["density_bins_max"] = MMM_DENSITY_BINS_MAX if "note_densities" in self.config.additional_params: if isinstance( self.config.additional_params["note_densities"], (list, tuple) @@ -260,11 +248,15 @@ def tokens_to_midi( time_signature_changes = [ TimeSignature(*TIME_SIGNATURE, 0) ] # mock the first time signature change to optimize below - ticks_per_bar = self._compute_ticks_per_bar(time_signature_changes[0], time_division) # init + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division + ) # init current_tick = 0 current_bar = -1 previous_note_end = 0 # unused (rest) + first_program = None + current_program = -2 for ti, token in enumerate(tokens): tok_type, tok_val = token.split("_") if tok_type == "Program": @@ -278,6 +270,8 @@ def tokens_to_midi( else MIDI_INSTRUMENTS[current_program]["name"], ) ) + if first_program is None: + first_program = current_program current_tick = 0 current_bar = -1 previous_note_end = 0 @@ -297,11 +291,13 @@ def tokens_to_midi( # as this Position token occurs before any Bar token current_bar = 0 current_tick += self._token_duration_to_ticks(tok_val, time_division) - elif tok_type == "Tempo": + elif ( + first_program is None or current_program == first_program + ) and tok_type == "Tempo": # If the tokenizer includes tempo tokens, each Position token should be followed by # a tempo token, but if it is not the case this method will skip this step tempo = float(token.split("_")[1]) - if tempo != tempo_changes[-1].tempo: + if current_tick != tempo_changes[-1].time: tempo_changes.append(TempoChange(tempo, current_tick)) elif tok_type == "TimeSig": num, den = self._parse_token_time_signature(token.split("_")[1]) @@ -444,6 +440,8 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: dic["TimeShift"] += ["Tempo"] if self.config.use_time_signatures: dic["TimeSig"] += ["Tempo"] + if self.config.use_chords: + dic["Tempo"] += ["Chord"] dic["Fill"] = list(dic.keys()) @@ -508,6 +506,7 @@ def tokens_errors( current_pitches.append(pitch_val) for i, token in enumerate(tokens[1:]): + # err_tokens = tokens[i - 4 : i + 4] # uncomment for debug event_type, event_value = token.split("_")[0], token.split("_")[1] # Good token type diff --git a/miditok/tokenizations/mumidi.py b/miditok/tokenizations/mumidi.py index c864d015..8606e41f 100644 --- a/miditok/tokenizations/mumidi.py +++ b/miditok/tokenizations/mumidi.py @@ -6,7 +6,7 @@ from miditoolkit import MidiFile, Instrument, Note, TempoChange from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq -from ..classes import TokSequence, Event, TokenizerConfig +from ..classes import TokSequence, Event from ..utils import detect_chords from ..constants import ( TIME_DIVISION, @@ -37,37 +37,25 @@ class MuMIDI(MIDITokenizer): For generation, the decoding implies sample from several distributions, which can be very delicate. Hence, we do not recommend this tokenization for generation with small models. + **Add a `drum_pitch_range` entry in the config, mapping to a tuple of values to restrict the range of drum pitches + to use.** + **Notes:** * Tokens are first sorted by time, then track, then pitch values. * Tracks with the same *Program* will be merged. - - :param tokenizer_config: the tokenizer's configuration, as a :class:`miditok.classes.TokenizerConfig` object. - :param drum_pitch_range: range of used MIDI pitches for drums exclusively - :param params: path to a tokenizer config file. This will override other arguments and - load the tokenizer based on the config file. This is particularly useful if the - tokenizer learned Byte Pair Encoding. (default: None) """ - - def __init__( - self, - tokenizer_config: TokenizerConfig = None, - drum_pitch_range: Tuple[int, int] = DRUM_PITCH_RANGE, - params: Union[str, Path] = None, - ): - if tokenizer_config is not None: - if "drum_pitch_range" not in tokenizer_config.additional_params: - tokenizer_config.additional_params[ - "drum_pitch_range" - ] = drum_pitch_range - if "max_bar_embedding" not in tokenizer_config.additional_params: - # this attribute might increase over tokenizations, if the tokenizer encounter longer MIDIs - tokenizer_config.additional_params["max_bar_embedding"] = 60 - super().__init__(tokenizer_config, True, params=params) - def _tweak_config_before_creating_voc(self): self.config.use_rests = False self.config.use_time_signatures = False - # self.one_token_stream = True + self.one_token_stream = True + + if "drum_pitch_range" not in self.config.additional_params: + self.config.additional_params[ + "drum_pitch_range" + ] = DRUM_PITCH_RANGE + if "max_bar_embedding" not in self.config.additional_params: + # this attribute might increase over tokenizations, if the tokenizer encounter longer MIDIs + self.config.additional_params["max_bar_embedding"] = 60 self.vocab_types_idx = { "Pitch": 0, @@ -410,7 +398,9 @@ def _create_base_vocabulary(self) -> List[List[str]]: for i in range(*self.config.additional_params["drum_pitch_range"]) ] vocab[0] += ["Bar_None"] # new bar token - max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures)) + max_nb_beats = max( + map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures) + ) nb_positions = max(self.config.beat_res.values()) * max_nb_beats vocab[0] += [f"Position_{i}" for i in range(nb_positions)] vocab[0] += [f"Program_{program}" for program in self.config.programs] diff --git a/miditok/tokenizations/octuple.py b/miditok/tokenizations/octuple.py index cc1ed426..686a5929 100644 --- a/miditok/tokenizations/octuple.py +++ b/miditok/tokenizations/octuple.py @@ -1,11 +1,11 @@ from math import ceil from pathlib import Path -from typing import List, Dict, Optional, Union, Any +from typing import List, Tuple, Dict, Optional, Union, Any import numpy as np from miditoolkit import MidiFile, Instrument, Note, TempoChange, TimeSignature -from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq +from ..midi_tokenizer import MIDITokenizer, _in_as_seq from ..classes import TokSequence, Event from ..constants import ( TIME_DIVISION, @@ -24,9 +24,9 @@ class Octuple(MIDITokenizer): * 0: Pitch * 1: Velocity * 2: Duration - * 3: Program (track) - * 4: Position - * 5: Bar + * 3: Position + * 4: Bar + * (+ Optional) Program * (+ Optional) Tempo * (+ Optional) TimeSignature @@ -39,313 +39,283 @@ class Octuple(MIDITokenizer): **Notes:** * Tokens are first sorted by time, then track, then pitch values. * Tracks with the same *Program* will be merged. + * When decoding multiple token sequences (of multiple tracks), i.e. when `config.use_programs` is False, + only the tempos and time signatures of the first sequence will be decoded for the whole MIDI. """ def _tweak_config_before_creating_voc(self): self.config.use_chords = False self.config.use_rests = False - self.one_token_stream = True + if self.config.use_programs: + self.one_token_stream = True + self.config.delete_equal_successive_tempo_changes = True # used in place of positional encoding # This attribute might increase over tokenizations, if the tokenizer encounter longer MIDIs if "max_bar_embedding" not in self.config.additional_params: self.config.additional_params["max_bar_embedding"] = 60 - token_types = ["Pitch", "Velocity", "Duration", "Program", "Position", "Bar"] + token_types = ["Pitch", "Velocity", "Duration", "Position", "Bar"] + if self.config.use_programs: + token_types.append("Program") if self.config.use_tempos: token_types.append("Tempo") if self.config.use_time_signatures: - token_types.append("TimeSignature") + token_types.append("TimeSig") self.vocab_types_idx = { type_: idx for idx, type_ in enumerate(token_types) } # used for data augmentation - @_out_as_complete_seq - def _midi_to_tokens(self, midi: MidiFile, *args, **kwargs) -> TokSequence: - r"""Override the parent class method - Converts a MIDI file in a token representation, a sequence of "time steps". + def _add_time_events(self, events: List[Event]) -> List[List[Event]]: + r""" + Takes a sequence of note events (containing optionally Chord, Tempo and TimeSignature tokens), + and insert (not inplace) time tokens (TimeShift, Rest) to complete the sequence. A time step is a list of tokens where: (list index: token type) 0: Pitch 1: Velocity 2: Duration - 3: Program (track) - 4: Position - 5: Bar + 3: Position + 4: Bar + (5: Program) (6: Tempo) (7: TimeSignature) - :param midi: the MIDI object to convert - :return: sequences of tokens + :param events: note events to complete. + :return: the same events, with time events inserted. """ - # Check bar embedding limit, update if needed - nb_bars = ceil(midi.max_tick / (midi.ticks_per_beat * 4)) - if self.config.additional_params["max_bar_embedding"] < nb_bars: - for i in range(self.config.additional_params["max_bar_embedding"], nb_bars): - self.add_to_vocab(f"Bar_{i}", 5) - self.config.additional_params["max_bar_embedding"] = nb_bars - - # Convert each track to tokens - tokens = [] - for track in midi.instruments: - if track.program in self.config.programs: - tokens += self._track_to_tokens(track) + time_division = self._current_midi_metadata["time_division"] + ticks_per_sample = time_division / max(self.config.beat_res.values()) - tokens.sort( - key=lambda x: (x[0].time, x[0].desc, x[0].value) - ) # Sort by time, then track, then pitch + # Add time events + all_events = [] + current_bar = 0 + current_bar_from_ts_time = 0 + current_tick_from_ts_time = 0 + current_pos = 0 + previous_tick = 0 + current_time_sig = TIME_SIGNATURE + current_tempo = TEMPO + current_program = None + ticks_per_bar = self._compute_ticks_per_bar( + TimeSignature(*current_time_sig, 0), time_division + ) + for e, event in enumerate(events): + # Set current bar and position + # This is done first, as we need to compute these values with the current ticks_per_bar, + # which might change if the current event is a TimeSig + if event.time != previous_tick: + elapsed_tick = event.time - current_tick_from_ts_time + current_bar = current_bar_from_ts_time + elapsed_tick // ticks_per_bar + current_pos = int((elapsed_tick % ticks_per_bar) / ticks_per_sample) + previous_tick = event.time + + if event.type == "TimeSig": + current_time_sig = list(map(int, event.value.split("/"))) + current_bar_from_ts_time = current_bar + current_tick_from_ts_time = previous_tick + ticks_per_bar = self._compute_ticks_per_bar( + TimeSignature(*current_time_sig, event.time), time_division + ) + elif event.type == "Tempo": + current_tempo = event.value + elif event.type == "Program": + current_program = event.value + elif event.type == "Pitch" and e + 2 < len(events): + new_event = [ + Event(type="Pitch", value=event.value, time=event.time), + Event(type="Velocity", value=events[e + 1].value, time=event.time), + Event(type="Duration", value=events[e + 2].value, time=event.time), + Event(type="Position", value=current_pos, time=event.time), + Event(type="Bar", value=current_bar, time=event.time), + ] + if self.config.use_programs: + new_event.append(Event("Program", current_program)) + if self.config.use_tempos: + new_event.append(Event(type="Tempo", value=current_tempo)) + if self.config.use_time_signatures: + new_event.append( + Event( + type="TimeSig", + value=f"{current_time_sig[0]}/{current_time_sig[1]}", + ) + ) + all_events.append(new_event) - # Convert pitch events into tokens - for time_step in tokens: - time_step[0] = str(time_step[0]) + return all_events - return TokSequence(tokens=tokens) + def _midi_to_tokens( + self, midi: MidiFile, *args, **kwargs + ) -> Union[TokSequence, List[TokSequence]]: + r"""Converts a preprocessed MIDI object to a sequence of tokens. + The workflow of this method is as follows: the events (Pitch, Velocity, Tempo, TimeSignature...) are + gathered into a list, then the time events are added. If `one_token_stream` is true, all events of all tracks + are treated all at once, otherwise the events of each track are treated independently. - def _track_to_tokens(self, track: Instrument) -> List[List[Union[Event, str]]]: - r"""Converts a track (miditoolkit.Instrument object) into a sequence of tokens (:class:`miditok.TokSequence`). A time step is a list of tokens where: (list index: token type) - 0: Pitch (as an Event object for sorting purpose afterwards) + 0: Pitch 1: Velocity 2: Duration - 3: Program (track) - 4: Position - 5: Bar + 3: Position + 4: Bar + (5: Program) (6: Tempo) (7: TimeSignature) - :param track: track object to convert - :return: sequence of corresponding tokens + :param midi: the MIDI object to convert + :return: sequences of tokens """ - # Make sure the notes are sorted first by their onset (start) times, second by pitch - # notes.sort(key=lambda x: (x.start, x.pitch)) # done in midi_to_tokens - time_division = self._current_midi_metadata["time_division"] - ticks_per_sample = time_division / max(self.config.beat_res.values()) - dur_bins = self._durations_ticks[time_division] - - tokens = [] - current_tick = -1 - current_bar = -1 - current_pos = -1 - current_tempo_idx = 0 - current_tempo = self._current_midi_metadata["tempo_changes"][ - current_tempo_idx - ].tempo - current_time_sig_idx = 0 - current_time_sig_tick = 0 - current_time_sig_bar = 0 - current_time_sig = self._current_midi_metadata["time_sig_changes"][ - current_time_sig_idx - ] - ticks_per_bar = self._compute_ticks_per_bar(current_time_sig, time_division) - - for note in track.notes: - # Positions and bars - if note.start != current_tick: - pos_index = int( - ((note.start - current_time_sig_tick) % ticks_per_bar) - / ticks_per_sample - ) - current_tick = note.start - current_bar = ( - current_time_sig_bar - + (current_tick - current_time_sig_tick) // ticks_per_bar - ) - current_pos = pos_index - - # Note attributes - duration = note.end - note.start - dur_index = np.argmin(np.abs(dur_bins - duration)) - token = [ - Event( - type="Pitch", - value=note.pitch, - time=note.start, - desc=-1 if track.is_drum else track.program, - ), - f"Velocity_{note.velocity}", - f'Duration_{".".join(map(str, self.durations[dur_index]))}', - f"Program_{-1 if track.is_drum else track.program}", - f"Position_{current_pos}", - f"Bar_{current_bar}", - ] + # Check bar embedding limit, update if needed + nb_bars = ceil(midi.max_tick / (midi.ticks_per_beat * 4)) + if self.config.additional_params["max_bar_embedding"] < nb_bars: + for i in range(self.config.additional_params["max_bar_embedding"], nb_bars): + self.add_to_vocab(f"Bar_{i}", 4) + self.config.additional_params["max_bar_embedding"] = nb_bars - # (Tempo) - if self.config.use_tempos: - # If the current tempo is not the last one - if current_tempo_idx + 1 < len( - self._current_midi_metadata["tempo_changes"] - ): - # Will loop over incoming tempo changes - for tempo_change in self._current_midi_metadata["tempo_changes"][ - current_tempo_idx + 1 : - ]: - # If this tempo change happened before the current moment - if tempo_change.time <= note.start: - current_tempo = tempo_change.tempo - current_tempo_idx += ( - 1 # update tempo value (might not change) and index - ) - elif tempo_change.time > note.start: - break # this tempo change is beyond the current time step, we break the loop - token.append(f"Tempo_{current_tempo}") - - # (TimeSignature) - if self.config.use_time_signatures: - # If the current time signature is not the last one - if current_time_sig_idx + 1 < len( - self._current_midi_metadata["time_sig_changes"] - ): - # Will loop over incoming time signature changes - for time_sig_change in self._current_midi_metadata[ - "time_sig_changes" - ][current_time_sig_idx + 1 :]: - # If this time signature change happened before the current moment - if time_sig_change.time <= note.start: - current_time_sig = time_sig_change - current_time_sig_idx += 1 # update time signature value (might not change) and index - current_time_sig_bar += ( - time_sig_change.time - current_time_sig_tick - ) // ticks_per_bar - current_time_sig_tick = time_sig_change.time - ticks_per_bar = self._compute_ticks_per_bar(current_time_sig, time_division) - elif time_sig_change.time > note.start: - break # this time signature change is beyond the current time step, we break the loop - token.append(f"TimeSig_{current_time_sig.numerator}/{current_time_sig.denominator}") - - tokens.append(token) - - return tokens + return super()._midi_to_tokens(midi, *args, **kwargs) @_in_as_seq() def tokens_to_midi( self, - tokens: Union[TokSequence, List, np.ndarray, Any], - _=None, + tokens: Union[ + Union[TokSequence, List, np.ndarray, Any], + List[Union[TokSequence, List, np.ndarray, Any]], + ], + programs: Optional[List[Tuple[int, bool]]] = None, output_path: Optional[str] = None, - time_division: Optional[int] = TIME_DIVISION, + time_division: int = TIME_DIVISION, ) -> MidiFile: - r"""Override the parent class method - Convert multiple sequences of tokens into a multitrack MIDI and save it. - The tokens will be converted to event objects and then to a miditoolkit.MidiFile object. - A time step is a list of tokens where (list index: token type): - * 0: Pitch - * 1: Velocity - * 2: Duration - * 3: Program (track) - * 4: Position - * 5: Bar - * (6: Tempo) - * (7: TimeSignature) - :param tokens: tokens to convert. Can be either a Tensor (PyTorch and Tensorflow are supported), - a numpy array, a Python list or a TokSequence. - :param _: unused, to match parent method signature - :param output_path: path to save the file (with its name, e.g. music.mid), - leave None to not save the file - :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create) - :return: the midi object (miditoolkit.MidiFile) + r"""Converts tokens (:class:`miditok.TokSequence`) into a MIDI and saves it. + + :param tokens: tokens to convert. Can be either a list of :class:`miditok.TokSequence`, + :param programs: programs of the tracks. If none is given, will default to piano, program 0. (default: None) + :param output_path: path to save the file. (default: None) + :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create). + :return: the midi object (:class:`miditoolkit.MidiFile`). """ + # Unsqueeze tokens in case of one_token_stream + if self.one_token_stream: # ie single token seq + tokens = [tokens] + for i in range(len(tokens)): + tokens[i] = tokens[i].tokens + midi = MidiFile(ticks_per_beat=time_division) assert ( time_division % max(self.config.beat_res.values()) == 0 ), f"Invalid time division, please give one divisible by {max(self.config.beat_res.values())}" - midi = MidiFile(ticks_per_beat=time_division) ticks_per_sample = time_division // max(self.config.beat_res.values()) - tokens = tokens.tokens - - tempo_changes = [TempoChange(TEMPO, 0)] - if self.config.use_tempos: - for i in range(len(tokens)): - if tokens[i][6].split("_")[1] != "None": - tempo_changes = [TempoChange(float(tokens[i][6].split("_")[1]), 0)] - break - time_sig = TIME_SIGNATURE - if self.config.use_time_signatures: - for i in range(len(tokens)): - if tokens[i][-1].split("_")[1] != "None": - time_sig = self._parse_token_time_signature( - tokens[i][-1].split("_")[1] - ) - break - - time_sig = TimeSignature(*time_sig, 0) - ticks_per_bar = self._compute_ticks_per_bar(time_sig, time_division) - time_sig_changes = [time_sig] - - current_time_sig_tick = 0 - current_time_sig_bar = 0 - - tracks = dict([(n, []) for n in self.config.programs]) - for time_step in tokens: - if any(tok.split("_")[1] == "None" for tok in time_step[:6]): - continue # Either padding, mask: error of prediction or end of sequence anyway - - # Note attributes - pitch = int(time_step[0].split("_")[1]) - vel = int(time_step[1].split("_")[1]) - duration = self._token_duration_to_ticks( - time_step[2].split("_")[1], time_division - ) - - # Time and track values - program = int(time_step[3].split("_")[1]) - current_pos = int(time_step[4].split("_")[1]) - current_bar = int(time_step[5].split("_")[1]) - current_tick = ( - current_time_sig_tick - + (current_bar - current_time_sig_bar) * ticks_per_bar - + current_pos * ticks_per_sample - ) - - # Append the created note - tracks[program].append( - Note(vel, pitch, current_tick, current_tick + duration) - ) - - # Tempo, adds a TempoChange if necessary - if self.config.use_tempos and time_step[6].split("_")[1] != "None": - tempo = float(time_step[6].split("_")[1]) - if tempo != tempo_changes[-1].tempo: - tempo_changes.append(TempoChange(tempo, current_tick)) - - # Time Signature, adds a TimeSignatureChange if necessary - if ( - self.config.use_time_signatures - and time_step[-1].split("_")[1] != "None" - ): - time_sig = self._parse_token_time_signature(time_step[-1].split("_")[1]) - if time_sig != ( - time_sig_changes[-1].numerator, - time_sig_changes[-1].denominator, + # RESULTS + instruments: Dict[int, Instrument] = {} + tempo_changes = [TempoChange(TEMPO, -1)] + time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division + ) # init + + current_bar_from_ts_time = 0 + current_tick_from_ts_time = 0 + current_program = 0 + for si, seq in enumerate(tokens): + # Set track / sequence program if needed + if not self.one_token_stream: + current_tick = 0 + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division + ) + if programs is not None: + current_program = -1 if programs[si][1] else programs[si][0] + + # Decode tokens + for time_step in seq: + nb_tok_to_check = 6 if self.config.use_programs else 5 + if any( + tok.split("_")[1] == "None" for tok in time_step[:nb_tok_to_check] ): - current_time_sig_tick += ( - current_bar - current_time_sig_bar - ) * ticks_per_bar - current_time_sig_bar = current_bar - time_sig = TimeSignature(*time_sig, current_time_sig_tick) - ticks_per_bar = self._compute_ticks_per_bar(time_sig, time_division) - time_sig_changes.append(time_sig) - - # Tempos - midi.tempo_changes = tempo_changes + continue # Either padding, mask: error of prediction or end of sequence anyway - # Time Signatures - midi.time_signature_changes = time_sig_changes + # Note attributes + pitch = int(time_step[0].split("_")[1]) + vel = int(time_step[1].split("_")[1]) + duration = self._token_duration_to_ticks( + time_step[2].split("_")[1], time_division + ) + if self.config.use_programs: + current_program = int(time_step[5].split("_")[1]) + + # Time values + event_pos = int(time_step[3].split("_")[1]) + event_bar = int(time_step[4].split("_")[1]) + current_tick = ( + current_tick_from_ts_time + + (event_bar - current_bar_from_ts_time) * ticks_per_bar + + event_pos * ticks_per_sample + ) - # Appends created notes to MIDI object - for program, notes in tracks.items(): - if len(notes) == 0: - continue - if int(program) == -1: - midi.instruments.append(Instrument(0, True, "Drums")) - else: - midi.instruments.append( - Instrument( - int(program), False, MIDI_INSTRUMENTS[int(program)]["name"] + # Append the created note + if current_program not in instruments.keys(): + instruments[current_program] = Instrument( + program=0 if current_program == -1 else current_program, + is_drum=current_program == -1, + name="Drums" + if current_program == -1 + else MIDI_INSTRUMENTS[current_program]["name"], ) + instruments[current_program].notes.append( + Note(vel, pitch, current_tick, current_tick + duration) ) - midi.instruments[-1].notes = notes + # Tempo, adds a TempoChange if necessary + if ( + si == 0 + and self.config.use_tempos + and time_step[self.vocab_types_idx["Tempo"]].split("_")[1] != "None" + ): + tempo = float( + time_step[self.vocab_types_idx["Tempo"]].split("_")[1] + ) + if tempo != tempo_changes[-1].tempo: + tempo_changes.append(TempoChange(tempo, current_tick)) + + # Time Signature, adds a TimeSignatureChange if necessary + if ( + self.config.use_time_signatures + and time_step[self.vocab_types_idx["TimeSig"]].split("_")[1] + != "None" + ): + num, den = self._parse_token_time_signature( + time_step[self.vocab_types_idx["TimeSig"]].split("_")[1] + ) + if ( + num != time_signature_changes[-1].numerator + and den != time_signature_changes[-1].denominator + ): + time_sig = TimeSignature(num, den, current_tick) + if si == 0: + time_signature_changes.append(time_sig) + current_bar_from_ts_time = event_bar + current_tick_from_ts_time = current_tick + ticks_per_bar = self._compute_ticks_per_bar( + time_sig, time_division + ) + + if len(tempo_changes) > 1: + del tempo_changes[0] # delete mocked tempo change + tempo_changes[0].time = 0 + if len(time_signature_changes) > 1: + del time_signature_changes[0] # delete mocked time signature change + time_signature_changes[0].time = 0 + + # create MidiFile + midi.instruments = list(instruments.values()) + midi.tempo_changes = tempo_changes + midi.time_signature_changes = time_signature_changes + midi.max_tick = max( + [ + max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 + for track in midi.instruments + ] + ) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -363,7 +333,7 @@ def _create_base_vocabulary(self) -> List[List[str]]: :return: the vocabulary as a list of string. """ - vocab = [[] for _ in range(6)] + vocab = [[] for _ in range(5)] # PITCH vocab[0] += [f"Pitch_{i}" for i in range(*self.config.pitch_range)] @@ -376,20 +346,23 @@ def _create_base_vocabulary(self) -> List[List[str]]: f'Duration_{".".join(map(str, duration))}' for duration in self.durations ] - # PROGRAM - vocab[3] += [f"Program_{i}" for i in self.config.programs] - # POSITION - max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures)) + max_nb_beats = max( + map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures) + ) nb_positions = max(self.config.beat_res.values()) * max_nb_beats - vocab[4] += [f"Position_{i}" for i in range(nb_positions)] + vocab[3] += [f"Position_{i}" for i in range(nb_positions)] # BAR (positional encoding) - vocab[5] += [ + vocab[4] += [ f"Bar_{i}" for i in range(self.config.additional_params["max_bar_embedding"]) ] + # PROGRAM + if self.config.use_programs: + vocab.append([f"Program_{i}" for i in self.config.programs]) + # TEMPO if self.config.use_tempos: vocab.append([f"Tempo_{i}" for i in self.tempos]) @@ -407,7 +380,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: :return: the token types transitions dictionary """ - return {} # not relevant for this encoding + return {} # not relevant for Octuple @_in_as_seq() def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float: @@ -425,16 +398,18 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl err = 0 current_bar = current_pos = -1 current_pitches = {p: [] for p in self.config.programs} + current_program = 0 for token in tokens.tokens: if any(tok.split("_")[1] == "None" for tok in token): err += 1 continue has_error = False - bar_value = int(token[5].split("_")[1]) - pos_value = int(token[4].split("_")[1]) + bar_value = int(token[4].split("_")[1]) + pos_value = int(token[3].split("_")[1]) pitch_value = int(token[0].split("_")[1]) - program_value = int(token[3].split("_")[1]) + if self.config.use_programs: + current_program = int(token[5].split("_")[1]) # Bar if bar_value < current_bar: @@ -452,10 +427,10 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl current_pitches = {p: [] for p in self.config.programs} # Pitch - if pitch_value in current_pitches[program_value]: + if pitch_value in current_pitches[current_program]: has_error = True else: - current_pitches[program_value].append(pitch_value) + current_pitches[current_program].append(pitch_value) if has_error: err += 1 diff --git a/miditok/tokenizations/octuple_mono.py b/miditok/tokenizations/octuple_mono.py deleted file mode 100644 index 37aafcf2..00000000 --- a/miditok/tokenizations/octuple_mono.py +++ /dev/null @@ -1,317 +0,0 @@ -from math import ceil -from typing import List, Tuple, Dict, Optional, Union, Any - -import numpy as np -from miditoolkit import MidiFile, Instrument, Note, TempoChange - -from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq -from ..classes import TokSequence -from ..constants import ( - TIME_DIVISION, - TEMPO, - MIDI_INSTRUMENTS, -) - - -class OctupleMono(MIDITokenizer): - r"""OctupleMono is similar to :ref:`Octuple` - (`MusicBert (Zeng et al.) `_) but without the - *Program* token. OctupleMono is hence better suited for tasks with one track. - Each pooled token will be a list of the form (index: Token type): - * 0: Pitch - * 1: Velocity - * 2: Duration - * 3: Position - * 4: Bar - * (+ Optional) Tempo - * (+ Optional) TimeSignature - """ - - def _tweak_config_before_creating_voc(self): - self.config.use_chords = False - self.config.use_rests = False - self.config.use_programs = False - - # used in place of positional encoding - # This attribute might increase over tokenizations, if the tokenizer encounter longer MIDIs - if "max_bar_embedding" not in self.config.additional_params: - self.config.additional_params["max_bar_embedding"] = 60 - - token_types = ["Pitch", "Velocity", "Duration", "Position", "Bar"] - if self.config.use_tempos: - token_types.append("Tempo") - if self.config.use_time_signatures: - token_types.append("TimeSignature") - self.vocab_types_idx = { - type_: idx for idx, type_ in enumerate(token_types) - } # used for data augmentation - - def _midi_to_tokens(self, midi: MidiFile, *args, **kwargs) -> List[TokSequence]: - # Convert each track to tokens - tokens = [] - for track in midi.instruments: - tokens.append(self._track_to_tokens(track)) - self.complete_sequence(tokens[-1]) - return tokens - - @_out_as_complete_seq - def _track_to_tokens(self, track: Instrument) -> TokSequence: - r"""Converts a track (miditoolkit.Instrument object) into a sequence of tokens (:class:`miditok.TokSequence`). - A time step is a list of tokens where: - (list index: token type) - 0: Pitch - 1: Velocity - 2: Duration - 4: Position - 5: Bar - (6: Tempo) - - :param track: MIDI track to convert - :return: :class:`miditok.TokSequence` of corresponding tokens. - """ - # Make sure the notes are sorted first by their onset (start) times, second by pitch - # notes.sort(key=lambda x: (x.start, x.pitch)) # done in midi_to_tokens - ticks_per_sample = self._current_midi_metadata["time_division"] / max( - self.config.beat_res.values() - ) - ticks_per_bar = self._current_midi_metadata["time_division"] * 4 - dur_bins = self._durations_ticks[self._current_midi_metadata["time_division"]] - - # Check bar embedding limit, update if needed - nb_bars = ceil( - max(note.end for note in track.notes) - / (self._current_midi_metadata["time_division"] * 4) - ) - if self.config.additional_params["max_bar_embedding"] < nb_bars: - for i in range(self.config.additional_params["max_bar_embedding"], nb_bars): - self.add_to_vocab(f"Bar_{i}", 4) - self.config.additional_params["max_bar_embedding"] = nb_bars - - tokens = [] - current_tick = -1 - current_bar = -1 - current_pos = -1 - current_tempo_idx = 0 - current_tempo = self._current_midi_metadata["tempo_changes"][ - current_tempo_idx - ].tempo - for note in track.notes: - # Positions and bars - if note.start != current_tick: - pos_index = int((note.start % ticks_per_bar) / ticks_per_sample) - current_tick = note.start - current_bar = current_tick // ticks_per_bar - current_pos = pos_index - - # Note attributes - duration = note.end - note.start - dur_index = np.argmin(np.abs(dur_bins - duration)) - token_ts = [ - f"Pitch_{note.pitch}", - f"Velocity_{note.velocity}", - f'Duration_{".".join(map(str, self.durations[dur_index]))}', - f"Position_{current_pos}", - f"Bar_{current_bar}", - ] - - # (Tempo) - if self.config.use_tempos: - # If the current tempo is not the last one - if current_tempo_idx + 1 < len( - self._current_midi_metadata["tempo_changes"] - ): - # Will loop over incoming tempo changes - for tempo_change in self._current_midi_metadata["tempo_changes"][ - current_tempo_idx + 1 : - ]: - # If this tempo change happened before the current moment - if tempo_change.time <= note.start: - current_tempo = tempo_change.tempo - current_tempo_idx += ( - 1 # update tempo value (might not change) and index - ) - elif tempo_change.time > note.start: - break # this tempo change is beyond the current time step, we break the loop - token_ts.append(f"Tempo_{current_tempo}") - - tokens.append(token_ts) - - return TokSequence(tokens=tokens) - - def _tokens_to_track( - self, - tokens: TokSequence, - time_division: Optional[int] = TIME_DIVISION, - program: Optional[Tuple[int, bool]] = (0, False), - ) -> Tuple[Instrument, List[TempoChange]]: - r"""Converts a sequence of tokens into a track object - A time step is a list of tokens where: - (list index: token type) - 0: Pitch - 1: Velocity - 2: Duration - 4: Position - 5: Bar - (+ TimeSignature) - (+ Tempo) - - :param tokens: sequence of tokens to convert. Can be either a Tensor (PyTorch and Tensorflow are supported), - a numpy array, a Python list or a TokSequence. - :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create) - :param program: the MIDI program of the produced track and if it drum, (default (0, False), piano) - :return: the miditoolkit instrument object and tempo changes - """ - assert ( - time_division % max(self.config.beat_res.values()) == 0 - ), f"Invalid time division, please give one divisible by {max(self.config.beat_res.values())}" - tokens = tokens.tokens - - ticks_per_sample = time_division // max(self.config.beat_res.values()) - name = "Drums" if program[1] else MIDI_INSTRUMENTS[program[0]]["name"] - instrument = Instrument(program[0], is_drum=program[1], name=name) - - tempo_changes = [TempoChange(TEMPO, 0)] - if self.config.use_tempos: - for i in range(len(tokens)): - if tokens[i][-1].split("_")[1] != "None": - tempo_changes = [TempoChange(float(tokens[i][-1].split("_")[1]), 0)] - break - - for time_step in tokens: - if any(tok.split("_")[1] == "None" for tok in time_step[:6]): - continue # Either padding, mask: error of prediction or end of sequence anyway - - # Note attributes - pitch = int(time_step[0].split("_")[1]) - vel = int(time_step[1].split("_")[1]) - duration = self._token_duration_to_ticks( - time_step[2].split("_")[1], time_division - ) - - # Time and track values - current_pos = int(time_step[3].split("_")[1]) - current_bar = int(time_step[4].split("_")[1]) - current_tick = ( - current_bar * time_division * 4 + current_pos * ticks_per_sample - ) - - # Append the created note - instrument.notes.append( - Note(vel, pitch, current_tick, current_tick + duration) - ) - - # Tempo, adds a TempoChange if necessary - if self.config.use_tempos and time_step[-1].split("_")[1] != "None": - tempo = float(time_step[-1].split("_")[1]) - if tempo != tempo_changes[-1].tempo: - tempo_changes.append(TempoChange(tempo, current_tick)) - - return instrument, tempo_changes - - def _create_base_vocabulary(self) -> List[List[str]]: - r"""Creates the vocabulary, as a list of string tokens. - Each token as to be given as the form of "Type_Value", separated with an underscore. - Example: Pitch_58 - The :class:`miditok.MIDITokenizer` main class will then create the "real" vocabulary as - a dictionary. - Special tokens have to be given when creating the tokenizer, and - will be added to the vocabulary by :class:`miditok.MIDITokenizer`. - - :return: the vocabulary as a list of string. - """ - vocab = [[] for _ in range(5)] - - # PITCH - vocab[0] += [f"Pitch_{i}" for i in range(*self.config.pitch_range)] - - # VELOCITY - vocab[1] += [f"Velocity_{i}" for i in self.velocities] - - # DURATION - vocab[2] += [ - f'Duration_{".".join(map(str, duration))}' for duration in self.durations - ] - - # POSITION - max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures)) - nb_positions = max(self.config.beat_res.values()) * max_nb_beats - vocab[3] += [f"Position_{i}" for i in range(nb_positions)] - - # BAR - vocab[4] += [ - f"Bar_{i}" - for i in range(self.config.additional_params["max_bar_embedding"]) - ] # bar embeddings (positional encoding) - - # TEMPO - if self.config.use_tempos: - vocab.append([f"Tempo_{i}" for i in self.tempos]) - - return vocab - - def _create_token_types_graph(self) -> Dict[str, List[str]]: - r"""Returns a graph (as a dictionary) of the possible token - types successions. - Not relevant for Octuple. - - :return: the token types transitions dictionary - """ - return {} # not relevant for this encoding - - @_in_as_seq() - def tokens_errors( - self, tokens: Union[TokSequence, List, np.ndarray, Any] - ) -> Union[float, List[float]]: - r"""Checks if a sequence of tokens is made of good token values and - returns the error ratio (lower is better). - The token types are always the same in Octuple so this method only checks - if their values are correct: - - a bar token value cannot be < to the current bar (it would go back in time) - - same for positions - - a pitch token should not be present if the same pitch is already played at the current position - - :param tokens: sequence of tokens to check - :return: the error ratio (lower is better) - """ - # If list of TokSequence -> recursive - if isinstance(tokens, list): - return [self.tokens_errors(tok_seq) for tok_seq in tokens] - - err = 0 - current_bar = current_pos = -1 - current_pitches = [] - - for token in tokens.tokens: - if any(tok.split("_")[1] == "None" for tok in token): - err += 1 - continue - has_error = False - bar_value = int(token[4].split("_")[1]) - pos_value = int(token[3].split("_")[1]) - pitch_value = int(token[0].split("_")[1]) - - # Bar - if bar_value < current_bar: - has_error = True - elif bar_value > current_bar: - current_bar = bar_value - current_pos = -1 - current_pitches = [] - - # Position - if pos_value < current_pos: - has_error = True - elif pos_value > current_pos: - current_pos = pos_value - current_pitches = [] - - # Pitch - if pitch_value in current_pitches: - has_error = True - else: - current_pitches.append(pitch_value) - - if has_error: - err += 1 - - return err / len(tokens) diff --git a/miditok/tokenizations/remi.py b/miditok/tokenizations/remi.py index 8665000b..ec623699 100644 --- a/miditok/tokenizations/remi.py +++ b/miditok/tokenizations/remi.py @@ -29,10 +29,12 @@ class REMI(MIDITokenizer): `FIGARO (Rütte et al.) `, which handle multiple instruments by adding `Program` tokens before the `Pitch` ones. - **NOTE:** in the original paper, the tempo information is represented as the succession + **Note:** in the original paper, the tempo information is represented as the succession of two token types: a *TempoClass* indicating if the tempo is fast or slow, and a *TempoValue* indicating its value. MidiTok only uses one *Tempo* token for its value (see :ref:`Additional tokens`). + **Note:** When decoding multiple token sequences (of multiple tracks), i.e. when `config.use_programs` is False, + only the tempos and time signatures of the first sequence will be decoded for the whole MIDI. :param tokenizer_config: the tokenizer's configuration, as a :class:`miditok.classes.TokenizerConfig` object. :param max_bar_embedding: Maximum number of bars ("Bar_0", "Bar_1",...,"Bar_{num_bars-1}"). @@ -76,8 +78,7 @@ def _add_time_events(self, events: List[Event]) -> List[Event]: time_division = self._current_midi_metadata["time_division"] ticks_per_sample = time_division / max(self.config.beat_res.values()) min_rest = ( - time_division * self.rests[0][0] - + ticks_per_sample * self.rests[0][1] + time_division * self.rests[0][0] + ticks_per_sample * self.rests[0][1] if self.config.use_rests else 0 ) @@ -121,9 +122,7 @@ def _add_time_events(self, events: List[Event]) -> List[Event]: desc=f"{rest_beat}.0", ) ) - previous_tick += ( - rest_beat * time_division - ) + previous_tick += rest_beat * time_division while rest_pos >= self.rests[0][1]: rest_pos_temp = min( @@ -189,7 +188,6 @@ def _add_time_events(self, events: List[Event]) -> List[Event]: elif event.type == "Tempo": previous_note_end = max(previous_note_end, event.time) - # So sorting needed return all_events @_in_as_seq() @@ -226,7 +224,9 @@ def tokens_to_midi( instruments: Dict[int, Instrument] = {} tempo_changes = [TempoChange(TEMPO, -1)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] - ticks_per_bar = self._compute_ticks_per_bar(time_signature_changes[0], time_division) # init + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division + ) # init current_tick = 0 current_bar = -1 @@ -237,6 +237,9 @@ def tokens_to_midi( if not self.one_token_stream: current_tick = 0 current_bar = -1 + ticks_per_bar = self._compute_ticks_per_bar( + time_signature_changes[0], time_division + ) previous_note_end = 0 if programs is not None: current_program = -1 if programs[si][1] else programs[si][0] @@ -300,18 +303,21 @@ def tokens_to_midi( # If your encoding include tempo tokens, each Position token should be followed by # a tempo token, but if it is not the case this method will skip this step tempo = float(token.split("_")[1]) - if tempo != tempo_changes[-1].tempo: + if si == 0 and current_tick != tempo_changes[-1].time: tempo_changes.append(TempoChange(tempo, current_tick)) + previous_note_end = max(previous_note_end, current_tick) elif token.split("_")[0] == "TimeSig": num, den = self._parse_token_time_signature(token.split("_")[1]) if ( num != time_signature_changes[-1].numerator and den != time_signature_changes[-1].denominator ): - time_signature_changes.append( - TimeSignature(num, den, current_tick) + time_sig = TimeSignature(num, den, current_tick) + if si == 0: + time_signature_changes.append(time_sig) + ticks_per_bar = self._compute_ticks_per_bar( + time_sig, time_division ) - ticks_per_bar = self._compute_ticks_per_bar(time_signature_changes[-1], time_division) if len(tempo_changes) > 1: del tempo_changes[0] # delete mocked tempo change tempo_changes[0].time = 0 @@ -369,7 +375,9 @@ def _create_base_vocabulary(self) -> List[str]: ] # POSITION - max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures)) + max_nb_beats = max( + map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures) + ) nb_positions = max(self.config.beat_res.values()) * max_nb_beats vocab += [f"Position_{i}" for i in range(nb_positions)] diff --git a/miditok/tokenizations/tsd.py b/miditok/tokenizations/tsd.py index 6a5dc41a..2569f047 100644 --- a/miditok/tokenizations/tsd.py +++ b/miditok/tokenizations/tsd.py @@ -24,6 +24,8 @@ class TSD(MIDITokenizer): **Note:** as `TSD` uses *TimeShifts* events to move the time from note to note, it can be unsuited for tracks with pauses longer than the maximum `TimeShift` value. In such cases, the maximum *TimeShift* value will be used. + **Note:** When decoding multiple token sequences (of multiple tracks), i.e. when `config.use_programs` is False, + only the tempos and time signatures of the first sequence will be decoded for the whole MIDI. """ def _tweak_config_before_creating_voc(self): @@ -133,6 +135,8 @@ def _add_time_events(self, events: List[Event]) -> List[Event]: # Update max offset time of the notes encountered if event.type == "Pitch": previous_note_end = max(previous_note_end, event.desc) + elif event.type == "Tempo": + previous_note_end = max(previous_note_end, event.time) return all_events @@ -232,13 +236,15 @@ def tokens_to_midi( # If your encoding include tempo tokens, each Position token should be followed by # a tempo token, but if it is not the case this method will skip this step tempo = float(token.split("_")[1]) - if tempo != tempo_changes[-1].tempo: + if si == 0 and current_tick != tempo_changes[-1].time: tempo_changes.append(TempoChange(tempo, current_tick)) + previous_note_end = max(previous_note_end, current_tick) elif token.split("_")[0] == "TimeSig": num, den = self._parse_token_time_signature(token.split("_")[1]) current_time_signature = time_signature_changes[-1] if ( - num != current_time_signature.numerator + si == 0 + and num != current_time_signature.numerator and den != current_time_signature.denominator ): time_signature_changes.append( diff --git a/tests/test_bpe.py b/tests/test_bpe.py index 09405dc0..0e6b0ffc 100644 --- a/tests/test_bpe.py +++ b/tests/test_bpe.py @@ -40,7 +40,7 @@ def test_bpe_conversion( :param data_path: root path to the data to test """ random.seed(777) - tokenizations = ["Structured", "REMI", "REMIPlus", "MIDILike", "TSD", "MMM"] + tokenizations = ["Structured", "REMI", "MIDILike", "TSD", "MMM"] data_path = Path(data_path) files = list(data_path.glob("**/*.mid")) diff --git a/tests/test_methods.py b/tests/test_methods.py index d0dfb3fc..472be465 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -17,6 +17,7 @@ from tensorflow import Tensor as tfTensor, convert_to_tensor import miditok +from .tests_utils import ALL_TOKENIZATIONS def test_convert_tensors(): @@ -81,19 +82,10 @@ def test_convert_tensors(): def test_data_augmentation(): data_path = Path("./tests/Multitrack_MIDIs") - tokenizations = [ - "TSD", - "MIDILike", - "REMI", - "REMIPlus", - "Structured", - "CPWord", - "Octuple", - "OctupleMono", - ] original_midi_paths = list(data_path.glob("**/*.mid")) + ALL_TOKENIZATIONS.remove("MuMIDI") # not compatible - for tokenization in tokenizations: + for tokenization in ALL_TOKENIZATIONS: print(f"TESTING WITH {tokenization}") tokenizer = getattr(miditok, tokenization)() midi_aug_path = Path("tests", "Multitrack_MIDIs_aug", tokenization) diff --git a/tests/test_multitrack.py b/tests/test_multitrack.py index c9e86e4b..c3a74d53 100644 --- a/tests/test_multitrack.py +++ b/tests/test_multitrack.py @@ -18,6 +18,7 @@ from copy import deepcopy from pathlib import Path from typing import Union +from time import time import miditok from miditoolkit import MidiFile, Marker @@ -30,6 +31,7 @@ reduce_note_durations, adapt_tempo_changes_times, time_signature_changes_equals, + remove_equal_successive_tempos, ) # Special beat res for test, up to 16 beats so the duration and time-shift values are @@ -66,16 +68,12 @@ def test_multitrack_midi_to_tokens_to_midi( """ files = list(Path(data_path).glob("**/*.mid")) - at_least_one_error = False + has_errors = False + t0 = time() for i, file_path in enumerate(tqdm(files, desc="Testing multitrack")): # Reads the MIDI - try: - midi = MidiFile(Path(file_path)) - except ( - Exception - ): # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue + midi = MidiFile(Path(file_path)) if midi.ticks_per_beat % max(BEAT_RES_TEST.values()) != 0: continue has_errors = False @@ -87,9 +85,8 @@ def test_multitrack_midi_to_tokens_to_midi( ) # Process the MIDI - midi_to_compare = deepcopy( - midi - ) # midi notes / tempos / time signature quantized with the line above + # midi notes / tempos / time signature quantized with the line above + midi_to_compare = deepcopy(midi) for track in midi_to_compare.instruments: if track.is_drum: track.program = ( @@ -105,11 +102,16 @@ def test_multitrack_midi_to_tokens_to_midi( track.notes, max(tu[1] for tu in BEAT_RES_TEST) * midi_to_compare.ticks_per_beat, ) - miditok.utils.fix_offsets_overlapping_notes(track.notes) - if tokenization in ["Octuple", "REMIPlus"]: # needed + if tokenization in ["MIDILike"]: + miditok.utils.fix_offsets_overlapping_notes(track.notes) + # For Octuple, as tempo is only carried at notes times, we need to adapt their times for comparison + if tokenization in ["Octuple"]: adapt_tempo_changes_times( midi_to_compare.instruments, midi_to_compare.tempo_changes ) + # When the tokenizer only decoded tempo changes different from the last tempo val + if tokenization in ["CPWord"]: + remove_equal_successive_tempos(midi_to_compare.tempo_changes) # MIDI -> Tokens -> MIDI midi_to_compare.instruments.sort( @@ -149,7 +151,6 @@ def test_multitrack_midi_to_tokens_to_midi( f"MIDI {i} - {file_path} failed to encode/decode NOTES with " f"{tokenization} ({sum(len(t[2]) for t in errors)} errors)" ) - # return False # Checks tempos if ( @@ -166,10 +167,7 @@ def test_multitrack_midi_to_tokens_to_midi( ) # Checks time signatures - if tokenizer.config.use_time_signatures and tokenization in [ - "Octuple", - "REMIPlus", - ]: + if tokenizer.config.use_time_signatures: time_sig_errors = time_signature_changes_equals( midi_to_compare.time_signature_changes, new_midi.time_signature_changes, @@ -182,7 +180,7 @@ def test_multitrack_midi_to_tokens_to_midi( ) if has_errors: - at_least_one_error = True + has_errors = True if saving_erroneous_midis: new_midi.dump( Path( @@ -198,7 +196,10 @@ def test_multitrack_midi_to_tokens_to_midi( f"{file_path.stem}_{tokenization}_original.mid", ) ) - assert not at_least_one_error + + ttotal = time() - t0 + print(f"Took {ttotal:.2f} seconds") + assert not has_errors if __name__ == "__main__": diff --git a/tests/test_one_track.py b/tests/test_one_track.py index 87ad1e43..19558e83 100644 --- a/tests/test_one_track.py +++ b/tests/test_one_track.py @@ -13,6 +13,7 @@ from copy import deepcopy from pathlib import Path, PurePath from typing import Union +from time import time import miditok from miditoolkit import MidiFile, Marker @@ -24,6 +25,7 @@ tempo_changes_equals, time_signature_changes_equals, adapt_tempo_changes_times, + remove_equal_successive_tempos, ) # Special beat res for test, up to 64 beats so the duration and time-shift values are @@ -44,6 +46,8 @@ "chord_maps": miditok.constants.CHORD_MAPS, "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" "chord_unknown": False, + "delete_equal_successive_time_sig_changes": True, + "delete_equal_successive_tempo_changes": True, } @@ -60,13 +64,14 @@ def test_one_track_midi_to_tokens_to_midi( """ files = list(Path(data_path).glob("**/*.mid")) at_least_one_error = False + t0 = time() for i, file_path in enumerate(tqdm(files, desc="Testing One Track")): # Reads the midi midi = MidiFile(file_path) - adapt_tempo_changes_times(midi.instruments, midi.tempo_changes) - tracks = [deepcopy(midi.instruments[0])] has_errors = False + # Will store the tracks tokenized / detokenized, to be saved in case of errors + tracks = [deepcopy(midi.instruments[0])] for tokenization in ALL_TOKENIZATIONS: tokenizer_config = miditok.TokenizerConfig(**TOKENIZER_PARAMS) @@ -80,11 +85,31 @@ def test_one_track_midi_to_tokens_to_midi( tokenizer_config=tokenizer_config ) + # Process the MIDI + # midi notes / tempos / time signature quantized with the line above + midi_to_compare = deepcopy(midi) + for track in midi_to_compare.instruments: + if track.is_drum: + track.program = ( + 0 # need to be done before sorting tracks per program + ) + + # This step is also performed in preprocess_midi, but we need to call it here for the assertions below + tokenizer.preprocess_midi(midi_to_compare) + # For Octuple, as tempo is only carried at notes times, we need to adapt their times for comparison + if tokenization in ["Octuple", "OctupleMono"]: + adapt_tempo_changes_times( + midi_to_compare.instruments, midi_to_compare.tempo_changes + ) + # When the tokenizer only decoded tempo changes different from the last tempo val + if tokenization in ["CPWord"]: + remove_equal_successive_tempos(midi_to_compare.tempo_changes) + # printing the tokenizer shouldn't fail _ = str(tokenizer) # Convert the track in tokens - tokens = tokenizer(midi) + tokens = tokenizer(midi_to_compare) if not tokenizer.one_token_stream: tokens = tokens[0] @@ -99,16 +124,12 @@ def test_one_track_midi_to_tokens_to_midi( if not tokenizer.one_token_stream: tokens = [tokens] new_midi = tokenizer.tokens_to_midi( - tokens, time_division=midi.ticks_per_beat + tokens, time_division=midi_to_compare.ticks_per_beat ) track = new_midi.instruments[0] - tempo_changes = new_midi.tempo_changes - time_sig_changes = None - if tokenization == "Octuple": - time_sig_changes = new_midi.time_signature_changes # Checks its good - errors = track_equals(midi.instruments[0], track) + errors = track_equals(midi_to_compare.instruments[0], track) if len(errors) > 0: has_errors = True if errors[0][0] != "len": @@ -122,13 +143,14 @@ def test_one_track_midi_to_tokens_to_midi( print( f"MIDI {i} - {file_path} failed to encode/decode NOTES with {tokenization} ({len(errors)} errors)" ) - # return False track.name = f"encoded with {tokenization}" tracks.append(track) # Checks tempos - if tempo_changes is not None and tokenizer.config.use_tempos: - tempo_errors = tempo_changes_equals(midi.tempo_changes, tempo_changes) + if tokenizer.config.use_tempos and tokenization != "MuMIDI": + tempo_errors = tempo_changes_equals( + midi_to_compare.tempo_changes, new_midi.tempo_changes + ) if len(tempo_errors) > 0: has_errors = True print( @@ -137,9 +159,10 @@ def test_one_track_midi_to_tokens_to_midi( ) # Checks time signatures - if time_sig_changes is not None and tokenizer.config.use_time_signatures: + if tokenizer.config.use_time_signatures: time_sig_errors = time_signature_changes_equals( - midi.time_signature_changes, time_sig_changes + midi_to_compare.time_signature_changes, + new_midi.time_signature_changes, ) if len(time_sig_errors) > 0: has_errors = True @@ -158,6 +181,8 @@ def test_one_track_midi_to_tokens_to_midi( midi.instruments += tracks midi.dump(PurePath("tests", "test_results", file_path.name)) + ttotal = time() - t0 + print(f"Took {ttotal:.2f} seconds") assert not at_least_one_error diff --git a/tests/test_saving_loading_config.py b/tests/test_saving_loading_config.py index 46e7de1a..ed3b1479 100644 --- a/tests/test_saving_loading_config.py +++ b/tests/test_saving_loading_config.py @@ -6,6 +6,7 @@ """ import miditok +from .tests_utils import ALL_TOKENIZATIONS ADDITIONAL_TOKENS_TEST = { @@ -19,21 +20,10 @@ "tempo_range": (40, 250), "time_signature_range": {4: [4]}, } -tokenizations = [ - "MIDILike", - "TSD", - "Structured", - "REMI", - "REMIPlus", - "CPWord", - "Octuple", - "OctupleMono", - "MuMIDI", -] def test_saving_loading_tokenizer_config(): - for tokenization in tokenizations: + for tokenization in ALL_TOKENIZATIONS: config1 = miditok.TokenizerConfig() config1.save_to_json(f"./tests/configs/tok_conf_{tokenization}.json") @@ -51,7 +41,7 @@ def test_saving_loading_tokenizer(): If all went well the tokenizer should be identical. """ - for tokenization in tokenizations: + for tokenization in ALL_TOKENIZATIONS: tokenizer_config = miditok.TokenizerConfig(**ADDITIONAL_TOKENS_TEST) tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( tokenizer_config=tokenizer_config diff --git a/tests/tests_utils.py b/tests/tests_utils.py index 01f33b50..132c1f74 100644 --- a/tests/tests_utils.py +++ b/tests/tests_utils.py @@ -14,7 +14,6 @@ "REMI", "CPWord", "Octuple", - "OctupleMono", "MuMIDI", "MMM", ] @@ -58,7 +57,9 @@ def notes_equals(note1: Note, note2: Note) -> str: def tempo_changes_equals( tempo_changes1: List[TempoChange], tempo_changes2: List[TempoChange] -) -> List[Tuple[str, TempoChange, float]]: +) -> List[Tuple[str, Union[TempoChange, int], float]]: + if len(tempo_changes1) != len(tempo_changes2): + return [("len", len(tempo_changes2), len(tempo_changes1))] errors = [] for tempo_change1, tempo_change2 in zip(tempo_changes1, tempo_changes2): if tempo_change1.time != tempo_change2.time: @@ -110,10 +111,14 @@ def adapt_tempo_changes_times( """ notes = sum((t.notes for t in tracks), []) notes.sort(key=lambda x: x.start) + max_tick = max(note.start for note in notes) current_note_idx = 0 tempo_idx = 1 while tempo_idx < len(tempo_changes): + if tempo_changes[tempo_idx].time > max_tick: + del tempo_changes[tempo_idx] + continue for n, note in enumerate(notes[current_note_idx:]): if note.start >= tempo_changes[tempo_idx].time: tempo_changes[tempo_idx].time = note.start @@ -123,3 +128,14 @@ def adapt_tempo_changes_times( del tempo_changes[tempo_idx - 1] continue tempo_idx += 1 + + +def remove_equal_successive_tempos(tempo_changes: List[TempoChange]): + current_tempo = -1 + i = 0 + while i < len(tempo_changes): + if tempo_changes[i].tempo == current_tempo: + del tempo_changes[i] + continue + current_tempo = tempo_changes[i].tempo + i += 1