Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch_data_loading module #61

Merged
merged 4 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Contents
examples
tokenizations
bpe
pytorch_data
data_augmentation
utils
citations
Expand Down
14 changes: 14 additions & 0 deletions docs/pytorch_data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
========================
PyTorch data loaders
========================

MidiTok features PyTorch `Dataset` objects to load MIDI or token files during training.
You can use them with PyTorch `DataLoader`s or your preferred libraries.
When indexed, the `Dataset`s will output dictionaries with values corresponding to the inputs and labels.

MidiTok also provides an "all-in-one" data collator: :class:`miditok.pytorch_data.DataCollator` to be used with PyTorch `DataLoader`s in order to pad batches, add `BOS` and `EOS` tokens and create attention masks.

**Note:** *This module is imported only if* `torch` *is installed in your Python environment.*

.. automodule:: miditok.pytorch_data
:members:
33 changes: 30 additions & 3 deletions miditok/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .midi_tokenizer import MIDITokenizer, convert_sequence_to_tokseq
from .classes import Event, TokSequence, TokenizerConfig
from .tokenizations import (
MIDILike,
REMI,
Expand All @@ -10,8 +12,33 @@
MuMIDI,
MMM,
)
from .midi_tokenizer import MIDITokenizer, convert_sequence_to_tokseq
from .classes import Event, TokSequence, TokenizerConfig

from .utils import utils
from .data_augmentation import data_augmentation
from miditok import data_augmentation


__all__ = [
"MIDITokenizer",
"convert_sequence_to_tokseq",
"Event",
"TokSequence",
"TokenizerConfig",
"MIDILike",
"REMI",
"REMIPlus",
"TSD",
"Structured",
"Octuple",
"OctupleMono",
"CPWord",
"MuMIDI",
"MMM",
"utils",
"data_augmentation",
]

try:
from miditok import pytorch_data
__all__.append("pytorch_data")
except ImportError as e:
pass
5 changes: 4 additions & 1 deletion miditok/classes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Common classes.
"""
from dataclasses import dataclass
from typing import Union, Any, List, Sequence, Dict, Tuple
from copy import deepcopy
Expand Down Expand Up @@ -114,7 +117,7 @@ def __eq__(self, other) -> bool:
class TokenizerConfig:
r"""
MIDI tokenizer base class, containing common methods and attributes for all tokenizers.
:param pitch_range: (default: range(21, 109)) range of MIDI pitches to use. Pitches can take
:param pitch_range: (default: (21, 109)) range of MIDI pitches to use. Pitches can take
values between 0 and 127 (included).
The `General MIDI 2 (GM2) specifications <https://www.midi.org/specifications-old/item/general-midi-2>`_
indicate the **recommended** ranges of pitches per MIDI program (instrument).
Expand Down
27 changes: 15 additions & 12 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ def tokenize_midi_dataset(
self,
midi_paths: Union[List[str], List[Path]],
out_dir: Union[str, Path],
tokenizer_config_file_name: str = "tokenizer.conf",
validation_fn: Callable[[MidiFile], bool] = None,
data_augment_offsets=None,
apply_bpe: bool = True,
Expand All @@ -1286,20 +1287,23 @@ def tokenize_midi_dataset(

:param midi_paths: paths of the MIDI files.
:param out_dir: output directory to save the converted files.
:param tokenizer_config_file_name: name of the tokenizer config file name. This file will be saved in
`out_dir`. (default: "tokenizer.conf")
:param validation_fn: a function checking if the MIDI is valid on your requirements
(e.g. time signature, minimum/maximum length, instruments ...).
:param data_augment_offsets: data augmentation arguments, to be passed to the
miditok.data_augmentation.data_augmentation_dataset method. Has to be given as a list / tuple
of offsets pitch octaves, velocities, durations, and finally their directions (up/down). (default: None)
:param apply_bpe: will apply BPE on the dataset to save, if the vocabulary was learned with.
:param save_programs: will also save the programs of the tracks of the MIDI. (default: True)
:param apply_bpe: will apply BPE on the dataset to save, if the vocabulary was learned with. (default: True)
:param save_programs: will also save the programs of the tracks of the MIDI. Note that this option is
probably unnecessary when using a multitrack tokenizer, as the Program information is present within the
tokens, and that the tracks having the same programs are likely to have been merged. (default: True)
:param logging: logs progress bar.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
self.save_params(
out_dir / "config.txt"
) # Saves the parameters with which the MIDIs are converted
# Saves the tokenizer so that it can be reloaded
self.save_params(out_dir / tokenizer_config_file_name)

for midi_path in (
tqdm(
Expand All @@ -1317,9 +1321,8 @@ def tokenize_midi_dataset(
if logging:
print(f"File not found: {midi_path}")
continue
except (
Exception
): # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError
except Exception:
# known are ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError
continue

# Checks the time division is valid
Expand All @@ -1330,10 +1333,10 @@ def tokenize_midi_dataset(
if not validation_fn(midi):
continue

# Converting the MIDI to tokens and saving them as json
tokens = self(
midi, apply_bpe_if_possible=False
) # BPE will be applied after if ordered
# Tokenizing the MIDI, without BPE here as this will be done at the end (as we might perform data aug)
tokens = self(midi, apply_bpe_if_possible=False)

# Save the tokens as JSON
self.save_tokens(
tokens,
Path(out_dir, f"{Path(midi_path).stem}.json").with_suffix(".json"),
Expand Down
10 changes: 10 additions & 0 deletions miditok/pytorch_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .datasets import DatasetTok, DatasetJsonIO, split_dataset_to_subsequences, split_seq_in_subsequences
from .collators import DataCollator

__all__ = [
"DatasetTok",
"DatasetJsonIO",
"split_dataset_to_subsequences",
"split_seq_in_subsequences",
"DataCollator",
]
185 changes: 185 additions & 0 deletions miditok/pytorch_data/collators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Collator objects for PyTorch `DataLoader`s.
"""
from typing import List, Any, Dict
from copy import deepcopy

from torch import LongTensor
import torch


class DataCollator:
def __init__(
self,
pad_token_id: int,
bos_token_id: int = None,
eos_token_id: int = None,
pad_on_left: bool = False,
copy_inputs_as_labels: bool = False,
shift_labels: bool = False,
labels_pad_idx: int = -100,
inputs_kwarg_name: str = "input_ids",
labels_kwarg_name: str = "labels",
):
r"""Multifunction data collator, applying padding (right or left), allowing to add `BOS` and `EOS` tokens.
It will also add an "attention_mask" entry to the batch, following the padding applied.

:param pad_token_id: padding token id.
:param bos_token_id: BOS token id. (default: None)
:param eos_token_id: EOS token id. (default: None)
:param pad_on_left: if given True, it will pad the sequences on the left. This can be required when using
some libraries expecting padding on left, for example when generating with Hugging Face Transformers.
(default: False)
:param copy_inputs_as_labels: will add a labels entry (`inputs_kwarg_name`) to the batch
(or replace the existing one), which is a copy to the input entry (`labels_kwarg_name`). (default: False)
:param shift_labels: will shift inputs and labels for autoregressive training / teacher forcing.
(default: False)
:param labels_pad_idx: padding id for labels. (default: -100)
:param inputs_kwarg_name: name of dict / kwarg key for inputs. (default: "input_ids")
:param labels_kwarg_name: name of dict / kwarg key for inputs. (default: "labels")
"""
self.pad_token = pad_token_id
self.bos_token = bos_token_id
self.eos_token = eos_token_id
self.pad_on_left = pad_on_left
self.copy_inputs_as_labels = copy_inputs_as_labels
self.shift_labels = shift_labels
self.labels_pad_idx = labels_pad_idx
self.inputs_kwarg_name = inputs_kwarg_name
self.labels_kwarg_name = labels_kwarg_name

def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, LongTensor]:
out_batch = {}
x, y = None, None

# Figure out inputs + adds BOS and EOS tokens
if self.inputs_kwarg_name in batch[0]:
x = [seq[self.inputs_kwarg_name] for seq in batch]
_add_bos_eos_tokens_to_batch(
x,
bos_tok_id=self.bos_token,
eos_tok_id=self.eos_token,
)

# Figure out labels + adds BOS and EOS tokens
if self.labels_kwarg_name in batch[0]:
y = [seq[self.labels_kwarg_name] for seq in batch]
_add_bos_eos_tokens_to_batch(
y,
bos_tok_id=self.bos_token,
eos_tok_id=self.eos_token,
)
elif self.copy_inputs_as_labels:
y = deepcopy(x)

# Pad inputs / convert to Tensors
if x is not None:
x = _pad_batch(x, self.pad_token, self.pad_on_left)
if y is not None:
if isinstance(y[0], LongTensor):
y = _pad_batch(y, self.labels_pad_idx, self.pad_on_left)
else: # classification
y = torch.stack(y)

# Shift labels, otherwise it's handled by models
if self.shift_labels:
x = x[:, :-1]
y = y[:, 1:]

# Add inputs / labels to output batch
if x is not None:
out_batch[self.inputs_kwarg_name] = x
if y is not None:
out_batch[self.labels_kwarg_name] = y

# Create attention mask (just for padding, causal mask is handled by models)
if x is not None:
attention_mask = (x != self.pad_token).int()
if attention_mask.dim() == 3:
attention_mask = attention_mask[..., 0] # (N,T,Z) --> (N,T)
out_batch["attention_mask"] = attention_mask

return out_batch


def _add_bos_eos_tokens_to_batch(
batch: List[LongTensor],
bos_tok_id: int = None,
eos_tok_id: int = None,
):
"""Adds (inplace) BOS and EOS tokens to inputs.

:param batch: batch as a list of Tensors.
:param bos_tok_id: BOS token id. (default: None)
:param eos_tok_id: EOS token id. (default: None)
"""
if bos_tok_id is None and eos_tok_id is None:
return

sos_shape = list(batch[0].shape)
sos_shape[0] = 1 # (1) or (1,Z)
for i in range(len(batch)):
if bos_tok_id is not None and eos_tok_id is not None:
batch[i] = torch.cat(
[
torch.full(sos_shape, bos_tok_id),
batch[i],
torch.full(sos_shape, eos_tok_id),
],
dim=0,
).long()
elif bos_tok_id is not None:
batch[i] = torch.cat(
[torch.full(sos_shape, bos_tok_id), batch[i]], dim=0
).long()
else: # EOS not None
batch[i] = torch.cat(
[batch[i], torch.full(sos_shape, eos_tok_id)], dim=0
).long()


def _pad_batch(
batch: List[LongTensor],
pad_token_id: int,
pad_on_left: bool = False,
) -> LongTensor:
r"""Pad sequences of a batch.

:param batch: batch as a list of Tensors.
:param pad_token_id: padding token id.
:param pad_on_left: if given True, it will pad the sequences on the left. This can be required when using
some libraries expecting padding on left, for example when generating with Hugging Face Transformers.
(default: False)
:return: the batch sequences, padded into a unique Tensor.
"""
length_of_first = batch[0].size(0)

# Check if padding is necessary.
are_tensors_same_length = all(x.size(0) == length_of_first for x in batch)
if are_tensors_same_length:
return torch.stack(batch, dim=0).long()

# Creating the full tensor and filling it with our data.
if pad_on_left:
return _pad_left(batch, pad_token_id)
else:
return torch.nn.utils.rnn.pad_sequence(
batch, batch_first=True, padding_value=pad_token_id
).long()


def _pad_left(batch: List[LongTensor], pad_token_id: int) -> LongTensor:
r"""Here the sequences are padded to the left, so that the last token along the time dimension
is always the last token of each seq, allowing to efficiently generate by batch

:param batch: batch as a list of Tensors.
:param pad_token_id: padding token id.
:return: the batch sequences, padded into a unique Tensor.
"""
batch = [torch.flip(seq, dims=(0,)) for seq in batch]
batch = torch.nn.utils.rnn.pad_sequence(
batch, batch_first=True, padding_value=pad_token_id
) # (N,T)
batch = torch.flip(batch, dims=(1,)).long()
return batch

Loading