Skip to content

Commit

Permalink
Fix get_beats and get_downbeats for compound meters
Browse files Browse the repository at this point in the history
* Fix 1/2 for #152. Fix in get_downbeats for compound meters

* Fix 2/2 for #152. Fix in qpm_to_bpm for compound meters

* Added compound meter tests (#153)

* Travis check fixes (#153)

* python2 force float division fix (#153)

* Updated docstring for get_beats()
  • Loading branch information
apmcleod authored and craffel committed Dec 5, 2018
1 parent 48fae35 commit 369ce98
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 18 deletions.
21 changes: 17 additions & 4 deletions pretty_midi/pretty_midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,13 @@ def estimate_tempo(self):
return tempi[0]

def get_beats(self, start_time=0.):
"""Return a list of beat locations, according to MIDI tempo changes.
"""Returns a list of beat locations, according to MIDI tempo changes.
For compound meters (any whose numerator is a multiple of 3 greater
than 3), this method returns every third denominator note (for 6/8
or 6/16 time, for example, it will return every third 8th note or
16th note, respectively). For all other meters, this method returns
every denominator note (every quarter note for 3/4 or 4/4 time, for
example).
Parameters
----------
Expand Down Expand Up @@ -712,12 +718,19 @@ def index(array, value, default):
end_beat_idx = index(beats, end_ts.time, start_beat_idx)
# Add beats within this time signature range, skipping beats
# according to the current time signature
downbeats.append(
beats[start_beat_idx:end_beat_idx:start_ts.numerator])
if start_ts.numerator % 3 == 0 and start_ts.numerator != 3:
downbeats.append(beats[
start_beat_idx:end_beat_idx:(start_ts.numerator // 3)])
else:
downbeats.append(beats[
start_beat_idx:end_beat_idx:start_ts.numerator])
# Add in beats from the second-to-last to last time signature
final_ts = time_signatures[-1]
start_beat_idx = index(beats, final_ts.time, end_beat_idx)
downbeats.append(beats[start_beat_idx::final_ts.numerator])
if final_ts.numerator % 3 == 0 and final_ts.numerator != 3:
downbeats.append(beats[start_beat_idx::(final_ts.numerator // 3)])
else:
downbeats.append(beats[start_beat_idx::final_ts.numerator])
# Convert from list to array
downbeats = np.concatenate(downbeats)
# Return all downbeats after start_time
Expand Down
19 changes: 5 additions & 14 deletions pretty_midi/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,26 +237,17 @@ def qpm_to_bpm(quarter_note_tempo, numerator, denominator):
'Time signature denominator must be an int greater than 0, but {} '
'was supplied.'.format(denominator))

# denominator is whole note
if denominator == 1:
return quarter_note_tempo / 4.0
# denominator is half note
elif denominator == 2:
return quarter_note_tempo / 2.0
# denominator is quarter note
elif denominator == 4:
return quarter_note_tempo
# denominator is eighth, sixteenth or 32nd
elif denominator in [8, 16, 32]:
# denominator is whole, half, quarter, eighth, sixteenth or 32nd note
if denominator in [1, 2, 4, 8, 16, 32]:
# simple triple
if numerator == 3:
return 2 * quarter_note_tempo
return quarter_note_tempo * denominator / 4.0
# compound meter 6/8*n, 9/8*n, 12/8*n...
elif numerator % 3 == 0:
return 2.0 * quarter_note_tempo / 3.0
return quarter_note_tempo / 3.0 * denominator / 4.0
# strongly assume two eighths equal a beat
else:
return quarter_note_tempo
return quarter_note_tempo * denominator / 4.0
else:
return quarter_note_tempo

Expand Down
10 changes: 10 additions & 0 deletions tests/test_pretty_midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ def test_get_downbeats():
np.arange(expected_beats[-1] + 3*60./change_bpm,
pm.get_end_time(), 3*60./change_bpm))
assert np.allclose(pm.get_downbeats(2.2), expected_beats)
# Test for compound meters
pm = pretty_midi.PrettyMIDI()
# Add a note to force get_end_time() to be non-zero
i = pretty_midi.Instrument(0)
i.notes.append(pretty_midi.Note(100, 100, 0.3, 20.4))
pm.instruments.append(i)
# Simple test, assume 6/8 time for the entire piece
pm.time_signature_changes.append(pretty_midi.TimeSignature(6, 8, 0))
assert np.allclose(pm.get_downbeats(),
np.arange(0, pm.get_end_time(), 3*60./120.))


def test_adjust_times():
Expand Down
47 changes: 47 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,50 @@ def test_key_name_to_key_number():
for invalid_key in ['C# m', 'C# ma', 'ba', 'bm m', 'f## Major', 'O']:
with pytest.raises(ValueError):
pretty_midi.key_name_to_key_number(invalid_key)


def test_qpm_to_bpm():
# Test that twice the qpm leads to double the bpm for a range of qpm
for qpm in [60, 100, 125.56]:
for num in range(1, 24):
for den in range(1, 64):
assert 2 * pretty_midi.qpm_to_bpm(qpm, num, den) \
== pretty_midi.qpm_to_bpm(qpm * 2, num, den)
# Test that twice the denominator leads to double the bpm for a range
# of denominators (those outside of this set just fall back to the
# default of returning qpm.
for qpm in [60, 100, 125.56]:
for num in range(1, 24):
for den in [1, 2, 4, 8, 16]:
assert 2 * pretty_midi.qpm_to_bpm(qpm, num, den) \
== pretty_midi.qpm_to_bpm(qpm, num, den * 2)
# Check all compound meters
# qpb is quarter notes per beat. qpm / qpb = q/m / q/b = b/m = bpm
for den, qpb in zip([1, 2, 4, 8, 16, 32],
[12.0, 6.0, 3.0, 3/2.0, 3/4.0, 3/8.0]):
for qpm in [60, 120, 125.56]:
for num in range(2 * 3, 8 * 3, 3):
assert pretty_midi.qpm_to_bpm(qpm, num, den) == qpm / qpb
# Check all simple meters
# qpb is quarter notes per beat. qpm / qpb = q/m / q/b = b/m = bpm
for den, qpb in zip([1, 2, 4, 8, 16, 32],
[4.0, 2.0, 1.0, 1/2.0, 1/4.0, 1/8.0]):
for qpm in [60, 120, 125.56]:
for num in range(1, 24):
if num > 3 and num % 3 == 0:
continue
assert pretty_midi.qpm_to_bpm(qpm, num, den) == qpm / qpb
# Test invalid inputs
den = 4
num = 4
for qpm in [-1, 0, 'invalid']:
with pytest.raises(ValueError):
pretty_midi.qpm_to_bpm(qpm, num, den)
qpm = 120
for num in [-1, 0, 4.3, 'invalid']:
with pytest.raises(ValueError):
pretty_midi.qpm_to_bpm(qpm, num, den)
num = 4
for den in [-1, 0, 4.3, 'invalid']:
with pytest.raises(ValueError):
pretty_midi.qpm_to_bpm(qpm, num, den)

0 comments on commit 369ce98

Please sign in to comment.