Skip to content

Commit

Permalink
fixing data augmentation example and considering all midi extensions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz authored Jan 24, 2024
1 parent e6bdebd commit db3ad1d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
6 changes: 3 additions & 3 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ We also perform data augmentation on the pitch, velocity and duration dimension.
# Creates the tokenizer and list the file paths
tokenizer = REMI() # using defaults parameters (constants.py)
midi_paths = list(Path("path", "to", "dataset").glob("**/*.mid"))
data_path = Path("path", "to", "dataset")
# A validation method to discard MIDIs we do not want
# It can also be used for custom pre-processing, for instance if you want to merge
Expand All @@ -121,15 +121,15 @@ We also perform data augmentation on the pitch, velocity and duration dimension.
# Performs data augmentation on one pitch octave (up and down), velocities and
# durations
augment_midi_dataset(
midi_paths,
data_path,
pitch_offsets=[-12, 12],
velocity_offsets=[-4, 5],
duration_offsets=[-0.5, 1],
out_path=midi_aug_path,
Path("to", "new", "location", "augmented"),
)
tokenizer.tokenize_midi_dataset( # 2 velocity and 1 duration values
midi_paths,
data_path,
Path("path", "to", "tokens"),
midi_valid,
)
2 changes: 1 addition & 1 deletion miditok/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CURRENT_TOKENIZERS_VERSION = version("tokenizers")
CURRENT_SYMUSIC_VERSION = version("symusic")

MIDI_FILES_EXTENSIONS = [".mid", ".midi", ".MID", ".MIDI"]
MIDI_FILES_EXTENSIONS = {".mid", ".midi", ".MID", ".MIDI"}
MIDI_LOADING_EXCEPTION = (
RuntimeError,
ValueError,
Expand Down
10 changes: 8 additions & 2 deletions miditok/data_augmentation/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from symusic import Score
from tqdm import tqdm

from miditok.constants import MIDI_INSTRUMENTS, MIDI_LOADING_EXCEPTION
from miditok.constants import (
MIDI_FILES_EXTENSIONS,
MIDI_INSTRUMENTS,
MIDI_LOADING_EXCEPTION,
)


def augment_midi_dataset(
Expand Down Expand Up @@ -73,7 +77,9 @@ def augment_midi_dataset(
if isinstance(out_path, str):
out_path = Path(out_path)
out_path.mkdir(parents=True, exist_ok=True)
files_paths = list(Path(data_path).glob("**/*.mid"))
files_paths = [
path for path in data_path.glob("**/*") if path.suffix in MIDI_FILES_EXTENSIONS
]

num_augmentations, num_tracks_augmented = 0, 0
for file_path in tqdm(files_paths, desc="Performing data augmentation"):
Expand Down
9 changes: 5 additions & 4 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,10 +2470,11 @@ def tokenize_midi_dataset(
if isinstance(midi_paths, str):
midi_paths = Path(midi_paths)
root_dir = midi_paths
midi_paths = sum(
(list(midi_paths.glob(f"**/*{ext}")) for ext in MIDI_FILES_EXTENSIONS),
[],
)
midi_paths = [
path
for path in midi_paths.glob("**/*")
if path.suffix in MIDI_FILES_EXTENSIONS
]
# User gave a list of paths, we need to find the root / deepest common subdir
else:
all_parts = [Path(path).parent.parts for path in midi_paths]
Expand Down

0 comments on commit db3ad1d

Please sign in to comment.