Skip to content

Commit

Permalink
Add variable time-stretching and pitch-shifting.
Browse files Browse the repository at this point in the history
  • Loading branch information
psobot committed May 28, 2024
1 parent 273eb53 commit c7fc4fd
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 33 deletions.
193 changes: 162 additions & 31 deletions pedalboard/TimeStretch.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ static const int MAX_SEMITONES_TO_PITCH_SHIFT = 72;
*/
static juce::AudioBuffer<float>
timeStretch(const juce::AudioBuffer<float> input, double sampleRate,
double stretchFactor, double pitchShiftInSemitones,
std::variant<double, std::vector<double>> stretchFactor,
std::variant<double, std::vector<double>> pitchShiftInSemitones,
bool highQuality, std::string transientMode,
std::string transientDetector, bool retainPhaseContinuity,
std::optional<bool> useLongFFTWindow, bool useTimeDomainSmoothing,
Expand Down Expand Up @@ -99,17 +100,87 @@ timeStretch(const juce::AudioBuffer<float> input, double sampleRate,
}

SuppressOutput suppress_cerr(std::cerr);
double initialStretchFactor = 1;
double initialPitchShiftInSemitones = 0;
size_t expectedNumberOfOutputSamples = 0;
if (auto *constantStretchFactor = std::get_if<double>(&stretchFactor)) {
if (*constantStretchFactor == 0)
throw std::domain_error(
"stretch_factor must be greater than 0.0x, but was passed " +
std::to_string(*constantStretchFactor) + "x.");

initialStretchFactor = *constantStretchFactor;
expectedNumberOfOutputSamples =
(((double)input.getNumSamples()) / *constantStretchFactor);
} else if (auto *variableStretchFactor =
std::get_if<std::vector<double>>(&stretchFactor)) {
for (int i = 0; i < variableStretchFactor->size(); i++) {
if (variableStretchFactor->data()[i] == 0)
throw std::domain_error(
"stretch_factor must be greater than 0.0x, but element at index " +
std::to_string(i) + " was " +
std::to_string(variableStretchFactor->data()[i]) + "x.");
}
expectedNumberOfOutputSamples =
(((double)input.getNumSamples()) /
*std::min_element(variableStretchFactor->begin(),
variableStretchFactor->end()));

if (variableStretchFactor->size() != input.getNumSamples())
throw std::domain_error(
"stretch_factor must be the same length as the input audio "
"buffer, but was passed an array of length " +
std::to_string(variableStretchFactor->size()) + " instead of " +
std::to_string(input.getNumSamples()) + " samples.");

options |= RubberBandStretcher::OptionProcessRealTime;
}

if (auto *constantPitchShiftInSemitones =
std::get_if<double>(&pitchShiftInSemitones)) {
if (*constantPitchShiftInSemitones < -MAX_SEMITONES_TO_PITCH_SHIFT ||
*constantPitchShiftInSemitones > MAX_SEMITONES_TO_PITCH_SHIFT)
throw std::domain_error(
"pitch_shift_in_semitones must be between -" +
std::to_string(MAX_SEMITONES_TO_PITCH_SHIFT) + " and +" +
std::to_string(MAX_SEMITONES_TO_PITCH_SHIFT) +
" semitones, but was passed " +
std::to_string(*constantPitchShiftInSemitones) + " semitones.");
initialPitchShiftInSemitones = *constantPitchShiftInSemitones;
} else if (auto *variablePitchShift =
std::get_if<std::vector<double>>(&pitchShiftInSemitones)) {
for (int i = 0; i < variablePitchShift->size(); i++) {
if (variablePitchShift->data()[i] < -MAX_SEMITONES_TO_PITCH_SHIFT ||
variablePitchShift->data()[i] > MAX_SEMITONES_TO_PITCH_SHIFT)
throw std::domain_error(
"pitch_shift_in_semitones must be between -" +
std::to_string(MAX_SEMITONES_TO_PITCH_SHIFT) + " and +" +
std::to_string(MAX_SEMITONES_TO_PITCH_SHIFT) +
" semitones, but element at index " + std::to_string(i) + " was " +
std::to_string(variablePitchShift->data()[i]) + " semitones.");
}

if (variablePitchShift->size() != input.getNumSamples())
throw std::domain_error(
"pitch_shift_in_semitones must be the same length as the input audio "
"buffer, but was passed an array of length " +
std::to_string(variablePitchShift->size()) + " instead of " +
std::to_string(input.getNumSamples()) + " samples.");
options |= RubberBandStretcher::OptionProcessRealTime;
}

RubberBandStretcher rubberBandStretcher(
sampleRate, input.getNumChannels(), options, 1.0 / stretchFactor,
pow(2.0, (pitchShiftInSemitones / 12.0)));
rubberBandStretcher.setMaxProcessSize(input.getNumSamples());
sampleRate, input.getNumChannels(), options, 1.0 / initialStretchFactor,
pow(2.0, (initialPitchShiftInSemitones / 12.0)));

rubberBandStretcher.setExpectedInputDuration(input.getNumSamples());

rubberBandStretcher.study(input.getArrayOfReadPointers(),
input.getNumSamples(), true);
if (!(options & RubberBandStretcher::OptionProcessRealTime)) {
rubberBandStretcher.study(input.getArrayOfReadPointers(),
input.getNumSamples(), true);
rubberBandStretcher.setMaxProcessSize(input.getNumSamples());
}

size_t expectedNumberOfOutputSamples =
(((double)input.getNumSamples()) / stretchFactor);
juce::AudioBuffer<float> output(input.getNumChannels(),
expectedNumberOfOutputSamples);

Expand All @@ -120,6 +191,13 @@ timeStretch(const juce::AudioBuffer<float> input, double sampleRate,
/* avoidReallocating */ true);

size_t blockSize = rubberBandStretcher.getProcessSizeLimit();
if (options & RubberBandStretcher::OptionProcessRealTime) {
// Process a small number of samples at a time in real-time
// mode to allow for variable stretch factors and pitch shifts.
// TODO: Make this configurable!
blockSize = 16;
}

const float **inputChannelPointers =
(const float **)alloca(sizeof(float *) * input.getNumChannels());
float **outputChannelPointers =
Expand All @@ -136,6 +214,19 @@ timeStretch(const juce::AudioBuffer<float> input, double sampleRate,
inputChannelPointers[c] = input.getReadPointer(c, i);
}

if (options & RubberBandStretcher::OptionProcessRealTime) {
if (auto *variableStretchFactor =
std::get_if<std::vector<double>>(&stretchFactor)) {
rubberBandStretcher.setTimeRatio(variableStretchFactor->data()[i]);
}

if (auto *variablePitchShift =
std::get_if<std::vector<double>>(&pitchShiftInSemitones)) {
double scale = pow(2.0, (variablePitchShift->data()[i] / 12.0));
rubberBandStretcher.setPitchScale(scale);
}
}

rubberBandStretcher.process(inputChannelPointers, chunkSize, isLastCall);
}

Expand Down Expand Up @@ -165,35 +256,63 @@ inline void init_time_stretch(py::module &m) {
m.def(
"time_stretch",
[](py::array_t<float, py::array::c_style> input, double sampleRate,
double stretchFactor, double pitchShiftInSemitones, bool highQuality,
std::string transientMode, std::string transientDetector,
bool retainPhaseContinuity, std::optional<bool> useLongFFTWindow,
bool useTimeDomainSmoothing, bool preserveFormants) {
if (stretchFactor == 0)
throw std::domain_error(
"stretch_factor must be greater than 0.0x, but was passed " +
std::to_string(stretchFactor) + "x.");

if (pitchShiftInSemitones < -MAX_SEMITONES_TO_PITCH_SHIFT ||
pitchShiftInSemitones > MAX_SEMITONES_TO_PITCH_SHIFT)
throw std::domain_error(
"pitch_shift_in_semitones must be between -" +
std::to_string(MAX_SEMITONES_TO_PITCH_SHIFT) + " and +" +
std::to_string(MAX_SEMITONES_TO_PITCH_SHIFT) +
" semitones, but was passed " +
std::to_string(pitchShiftInSemitones) + " semitones.");
std::variant<double, py::array_t<double, py::array::c_style>>
stretchFactor,
std::variant<double, py::array_t<double, py::array::c_style>>
pitchShiftInSemitones,
bool highQuality, std::string transientMode,
std::string transientDetector, bool retainPhaseContinuity,
std::optional<bool> useLongFFTWindow, bool useTimeDomainSmoothing,
bool preserveFormants) {
// Convert from Python arrays to std::vector<double> or double:
std::variant<double, std::vector<double>> cppStretchFactor;
if (auto *variableStretchFactor =
std::get_if<py::array_t<double, py::array::c_style>>(
&stretchFactor)) {
py::buffer_info inputInfo = variableStretchFactor->request();
if (inputInfo.ndim != 1) {
throw std::domain_error(
"stretch_factor must be a one-dimensional array of "
"double-precision floating point numbers, but a " +
std::to_string(inputInfo.ndim) +
"-dimensional array was provided.");
}
cppStretchFactor = std::vector<double>(
static_cast<double *>(inputInfo.ptr),
static_cast<double *>(inputInfo.ptr) + inputInfo.size);
} else {
cppStretchFactor = std::get<double>(stretchFactor);
}

std::variant<double, std::vector<double>> cppPitchShift;
if (auto *variablePitchShift =
std::get_if<py::array_t<double, py::array::c_style>>(
&pitchShiftInSemitones)) {
py::buffer_info inputInfo = variablePitchShift->request();
if (inputInfo.ndim != 1) {
throw std::domain_error(
"stretch_factor must be a one-dimensional array of "
"double-precision floating point numbers, but a " +
std::to_string(inputInfo.ndim) +
"-dimensional array was provided.");
}
cppPitchShift = std::vector<double>(
static_cast<double *>(inputInfo.ptr),
static_cast<double *>(inputInfo.ptr) + inputInfo.size);
} else {
cppPitchShift = std::get<double>(pitchShiftInSemitones);
}

juce::AudioBuffer<float> inputBuffer =
convertPyArrayIntoJuceBuffer(input, detectChannelLayout(input));

juce::AudioBuffer<float> output;
{
py::gil_scoped_release release;
output = timeStretch(inputBuffer, sampleRate, stretchFactor,
pitchShiftInSemitones, highQuality,
transientMode, transientDetector,
retainPhaseContinuity, useLongFFTWindow,
useTimeDomainSmoothing, preserveFormants);
output = timeStretch(inputBuffer, sampleRate, cppStretchFactor,
cppPitchShift, highQuality, transientMode,
transientDetector, retainPhaseContinuity,
useLongFFTWindow, useTimeDomainSmoothing,
preserveFormants);
}

return copyJuceBufferIntoPyArray(output, detectChannelLayout(input), 0);
Expand All @@ -210,6 +329,12 @@ operation. The ``stretch_factor`` and ``pitch_shift_in_semitones`` arguments are
independent and do not affect each other (i.e.: you can change one, the other, or both
without worrying about how they interact).
Both ``stretch_factor`` and ``pitch_shift_in_semitones`` can be either floating-point
numbers or NumPy arrays of double-precision floating point numbers. If NumPy arrays
are provided, the length of the array must match the number of samples in the input
audio buffer, allowing dynamic transitions of the stretch factor and pitch shift over\
the length of the buffer.
The additional arguments provided to this function allow for more fine-grained control
over the behavior of the time stretcher:
Expand Down Expand Up @@ -245,6 +370,12 @@ over the behavior of the time stretcher:
This is a function, not a :py:class:`Plugin` instance, and cannot be
used in :py:class:`Pedalboard` objects, as it changes the duration of
the audio stream.
.. info::
The ability to pass a NumPy array for ``stretch_factor`` and
``pitch_shift_in_semitones`` was added in Pedalboard v0.9.7.
)",
py::arg("input_audio"), py::arg("samplerate"),
py::arg("stretch_factor") = 1.0,
Expand Down
65 changes: 63 additions & 2 deletions tests/test_time_stretch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#! /usr/bin/env python
#
# Copyright 2023 Spotify AB
# Copyright 2024 Spotify AB
#
# Licensed under the GNU Public License, Version 3.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,9 @@
# limitations under the License.


import pytest
import numpy as np
import pytest

from pedalboard import time_stretch


Expand Down Expand Up @@ -104,3 +105,63 @@ def test_time_stretch_long_passthrough(
high_quality=high_quality,
)
np.testing.assert_allclose(output[0], sine_wave, atol=0.25)


@pytest.mark.parametrize("semitones_start", [-1, 0, 1])
@pytest.mark.parametrize("semitones_end", [-1, 0, 1])
@pytest.mark.parametrize("stretch_factor_start", [0.75, 1, 1.25])
@pytest.mark.parametrize("stretch_factor_end", [0.75, 1, 1.25])
def test_time_stretch_with_array(
semitones_start, semitones_end, stretch_factor_start, stretch_factor_end
):
sample_rate = 44100
fundamental_hz = 440
num_seconds = 1.0
samples = np.arange(num_seconds * sample_rate)
sine_wave = np.sin(2 * np.pi * fundamental_hz * samples / sample_rate).astype(np.float32)

output = time_stretch(
sine_wave,
sample_rate,
stretch_factor=np.linspace(stretch_factor_start, stretch_factor_end, sine_wave.shape[0]),
pitch_shift_in_semitones=np.linspace(semitones_start, semitones_end, sine_wave.shape[0]),
)

assert np.all(np.isfinite(output))


def test_time_stretch_mismatched_buffer_length_and_stretch_factors():
with pytest.raises(ValueError):
time_stretch(
np.zeros((1, 10), dtype=np.float32), 44100, stretch_factor=np.linspace(0.1, 1.0, 11)
)


def test_time_stretch_mismatched_buffer_length_and_pitch_shift():
with pytest.raises(ValueError) as e:
time_stretch(
np.zeros((1, 10), dtype=np.float32),
44100,
pitch_shift_in_semitones=np.linspace(0.1, 1.0, 11),
)
assert "buffer" in str(e)


def test_time_stretch_variable_stretch_factor_out_of_range():
with pytest.raises(ValueError) as e:
time_stretch(
np.zeros((1, 10), dtype=np.float32),
44100,
stretch_factor=np.zeros((10,), dtype=np.float32),
)
assert "element at index 0 was 0" in str(e)


def test_time_stretch_variable_pitch_shift_out_of_range():
with pytest.raises(ValueError) as e:
time_stretch(
np.zeros((1, 10), dtype=np.float32),
44100,
pitch_shift_in_semitones=np.ones((10,), dtype=np.float32) * 73,
)
assert "element at index 0 was 73" in str(e)

0 comments on commit c7fc4fd

Please sign in to comment.