Skip to content

Commit

Permalink
Fixing split methods for empty files (no tracks and/or no notes) (#177)
Browse files Browse the repository at this point in the history
* fixing split methods for empty files (no tracks and/or no notes)

* fix
  • Loading branch information
Natooz authored Jun 5, 2024
1 parent 612d580 commit 0671ca1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
29 changes: 27 additions & 2 deletions miditok/utils/split_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from typing import TYPE_CHECKING, Any
from warnings import warn

from symusic import Score, TextMeta
from symusic import Score, TextMeta, TimeSignature
from symusic.core import TimeSignatureTickList
from tqdm import tqdm

from miditok.constants import (
MAX_NUM_FILES_NUM_TOKENS_PER_NOTE,
MIDI_FILES_EXTENSIONS,
SCORE_LOADING_EXCEPTION,
SUPPORTED_MUSIC_FILE_EXTENSIONS,
TIME_SIGNATURE,
)

from .utils import (
Expand Down Expand Up @@ -112,7 +114,10 @@ def split_files_for_training(
except SCORE_LOADING_EXCEPTION:
continue

# Separate track first if needed
# First preprocess time signatures to avoid cases where they might cause errors
_preprocess_time_signatures(scores[0], tokenizer)

# Separate track if needed
tracks_separated = False
if not tokenizer.one_token_stream and len(scores[0].tracks) > 1:
scores = split_score_per_tracks(scores[0])
Expand Down Expand Up @@ -308,6 +313,9 @@ def get_average_num_tokens_per_note(
if (num_notes := track.note_num()) > 0:
num_tokens_per_note.append(len(seq) / num_notes)

if len(num_tokens_per_note) == 0:
msg = "All the music files provided are empty and contain no note."
raise ValueError(msg)
return sum(num_tokens_per_note) / len(num_tokens_per_note)


Expand Down Expand Up @@ -381,3 +389,20 @@ def split_tokens_files_to_subsequences(
new_tok = deepcopy(tokens)
new_tok["ids"] = subseq
json.dump(tokens, outfile)


def _preprocess_time_signatures(score: Score, tokenizer: MusicTokenizer) -> None:
"""
Make sure a Score contains time signature valid according to a tokenizer.
:param score: ``symusic.Score`` to preprocess the time signature.
:param tokenizer: :class:`miditok.MusicTokenizer`.
"""
if tokenizer.config.use_time_signatures:
tokenizer._filter_unsupported_time_signatures(score.time_signatures)
if len(score.time_signatures) == 0 or score.time_signatures[0].time != 0:
score.time_signatures.insert(0, TimeSignature(0, *TIME_SIGNATURE))
else:
score.time_signatures = TimeSignatureTickList(
[TimeSignature(0, *TIME_SIGNATURE)]
)
21 changes: 12 additions & 9 deletions miditok/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def get_num_notes_per_bar(
:param tracks_indep: whether to process each track independently or all together.
:return: the number of notes within each bar.
"""
if len(score.tracks) == 0:
if score.end() == 0:
return [] if tracks_indep else [0]

# Get bar and note times
Expand All @@ -789,13 +789,16 @@ def get_num_notes_per_bar(
bar_ticks.append(score.end())
tracks_times = [track.notes.numpy()["time"] for track in score.tracks]
if not tracks_indep:
tracks_times = [np.concatenate(tracks_times)]
tracks_times[-1].sort()
num_notes_per_bar = []
if len(tracks_times) > 0:
tracks_times = [np.concatenate(tracks_times)]
tracks_times[-1].sort()
num_notes_per_bar = np.zeros(len(bar_ticks) - 1, dtype=np.int32)
else:
num_notes_per_bar = [[] for _ in range(len(bar_ticks) - 1)]
num_notes_per_bar = np.zeros(
(len(bar_ticks) - 1, max(len(score.tracks), 1)), dtype=np.int32
)

for notes_times in tracks_times:
for ti, notes_times in enumerate(tracks_times):
current_note_time_idx = previous_note_time_idx = 0
current_bar_tick_idx = 0
while current_bar_tick_idx < len(bar_ticks) - 1:
Expand All @@ -807,14 +810,14 @@ def get_num_notes_per_bar(
current_note_time_idx += 1
num_notes = current_note_time_idx - previous_note_time_idx
if tracks_indep and len(score.tracks) > 1:
num_notes_per_bar[current_bar_tick_idx].append(num_notes)
num_notes_per_bar[current_bar_tick_idx, ti] = num_notes
else:
num_notes_per_bar.append(num_notes)
num_notes_per_bar[current_bar_tick_idx] = num_notes

current_bar_tick_idx += 1
previous_note_time_idx = current_note_time_idx

return num_notes_per_bar
return num_notes_per_bar.tolist()


def get_score_ticks_per_beat(score: Score) -> np.ndarray:
Expand Down

0 comments on commit 0671ca1

Please sign in to comment.