Skip to content

Commit

Permalink
adding magic methods call, len and getitem for MIDITokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed May 25, 2022
1 parent 960cbfa commit 05c1ab9
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion miditok/midi_tokenizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, pitch_range: range, beat_res: Dict[Tuple[int, int], int], nb_
# MIDI (being parsed) so that methods processing tracks can access them
self.current_midi_metadata = {} # needs to be updated each time a MIDI is read

def midi_to_tokens(self, midi: MidiFile) -> List[List[Union[int, List[int]]]]:
def midi_to_tokens(self, midi: MidiFile, *args, **kwargs) -> List[List[Union[int, List[int]]]]:
""" Converts a MIDI file in a tokens representation.
NOTE: if you override this method, be sure to keep the first lines in your method
Expand Down Expand Up @@ -608,6 +608,22 @@ def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]):
if '_mask' not in params:
self._mask = False

def __call__(self, midi: MidiFile, *args, **kwargs):
return self.midi_to_tokens(midi, *args, **kwargs)

def __len__(self):
return [len(v) for v in self.vocab] if isinstance(self.vocab, list) else len(self.vocab)

def __getitem__(self, item: Union[int, str, Tuple[int, int]]) -> Union[str, int]:
if isinstance(item, str):
return self.vocab.event_to_token[item]
elif isinstance(item, int):
return self.vocab.token_to_event[item]
elif isinstance(item, tuple) and isinstance(self.vocab, list):
return self.vocab[item[0]].token_to_event[item[1]]
else:
raise IndexError('The index must be an integer or a string')


def get_midi_programs(midi: MidiFile) -> List[Tuple[int, bool]]:
""" Returns the list of programs of the tracks of a MIDI, deeping the
Expand Down

0 comments on commit 05c1ab9

Please sign in to comment.