Skip to content

Commit

Permalink
Update time signature tokens configuration (#65)
Browse files Browse the repository at this point in the history
* update time signature tokens configuration

* add config check in validate_midi_time_signatures
  • Loading branch information
ilya16 authored Aug 14, 2023
1 parent 5d1c12e commit 114d253
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 97 deletions.
15 changes: 12 additions & 3 deletions miditok/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ class TokenizerConfig:
:param nb_tempos: number of tempos "bins" to use. (default: 32)
:param tempo_range: range of minimum and maximum tempos within which the bins fall. (default: (40, 250))
:param log_tempos: will use log scaled tempo values instead of linearly scaled. (default: False)
:param time_signature_range: range as a tuple (max_beat_res, nb_notes). (default: (8, 2))
:param time_signature_range: range as a dictionary {denom_i: [num_i1, ..., num_in] / (min_num_i, max_num_i)}.
(default: {4: [4]})
:param programs: sequence of MIDI programs to use. Note that `-1` is used and reserved for drums tracks.
(default: from -1 to 127 included)
:param **kwargs: additional parameters that will be saved in `config.additional_params`.
Expand All @@ -207,7 +208,7 @@ def __init__(
nb_tempos: int = NB_TEMPOS,
tempo_range: Tuple[int, int] = TEMPO_RANGE,
log_tempos: bool = LOG_TEMPOS,
time_signature_range: Tuple[int, int] = TIME_SIGNATURE_RANGE,
time_signature_range: Dict[int, Union[List[int], Tuple[int, int]]] = TIME_SIGNATURE_RANGE,
programs: Sequence[int] = PROGRAMS,
**kwargs,
):
Expand Down Expand Up @@ -240,7 +241,10 @@ def __init__(
self.log_tempos: bool = log_tempos

# Time signature params
self.time_signature_range: Tuple[int, int] = time_signature_range
self.time_signature_range: Dict[int, List[int]] = {
beat_res: list(range(beats[0], beats[1] + 1)) if isinstance(beats, tuple) else beats
for beat_res, beats in time_signature_range.items()
}

# Programs
self.programs: Sequence[int] = programs
Expand Down Expand Up @@ -325,6 +329,11 @@ def load_from_json(cls, config_file_path: Union[str, Path]) -> "TokenizerConfig"
for beat_range, res in dict_config["beat_res"].items()
}

dict_config["time_signature_range"] = {
int(res): beat_range
for res, beat_range in dict_config["time_signature_range"].items()
}

return cls.from_dict(dict_config)

def __eq__(self, other):
Expand Down
2 changes: 1 addition & 1 deletion miditok/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
LOG_TEMPOS = False # log or linear scale tempos

# Time signature params
TIME_SIGNATURE_RANGE = (8, 2)
TIME_SIGNATURE_RANGE = {4: [4]} # {denom_i: [num_i1, ..., num_in] / (min_num_i, max_num_i)}

# Programs
PROGRAMS = list(range(-1, 128))
Expand Down
89 changes: 43 additions & 46 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .constants import (
TIME_DIVISION,
CURRENT_VERSION_PACKAGE,
TIME_SIGNATURE,
CHR_ID_START,
PITCH_CLASSES,
UNKNOWN_CHORD_PREFIX,
Expand Down Expand Up @@ -240,7 +241,7 @@ def __init__(
self.rests = self.__create_rests()

# Time Signatures
self.time_signatures = []
self.time_signatures = [TIME_SIGNATURE]
if self.config.use_time_signatures:
self.time_signatures = self.__create_time_signatures()

Expand Down Expand Up @@ -362,7 +363,7 @@ def preprocess_midi(self, midi: MidiFile):

if len(midi.time_signature_changes) == 0: # can sometimes happen
midi.time_signature_changes.append(
TimeSignature(4, 4, 0)
TimeSignature(*TIME_SIGNATURE, 0)
) # 4/4 by default in this case
if self.config.use_time_signatures:
self._quantize_time_signatures(
Expand Down Expand Up @@ -448,7 +449,7 @@ def _quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int
:param time_sigs: time signature changes to quantize.
:param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed).
"""
ticks_per_bar = time_division * time_sigs[0].numerator
ticks_per_bar = MIDITokenizer._compute_ticks_per_bar(time_sigs[0], time_division)
current_bar = 0
previous_tick = 0 # first time signature change is always at tick 0
prev_time_sig = time_sigs[0]
Expand All @@ -472,7 +473,7 @@ def _quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int
time_sig.time = previous_tick + bar_offset * ticks_per_bar

# Update values
ticks_per_bar = time_division * time_sig.numerator
ticks_per_bar = MIDITokenizer._compute_ticks_per_bar(time_sig, time_division)
current_bar += bar_offset
previous_tick = time_sig.time
prev_time_sig = time_sig
Expand Down Expand Up @@ -1136,54 +1137,26 @@ def __create_time_signatures(self) -> List[Tuple]:
:return: the time signatures.
"""
max_beat_res, nb_notes = self.config.time_signature_range
assert (
max_beat_res > 0 and math.log2(max_beat_res).is_integer()
), "The beat resolution in time signature must be a power of 2"
time_signature_range = self.config.time_signature_range

time_signatures = []
for i in range(0, int(math.log2(max_beat_res)) + 1): # 1 ~ max_beat_res
for j in range(1, ((2**i) * nb_notes) + 1):
time_signatures.append((j, 2**i))
return time_signatures

def _reduce_time_signature(
self, numerator: int, denominator: int
) -> Tuple[int, int]:
r"""Reduces and decomposes a time signature into one of the valid vocabulary time signatures.
If time signature's denominator (beat resolution) is larger than max_beat_res,
the denominator and numerator are reduced to max_beat_res if possible.
If time signature's numerator (bar length in beats) is larger than nb_notes * denominator,
the numerator is replaced with its GCD not larger than nb_notes * denominator.
Example: (10, 4), max_beat_res of 8, and nb_notes of 2 will convert the signature into (5, 4).
:param numerator: time signature's numerator (bar length in beats).
:param denominator: time signature's denominator (beat resolution).
:return: the numerator and denominator of a reduced and decomposed time signature.
"""
max_beat_res, nb_notes = self.config.time_signature_range
for beat_res, beats in time_signature_range.items():
assert beat_res > 0 and math.log2(beat_res).is_integer(), \
f"The beat resolution ({beat_res}) in time signature must be a power of 2"

# reduction (when denominator exceed max_beat_res)
while (
denominator > max_beat_res and denominator % 2 == 0 and numerator % 2 == 0
):
denominator //= 2
numerator //= 2
time_signatures.extend([(nb_beats, beat_res) for nb_beats in beats])

assert denominator <= max_beat_res, (
f"Unsupported time signature ({numerator}/{denominator}), "
f"beat resolution is irreducible to maximum beat resolution {max_beat_res}"
)
return time_signatures

# decomposition (when length of a bar exceed max_nb_beats_per_bar)
while numerator > nb_notes * denominator:
for i in range(2, numerator + 1):
if numerator % i == 0:
numerator //= i
break
@staticmethod
def _compute_ticks_per_bar(time_sig: TimeSignature, time_division: int):
r"""Computes time resolution of one bar in ticks.
return numerator, denominator
:param time_sig: time signature object
:param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
:return: MIDI bar resolution, in ticks/bar
"""
return int(time_division * 4 * time_sig.numerator / time_sig.denominator)

@staticmethod
def _parse_token_time_signature(token_time_sig: str) -> Tuple[int, int]:
Expand All @@ -1196,6 +1169,17 @@ def _parse_token_time_signature(token_time_sig: str) -> Tuple[int, int]:
numerator, denominator = map(int, token_time_sig.split("/"))
return numerator, denominator

def validate_midi_time_signatures(self, midi: MidiFile) -> bool:
r"""Checks if MIDI files contains only time signatures supported by the encoding.
:param midi: MIDI file
:return: boolean indicating whether MIDI file could be processed by the Encoding
"""
if self.config.use_time_signatures:
for time_sig in midi.time_signature_changes:
if (time_sig.numerator, time_sig.denominator) not in self.time_signatures:
return False
return True

def learn_bpe(
self,
vocab_size: int,
Expand Down Expand Up @@ -1509,6 +1493,10 @@ def tokenize_midi_dataset(
if not validation_fn(midi):
continue

# Checks if MIDI contains supported time signatures
if not self.validate_midi_time_signatures(midi):
continue

# Tokenizing the MIDI, without BPE here as this will be done at the end (as we might perform data aug)
tokens = self(midi, apply_bpe_if_possible=False)

Expand Down Expand Up @@ -1743,13 +1731,22 @@ def _load_params(self, config_file_path: Union[str, Path]):
tuple(map(int, beat_range.split("_"))): res
for beat_range, res in value["beat_res"].items()
}
value["time_signature_range"] = {
int(res): beat_range
for res, beat_range in value["time_signature_range"].items()
}
value = TokenizerConfig.from_dict(value)
elif key in config_attributes:
if key == "beat_res":
value = {
tuple(map(int, beat_range.split("_"))): res
for beat_range, res in value.items()
}
elif key == "time_signature_range":
value = {
int(res): beat_range
for res, beat_range in value.items()
}
# Convert old attribute from < v2.1.0 to new for TokenizerConfig
elif key in old_add_tokens_attr:
key = old_add_tokens_attr[key]
Expand Down
4 changes: 3 additions & 1 deletion miditok/tokenizations/cp_word.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import ceil
from typing import List, Tuple, Dict, Optional, Union, Any

import numpy as np
Expand Down Expand Up @@ -395,7 +396,8 @@ def _create_base_vocabulary(self) -> List[List[str]]:
vocab[0].append("Family_Note")

# POSITION
nb_positions = max(self.config.beat_res.values()) * 4 # 4/* time signature
max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures))
nb_positions = max(self.config.beat_res.values()) * max_nb_beats
vocab[1].append("Ignore_None")
vocab[1].append("Bar_None")
vocab[1] += [f"Position_{i}" for i in range(nb_positions)]
Expand Down
8 changes: 3 additions & 5 deletions miditok/tokenizations/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ def _track_to_tokens(self, track: Instrument) -> List[Event]:
# Time events
events.sort(key=lambda x: (x.time, self._order(x)))
time_sig_change = self._current_midi_metadata["time_sig_changes"][0]
first_time_sig = self._reduce_time_signature(
time_sig_change.numerator, time_sig_change.denominator
)
ticks_per_bar = time_division * first_time_sig[0]
ticks_per_bar = self._compute_ticks_per_bar(time_sig_change, time_division)
previous_tick = 0
current_bar = 0
for ei in range(len(events)):
Expand Down Expand Up @@ -263,7 +260,7 @@ def tokens_to_midi(
time_signature_changes = [
TimeSignature(*TIME_SIGNATURE, 0)
] # mock the first time signature change to optimize below
ticks_per_bar = time_division * TIME_SIGNATURE[0] # init
ticks_per_bar = self._compute_ticks_per_bar(time_signature_changes[0], time_division) # init

current_tick = 0
current_bar = -1
Expand Down Expand Up @@ -314,6 +311,7 @@ def tokens_to_midi(
and den != current_time_signature.denominator
):
time_signature_changes.append(TimeSignature(num, den, current_tick))
# ticks_per_bar = self._compute_ticks_per_bar(time_signature_changes[-1], time_division)
elif tok_type == "Pitch":
try:
if (
Expand Down
3 changes: 2 additions & 1 deletion miditok/tokenizations/mumidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def _create_base_vocabulary(self) -> List[List[str]]:
for i in range(*self.config.additional_params["drum_pitch_range"])
]
vocab[0] += ["Bar_None"] # new bar token
nb_positions = max(self.config.beat_res.values()) * 4 # 4/* time signature
max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures))
nb_positions = max(self.config.beat_res.values()) * max_nb_beats
vocab[0] += [f"Position_{i}" for i in range(nb_positions)]
vocab[0] += [f"Program_{program}" for program in self.config.programs]

Expand Down
32 changes: 14 additions & 18 deletions miditok/tokenizations/octuple.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from math import ceil
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union, Any
from typing import List, Dict, Optional, Union, Any

import numpy as np
from miditoolkit import MidiFile, Instrument, Note, TempoChange, TimeSignature
Expand Down Expand Up @@ -134,13 +134,10 @@ def _track_to_tokens(self, track: Instrument) -> List[List[Union[Event, str]]]:
current_time_sig_idx = 0
current_time_sig_tick = 0
current_time_sig_bar = 0
time_sig_change = self._current_midi_metadata["time_sig_changes"][
current_time_sig = self._current_midi_metadata["time_sig_changes"][
current_time_sig_idx
]
current_time_sig = self._reduce_time_signature(
time_sig_change.numerator, time_sig_change.denominator
)
ticks_per_bar = time_division * current_time_sig[0]
ticks_per_bar = self._compute_ticks_per_bar(current_time_sig, time_division)

for note in track.notes:
# Positions and bars
Expand Down Expand Up @@ -205,18 +202,16 @@ def _track_to_tokens(self, track: Instrument) -> List[List[Union[Event, str]]]:
][current_time_sig_idx + 1 :]:
# If this time signature change happened before the current moment
if time_sig_change.time <= note.start:
current_time_sig = self._reduce_time_signature(
time_sig_change.numerator, time_sig_change.denominator
)
current_time_sig = time_sig_change
current_time_sig_idx += 1 # update time signature value (might not change) and index
current_time_sig_bar += (
time_sig_change.time - current_time_sig_tick
) // ticks_per_bar
current_time_sig_tick = time_sig_change.time
ticks_per_bar = time_division * current_time_sig[0]
ticks_per_bar = self._compute_ticks_per_bar(current_time_sig, time_division)
elif time_sig_change.time > note.start:
break # this time signature change is beyond the current time step, we break the loop
token.append(f"TimeSig_{current_time_sig[0]}/{current_time_sig[1]}")
token.append(f"TimeSig_{current_time_sig.numerator}/{current_time_sig.denominator}")

tokens.append(token)

Expand Down Expand Up @@ -273,8 +268,9 @@ def tokens_to_midi(
)
break

ticks_per_bar = time_division * time_sig[0]
time_sig_changes = [TimeSignature(*time_sig, 0)]
time_sig = TimeSignature(*time_sig, 0)
ticks_per_bar = self._compute_ticks_per_bar(time_sig, time_division)
time_sig_changes = [time_sig]

current_time_sig_tick = 0
current_time_sig_bar = 0
Expand Down Expand Up @@ -326,10 +322,9 @@ def tokens_to_midi(
current_bar - current_time_sig_bar
) * ticks_per_bar
current_time_sig_bar = current_bar
ticks_per_bar = time_division * time_sig[0]
time_sig_changes.append(
TimeSignature(*time_sig, current_time_sig_tick)
)
time_sig = TimeSignature(*time_sig, current_time_sig_tick)
ticks_per_bar = self._compute_ticks_per_bar(time_sig, time_division)
time_sig_changes.append(time_sig)

# Tempos
midi.tempo_changes = tempo_changes
Expand Down Expand Up @@ -385,7 +380,8 @@ def _create_base_vocabulary(self) -> List[List[str]]:
vocab[3] += [f"Program_{i}" for i in self.config.programs]

# POSITION
nb_positions = max(self.config.beat_res.values()) * 4 # 4/4 time signature
max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures))
nb_positions = max(self.config.beat_res.values()) * max_nb_beats
vocab[4] += [f"Position_{i}" for i in range(nb_positions)]

# BAR (positional encoding)
Expand Down
3 changes: 2 additions & 1 deletion miditok/tokenizations/octuple_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def _create_base_vocabulary(self) -> List[List[str]]:
]

# POSITION
nb_positions = max(self.config.beat_res.values()) * 4 # 4/4 time signature
max_nb_beats = max(map(lambda ts: ceil(4 * ts[0] / ts[1]), self.time_signatures))
nb_positions = max(self.config.beat_res.values()) * max_nb_beats
vocab[3] += [f"Position_{i}" for i in range(nb_positions)]

# BAR
Expand Down
Loading

0 comments on commit 114d253

Please sign in to comment.