Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added missing functions of torchaudio.functional #685

Merged
merged 1 commit into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 262 additions & 0 deletions src/TorchSharp/TorchAudio/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Runtime.InteropServices;

using static TorchSharp.torch;
using System.Diagnostics;

// A number of implementation details in this file have been translated from the Python version or torchvision,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this comment should refer to 'torchaudio' instead of 'torchvision'? :-)

// largely located in the files found in this folder:
Expand Down Expand Up @@ -150,6 +151,267 @@ public static torch.Tensor inverse_spectrogram(torch.Tensor spectrogram, long? l
}
}

private static ScalarType _get_complex_dtype(torch.ScalarType real_dtype)
{
if (real_dtype == ScalarType.Float64)
return ScalarType.ComplexFloat64;
if (real_dtype == ScalarType.Float32)
return ScalarType.ComplexFloat32;
if (real_dtype == ScalarType.Float16)
return ScalarType.ComplexFloat32;
throw new ArgumentException($"Unexpected dtype {real_dtype}");
}

/// <summary>
/// Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
/// </summary>
/// <param name="specgram">A magnitude-only STFT spectrogram of dimension `(..., freq, frames)` where freq is ``n_fft // 2 + 1``.</param>
/// <param name="window">Window tensor that is applied/multiplied to each frame/window</param>
/// <param name="n_fft">Size of FFT, creates ``n_fft // 2 + 1`` bins</param>
/// <param name="hop_length">Length of hop between STFT windows.</param>
/// <param name="win_length">Window size.</param>
/// <param name="power">Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc.</param>
/// <param name="n_iter">Number of iteration for phase recovery process.</param>
/// <param name="momentum">The momentum parameter for fast Griffin-Lim.</param>
/// <param name="length">Array length of the expected output.</param>
/// <param name="rand_init">Initializes phase randomly if True, to zero otherwise.</param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public static Tensor griffinlim(Tensor specgram, Tensor window, long n_fft, long hop_length, long win_length, double power, int n_iter, double momentum, long? length, bool rand_init)
{
if (momentum < 0.0 || 1.0 <= momentum) {
throw new ArgumentOutOfRangeException($"momentum must be in range [0, 1). Found: {momentum}");
}
momentum = momentum / (1 + momentum);

// pack batch
var shape = specgram.size();
specgram = specgram.reshape(new long[] { -1, shape[shape.Length - 2], shape[shape.Length - 1] });

specgram = specgram.pow(1 / power);

// initialize the phase
Tensor angles;
if (rand_init) {
angles = torch.rand(specgram.size(), dtype: _get_complex_dtype(specgram.dtype), device: specgram.device);
} else {
angles = torch.full(specgram.size(), 1, dtype: _get_complex_dtype(specgram.dtype), device: specgram.device);
}

// And initialize the previous iterate to 0
var tprev = torch.tensor(0.0, dtype: specgram.dtype, device: specgram.device);
for (int i = 0; i < n_iter; i++) {
// Invert with our current estimate of the phases
var inverse = torch.istft(
specgram * angles, n_fft: n_fft, hop_length: hop_length, win_length: win_length, window: window, length: length ?? -1
);

// Rebuild the spectrogram
var rebuilt = torch.stft(
input: inverse,
n_fft: n_fft,
hop_length: hop_length,
win_length: win_length,
window: window,
center: true,
pad_mode: PaddingModes.Reflect,
normalized: false,
onesided: true,
return_complex: true);

// Update our phase estimates
angles = rebuilt;
if (momentum > 0.0) {
angles = angles - tprev.mul_(momentum);
}
angles = angles.div(angles.abs().add(1e-16));

// Store the previous iterate
tprev = rebuilt;
}

// Return the final phase estimates
var waveform = torch.istft(
specgram * angles, n_fft: n_fft, hop_length: hop_length, win_length: win_length, window: window, length: length ?? -1
);

// unpack batch
var new_shape = new long[shape.Length - 1];
Array.Copy(shape, new_shape, shape.Length - 2);
new_shape[new_shape.Length - 1] = waveform.shape[waveform.dim() - 1];
waveform = waveform.reshape(new_shape);

return waveform;
}

/// <summary>
/// Turn a spectrogram from the power/amplitude scale to the decibel scale.
/// </summary>
/// <param name="x">Input spectrogram(s) before being converted to decibel scale.</param>
/// <param name="multiplier">Use 10. for power and 20. for amplitude</param>
/// <param name="amin">Number to clamp x</param>
/// <param name="db_multiplier">Log10(max(reference value and amin))</param>
/// <param name="top_db">Minimum negative cut-off in decibels.</param>
/// <returns>Output tensor in decibel scale</returns>
public static Tensor amplitude_to_DB(Tensor x, double multiplier, double amin, double db_multiplier, double? top_db = null)
{
var x_db = multiplier * torch.log10(torch.clamp(x, min: amin));
x_db -= multiplier * db_multiplier;

if (top_db != null) {
// Expand batch
var shape = x_db.size();
var packed_channels = x_db.dim() > 2 ? shape[shape.Length - 3] : 1;
x_db = x_db.reshape(-1, packed_channels, shape[shape.Length - 2], shape[shape.Length - 1]);

x_db = torch.maximum(x_db, (x_db.amax(dims: new long[] { -3, -2, -1 }) - top_db).view(-1, 1, 1, 1));

// Repack batch
x_db = x_db.reshape(shape);
}
return x_db;
}

/// <summary>
/// Turn a tensor from the decibel scale to the power/amplitude scale.
/// </summary>
/// <param name="x">Input tensor before being converted to power/amplitude scale.</param>
/// <param name="ref">Reference which the output will be scaled by.</param>
/// <param name="power">If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.</param>
/// <returns>Output tensor in power/amplitude scale.</returns>
public static Tensor DB_to_amplitude(Tensor x, double @ref, double power)
{
return @ref * torch.pow(torch.pow(10.0, 0.1 * x), power);
}

private static double _hz_to_mel(double freq, MelScale mel_scale = MelScale.htk)
{
if (mel_scale == MelScale.htk) {
return 2595.0 * Math.Log10(1.0 + freq / 700.0);
}

// Fill in the linear part
var f_min = 0.0;
var f_sp = 200.0 / 3;

var mels = (freq - f_min) / f_sp;

// Fill in the log-scale part
var min_log_hz = 1000.0;
var min_log_mel = (min_log_hz - f_min) / f_sp;
var logstep = Math.Log(6.4) / 27.0;

if (freq >= min_log_hz) {
mels = min_log_mel + Math.Log(freq / min_log_hz) / logstep;
}

return mels;
}

private static Tensor _mel_to_hz(Tensor mels, MelScale mel_scale = MelScale.htk)
{
if (mel_scale == MelScale.htk) {
return 700.0 * (torch.pow(10.0, mels / 2595.0) - 1.0);
}

// Fill in the linear scale
var f_min = 0.0;
var f_sp = 200.0 / 3;

var freqs = f_min + f_sp * mels;

// And now the nonlinear scale
var min_log_hz = 1000.0;
var min_log_mel = (min_log_hz - f_min) / f_sp;
var logstep = Math.Log(6.4) / 27.0;

var log_t = mels >= min_log_mel;
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel));

return freqs;
}

private static Tensor _create_triangular_filterbank(Tensor all_freqs, Tensor f_pts)
{
// Adopted from Librosa
// calculate the difference between each filter mid point and each stft freq point in hertz
var f_diff = f_pts[TensorIndex.Slice(1, null)] - f_pts[TensorIndex.Slice(null, -1)]; // (n_filter + 1)
var slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1); // (n_freqs, n_filter + 2)
// create overlapping triangles
var zero = torch.zeros(1);
var down_slopes = (-1.0 * slopes[TensorIndex.Colon, TensorIndex.Slice(null, -2)]) / f_diff[TensorIndex.Slice(null, -1)]; // (n_freqs, n_filter)
var up_slopes = slopes[TensorIndex.Colon, TensorIndex.Slice(2, null)] / f_diff[TensorIndex.Slice(1, null)]; // (n_freqs, n_filter)
var fb = torch.maximum(zero, torch.minimum(down_slopes, up_slopes));

return fb;
}

/// <summary>
/// Create a frequency bin conversion matrix.
/// </summary>
/// <param name="n_freqs">Number of frequencies to highlight/apply</param>
/// <param name="f_min">Minimum frequency(Hz)</param>
/// <param name="f_max">Maximum frequency(Hz)</param>
/// <param name="n_mels">Number of mel filterbanks</param>
/// <param name="sample_rate">Sample rate of the audio waveform</param>
/// <param name="norm">If MelNorm.slaney, divide the triangular mel weights by the width of the mel band</param>
/// <param name="mel_scale">Scale to use</param>
/// <returns>Triangular filter banks</returns>
public static Tensor melscale_fbanks(int n_freqs, double f_min, double f_max, int n_mels, int sample_rate, MelNorm norm = MelNorm.none, MelScale mel_scale = MelScale.htk)
{
// freq bins
var all_freqs = torch.linspace(0, sample_rate / 2, n_freqs);

// calculate mel freq bins
var m_min = _hz_to_mel(f_min, mel_scale: mel_scale);
var m_max = _hz_to_mel(f_max, mel_scale: mel_scale);

var m_pts = torch.linspace(m_min, m_max, n_mels + 2);
var f_pts = _mel_to_hz(m_pts, mel_scale: mel_scale);

// create filterbank
var fb = _create_triangular_filterbank(all_freqs, f_pts);

if (norm == MelNorm.slaney) {
// Slaney-style mel is scaled to be approx constant energy per channel
var enorm = 2.0 / (f_pts[TensorIndex.Slice(2, n_mels + 2)] - f_pts[TensorIndex.Slice(null, n_mels)]);
fb *= enorm.unsqueeze(0);
}

if ((fb.max(dim: 0).values == 0.0).any().item<bool>()) {
Debug.Print(
"At least one mel filterbank has all zero values. " +
$"The value for `n_mels` ({n_mels}) may be set too high. " +
$"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
);
}

return fb;
}

/// <summary>
/// Creates a linear triangular filterbank.
/// </summary>
/// <param name="n_freqs">Number of frequencies to highlight/apply</param>
/// <param name="f_min">Minimum frequency (Hz)</param>
/// <param name="f_max">Maximum frequency (Hz)</param>
/// <param name="n_filter">Number of (linear) triangular filter</param>
/// <param name="sample_rate">Sample rate of the audio waveform</param>
/// <returns>Triangular filter banks</returns>
public static Tensor linear_fbanks(int n_freqs, double f_min, double f_max, int n_filter, int sample_rate)
{
// freq bins
var all_freqs = torch.linspace(0, sample_rate / 2, n_freqs);

// filter mid-points
var f_pts = torch.linspace(f_min, f_max, n_filter + 2);

// create filterbank
var fb = _create_triangular_filterbank(all_freqs, f_pts);

return fb;
}

/// <summary>
/// Resample the waveform
/// </summary>
Expand Down
15 changes: 15 additions & 0 deletions src/TorchSharp/TorchAudio/MelNorm.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
namespace TorchSharp
{
public static partial class torchaudio
{
/// <summary>
/// Normalization type of mel filterbanks
/// </summary>
public enum MelNorm
{
none,
slaney
}
}
}
15 changes: 15 additions & 0 deletions src/TorchSharp/TorchAudio/MelScale.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
namespace TorchSharp
{
public static partial class torchaudio
{
/// <summary>
/// Scale type of mel filterbanks
/// </summary>
public enum MelScale
{
slaney,
htk
}
}
}
77 changes: 77 additions & 0 deletions test/TorchSharpTest/TestTorchAudio.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,83 @@ public void TransformsInverseSpectrogram()
Assert.InRange(mse, 0f, 1e-10f);
}

[Fact]
public void TestAmplitudeToDB()
{
var x = torch.linspace(0, 3.0, 101)[torch.TensorIndex.None, torch.TensorIndex.Colon];
var y = torchaudio.functional.amplitude_to_DB(x, 20, 1e-10, 0.0, 80);
var z = 20.0 * torch.log10(torch.clamp(x, min: 1e-10));
z = torch.clamp(z, min: torch.max(z) - 80);
var mse = torch.mean(torch.square(y - z)).item<float>();
Assert.InRange(mse, 0f, 1e-10f);
}

[Fact]
public void TestDBToAmplitude()
{
var x = torch.linspace(-20.0, 0.0, 101)[torch.TensorIndex.None, torch.TensorIndex.Colon];
var y = torchaudio.functional.DB_to_amplitude(x, 1.0, 0.5);
var z = torch.pow(torch.pow(10.0, 0.1 * x), 0.5);
var mse = torch.mean(torch.square(y - z)).item<float>();
Assert.InRange(mse, 0f, 1e-10f);
}

[Fact]
public void TestGriffinLim()
{
var waveform = make_waveform();
var window = torch.hann_window(400);
var specgram = torchaudio.functional.spectrogram(
waveform: waveform,
pad: 200,
window: window,
n_fft: 512,
hop_length: 160,
win_length: 400,
power: 2.0,
normalized: false);
var recovered_waveform = torchaudio.functional.griffinlim(
specgram: specgram,
window: window,
n_fft: 512,
hop_length: 160,
win_length: 400,
power: 2.0,
n_iter: 32,
momentum: 0.99,
length: null,
rand_init: true);
Assert.Equal(new long[] { 1, 80320 }, recovered_waveform.shape);
}

[Fact]
public void TestMelscaleFbanks()
{
int n_freqs = 257;
double f_min = 50;
double f_max = 7600;
int n_mels = 64;
int sample_rate = 16000;
var fb = torchaudio.functional.melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate);
Assert.Equal(new long[] { n_freqs, n_mels }, fb.shape);
// Sum of all banks should be 1.0
Assert.True((fb.sum(dim: 1)[torch.TensorIndex.Slice(3, -23)] == 1.0).all().item<bool>());
}

[Fact]
public void TestLinearFbanks()
{
int n_freqs = 257;
double f_min = 50;
double f_max = 7600;
int n_filter = 64;
int sample_rate = 16000;
var fb = torchaudio.functional.linear_fbanks(n_freqs, f_min, f_max, n_filter, sample_rate);
Assert.Equal(new long[] { n_freqs, n_filter }, fb.shape);
// Sum of all banks should be 1.0
Assert.True((fb.sum(dim: 1)[torch.TensorIndex.Slice(6, -17)] == 1.0).all().item<bool>());
}

[Fact]
public void TestFunctionalResampleIdent()
{
Expand Down