Skip to content

Commit

Permalink
fixes CPWord time signature encoding/decoding + warning when used wit…
Browse files Browse the repository at this point in the history
…h rests
  • Loading branch information
Natooz committed Oct 10, 2023
1 parent 0506d75 commit 70a2a45
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/additional_tokens_table.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MIDILike,✅,✅,✅,✅,✅,✅
REMI,✅,✅,✅,✅,✅,✅
TSD,✅,✅,✅,✅,✅,✅
Structured,❌,❌,❌,❌,❌,❌
CPWord,✅,✅,✅,✅,❌,❌
CPWord,✅,✅¹,✅,✅¹,❌,❌
Octuple,✅,✅,❌,❌,❌,❌
MuMIDI,✅,❌,✅,❌,❌,❌
MMM,✅,✅,✅,❌,❌,❌
2 changes: 2 additions & 0 deletions docs/midi_tokenizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ MidiTok offers to include additional tokens on music information. You can specif
:file: additional_tokens_table.csv
:header-rows: 1

¹: using both time signatures and rests with `CPWord` might result in time alterations, as the time signature is carried by the Bar tokens which are skipped during period of rests.


Special tokens
------------------------
Expand Down
106 changes: 65 additions & 41 deletions miditok/tokenizations/cp_word.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from math import ceil
from typing import List, Tuple, Dict, Optional, Union, Any
from pathlib import Path
import warnings

import numpy as np
from miditoolkit import MidiFile, Instrument, Note, TempoChange, TimeSignature
Expand Down Expand Up @@ -37,6 +38,17 @@ class CPWord(MIDITokenizer):
"""

def _tweak_config_before_creating_voc(self):
if self.config.use_time_signatures and self.config.use_rests:
# NOTE: this configuration could work by adding a Bar token with the new TimeSig after the Rest, but the
# decoding should handle this to not add another bar. Or it could work by making Rests not crossing new
# bars. Rests would have a maximal value corresponding to the difference between the previous event tick
# and the tick of the next bar. However, in cases of long rests of more than one bar, we would have
# successions of Rest --> Bar --> Rest --> Bar ... tokens.
warnings.warn("You are using both Time Signatures and Rests with CPWord. Be aware that this configuration"
"can result in altered time, as the time signature is carried by the Bar tokens, that are"
"skipped during rests. To disable this warning, you can disable either Time Signatures or"
"Rests. Otherwise, you can check that your data does not have time signature changes"
"occurring during rests.")
self.config.use_sustain_pedals = False
self.config.use_pitch_bends = False
self.config.program_changes = False
Expand Down Expand Up @@ -120,6 +132,7 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
current_tempo = event.value
elif event.type == "Program":
current_program = event.value
continue
if event.time != previous_tick:
# (Rest)
if (
Expand Down Expand Up @@ -158,26 +171,17 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
- current_bar
)
if nb_new_bars >= 1:
if self.config.use_time_signatures:
time_sig_arg = (
f"{current_time_sig[0]}/{current_time_sig[1]}"
)
else:
time_sig_arg = None
for i in range(nb_new_bars):
# Update time signature time variables before adding the last bar
if self.config.use_time_signatures:
if event.type == "TimeSig" and i + 1 == nb_new_bars:
current_time_sig = list(
map(int, event.value.split("/"))
)
bar_at_last_ts_change += (
event.time - tick_at_last_ts_change
) // ticks_per_bar
tick_at_last_ts_change = event.time
ticks_per_bar = self._compute_ticks_per_bar(
TimeSignature(*current_time_sig, event.time),
time_division,
)
time_sig_arg = (
f"{current_time_sig[0]}/{current_time_sig[1]}"
)
else:
time_sig_arg = None
# exception when last bar and event.type == "TimeSig"
if i == nb_new_bars - 1 and event.type == "TimeSig":
time_sig_arg = list(map(int, event.value.split("/")))
time_sig_arg = f"{time_sig_arg[0]}/{time_sig_arg[1]}"
all_events.append(
self.__create_cp_token(
(current_bar + i + 1) * ticks_per_bar,
Expand All @@ -193,18 +197,33 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
)

# Position
pos_index = int((event.time - tick_at_current_bar) / ticks_per_sample)
all_events.append(
self.__create_cp_token(
event.time,
pos=pos_index,
tempo=current_tempo if self.config.use_tempos else None,
desc="Position",
if event.type != "TimeSig":
pos_index = int((event.time - tick_at_current_bar) / ticks_per_sample)
all_events.append(
self.__create_cp_token(
event.time,
pos=pos_index,
chord=event.value if event.type == "Chord" else None,
tempo=current_tempo if self.config.use_tempos else None,
desc="Position",
)
)
)

previous_tick = event.time

# Update time signature time variables, after adjusting the time (above)
if event.type == "TimeSig":
current_time_sig = list(map(int, event.value.split("/")))
bar_at_last_ts_change += (
event.time - tick_at_last_ts_change
) // ticks_per_bar
tick_at_last_ts_change = event.time
ticks_per_bar = self._compute_ticks_per_bar(
TimeSignature(*current_time_sig, event.time), time_division
)
# We decrease the previous tick so that a Position token is enforced for the next event
previous_tick -= 1

# Convert event to CP Event
# Update max offset time of the notes encountered
if event.type == "Pitch" and e + 2 < len(events):
Expand All @@ -218,7 +237,12 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
)
)
previous_note_end = max(previous_note_end, event.desc)
elif event.type == "Tempo":
elif event.type in [
"Program",
"Tempo",
"TimeSig",
"Chord",
]:
previous_note_end = max(previous_note_end, event.time)

return all_events
Expand Down Expand Up @@ -377,11 +401,13 @@ def check_inst(prog: int):
)[1]
)
time_signature_changes.append(TimeSignature(num, den, 0))
break
else:
break
if len(time_signature_changes) == 0:
time_signature_changes.append(TimeSignature(*TIME_SIGNATURE, 0))
ticks_per_bar = self._compute_ticks_per_bar(time_signature_changes[0], time_division)
current_time_sig = time_signature_changes[0]
ticks_per_bar = self._compute_ticks_per_bar(current_time_sig, time_division)
# Set track / sequence program if needed
if not self.one_token_stream:
current_tick = tick_at_last_ts_change = tick_at_current_bar = 0
Expand Down Expand Up @@ -430,25 +456,23 @@ def check_inst(prog: int):
current_tick = tick_at_current_bar + ticks_per_bar
tick_at_current_bar = current_tick
# Add new TS only if different from the last one
if self.config.use_time_signatures and si == 0:
if self.config.use_time_signatures:
num, den = self._parse_token_time_signature(
compound_token[self.vocab_types_idx["TimeSig"]].split(
"_"
)[1]
)
if (
num != time_signature_changes[-1].numerator
or den != time_signature_changes[-1].denominator
num != current_time_sig.numerator
or den != current_time_sig.denominator
):
time_sig = TimeSignature(num, den, current_tick)
current_time_sig = TimeSignature(num, den, current_tick)
if si == 0:
time_signature_changes.append(time_sig)
tick_at_last_ts_change = (
tick_at_current_bar # == current_tick
)
time_signature_changes.append(current_time_sig)
tick_at_last_ts_change = tick_at_current_bar
bar_at_last_ts_change = current_bar
ticks_per_bar = self._compute_ticks_per_bar(
time_sig, time_division
current_time_sig, time_division
)
elif bar_pos == "Position": # i.e. its a position
if current_bar == -1:
Expand All @@ -466,12 +490,10 @@ def check_inst(prog: int):
)[1]
)
if (
si == 0
and tempo != tempo_changes[-1].tempo
tempo != tempo_changes[-1].tempo
and current_tick != tempo_changes[-1].time
):
tempo_changes.append(TempoChange(tempo, current_tick))
previous_note_end = max(previous_note_end, current_tick)
elif (
self.config.use_rests
and compound_token[self.vocab_types_idx["Rest"]].split("_")[1]
Expand All @@ -492,6 +514,8 @@ def check_inst(prog: int):
) * ticks_per_bar
current_bar = real_current_bar

previous_note_end = max(previous_note_end, current_tick)

# Add current_inst to midi and handle notes still active
if not self.one_token_stream:
midi.instruments.append(current_instrument)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_one_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def test_one_track_midi_to_tokens_to_midi(
elif tokenization == "Octuple":
params["max_bar_embedding"] = 300
params["use_time_signatures"] = False # because of time shifted
elif tokenization == "CPWord":
# Rests and time sig can mess up with CPWord, when a Rest that is crossing new bar is followed
# by a new TimeSig change, as TimeSig are carried with Bar tokens (and there is None is this case)
if params["use_time_signatures"] and params["use_rests"]:
params["use_rests"] = False

tokenizer_config = miditok.TokenizerConfig(**params)
tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)(
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def tempo_changes_equals(

def time_signature_changes_equals(
time_sig_changes1: List[TimeSignature], time_sig_changes2: List[TimeSignature]
) -> List[Tuple[str, TimeSignature, float]]:
) -> List[Tuple[str, Union[TimeSignature, int], float]]:
if len(time_sig_changes1) != len(time_sig_changes2):
return [("len", len(time_sig_changes1), len(time_sig_changes2))]
errors = []
for time_sig_change1, time_sig_change2 in zip(time_sig_changes1, time_sig_changes2):
if time_sig_change1.time != time_sig_change2.time:
Expand Down

0 comments on commit 70a2a45

Please sign in to comment.