diff --git a/miditok/utils/split_utils.py b/miditok/utils/split_utils.py index 70c18aeb..312c5598 100644 --- a/miditok/utils/split_utils.py +++ b/miditok/utils/split_utils.py @@ -8,7 +8,8 @@ from typing import TYPE_CHECKING, Any from warnings import warn -from symusic import Score, TextMeta +from symusic import Score, TextMeta, TimeSignature +from symusic.core import TimeSignatureTickList from tqdm import tqdm from miditok.constants import ( @@ -16,6 +17,7 @@ MIDI_FILES_EXTENSIONS, SCORE_LOADING_EXCEPTION, SUPPORTED_MUSIC_FILE_EXTENSIONS, + TIME_SIGNATURE, ) from .utils import ( @@ -112,7 +114,10 @@ def split_files_for_training( except SCORE_LOADING_EXCEPTION: continue - # Separate track first if needed + # First preprocess time signatures to avoid cases where they might cause errors + _preprocess_time_signatures(scores[0], tokenizer) + + # Separate track if needed tracks_separated = False if not tokenizer.one_token_stream and len(scores[0].tracks) > 1: scores = split_score_per_tracks(scores[0]) @@ -308,6 +313,9 @@ def get_average_num_tokens_per_note( if (num_notes := track.note_num()) > 0: num_tokens_per_note.append(len(seq) / num_notes) + if len(num_tokens_per_note) == 0: + msg = "All the music files provided are empty and contain no note." + raise ValueError(msg) return sum(num_tokens_per_note) / len(num_tokens_per_note) @@ -381,3 +389,20 @@ def split_tokens_files_to_subsequences( new_tok = deepcopy(tokens) new_tok["ids"] = subseq json.dump(tokens, outfile) + + +def _preprocess_time_signatures(score: Score, tokenizer: MusicTokenizer) -> None: + """ + Make sure a Score contains time signature valid according to a tokenizer. + + :param score: ``symusic.Score`` to preprocess the time signature. + :param tokenizer: :class:`miditok.MusicTokenizer`. + """ + if tokenizer.config.use_time_signatures: + tokenizer._filter_unsupported_time_signatures(score.time_signatures) + if len(score.time_signatures) == 0 or score.time_signatures[0].time != 0: + score.time_signatures.insert(0, TimeSignature(0, *TIME_SIGNATURE)) + else: + score.time_signatures = TimeSignatureTickList( + [TimeSignature(0, *TIME_SIGNATURE)] + ) diff --git a/miditok/utils/utils.py b/miditok/utils/utils.py index 994232b8..8f8b2f1e 100644 --- a/miditok/utils/utils.py +++ b/miditok/utils/utils.py @@ -780,7 +780,7 @@ def get_num_notes_per_bar( :param tracks_indep: whether to process each track independently or all together. :return: the number of notes within each bar. """ - if len(score.tracks) == 0: + if score.end() == 0: return [] if tracks_indep else [0] # Get bar and note times @@ -789,13 +789,16 @@ def get_num_notes_per_bar( bar_ticks.append(score.end()) tracks_times = [track.notes.numpy()["time"] for track in score.tracks] if not tracks_indep: - tracks_times = [np.concatenate(tracks_times)] - tracks_times[-1].sort() - num_notes_per_bar = [] + if len(tracks_times) > 0: + tracks_times = [np.concatenate(tracks_times)] + tracks_times[-1].sort() + num_notes_per_bar = np.zeros(len(bar_ticks) - 1, dtype=np.int32) else: - num_notes_per_bar = [[] for _ in range(len(bar_ticks) - 1)] + num_notes_per_bar = np.zeros( + (len(bar_ticks) - 1, max(len(score.tracks), 1)), dtype=np.int32 + ) - for notes_times in tracks_times: + for ti, notes_times in enumerate(tracks_times): current_note_time_idx = previous_note_time_idx = 0 current_bar_tick_idx = 0 while current_bar_tick_idx < len(bar_ticks) - 1: @@ -807,14 +810,14 @@ def get_num_notes_per_bar( current_note_time_idx += 1 num_notes = current_note_time_idx - previous_note_time_idx if tracks_indep and len(score.tracks) > 1: - num_notes_per_bar[current_bar_tick_idx].append(num_notes) + num_notes_per_bar[current_bar_tick_idx, ti] = num_notes else: - num_notes_per_bar.append(num_notes) + num_notes_per_bar[current_bar_tick_idx] = num_notes current_bar_tick_idx += 1 previous_note_time_idx = current_note_time_idx - return num_notes_per_bar + return num_notes_per_bar.tolist() def get_score_ticks_per_beat(score: Score) -> np.ndarray: