Skip to content

Commit

Permalink
Added torch.stft() and torch.istft().
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiidams committed May 15, 2022
1 parent b22cede commit ada6545
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ EXPORT_API(Tensor) THSTensor_isneginf(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_isreal(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, bool onesided, int64_t length, bool return_complex);

EXPORT_API(Scalar) THSTensor_item(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_kron(const Tensor left, const Tensor right);
Expand Down Expand Up @@ -1056,6 +1058,8 @@ EXPORT_API(Tensor) THSTensor_sqrt_(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_std(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, bool onesided, bool return_complex);

EXPORT_API(Tensor) THSTensor_std_along_dimensions(const Tensor tensor, const int64_t* dimensions, int length, bool unbiased, bool keepdim);

EXPORT_API(Tensor) THSTensor_sub(const Tensor left, const Tensor right);
Expand Down
17 changes: 17 additions & 0 deletions src/Native/LibTorchSharp/THSTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,15 @@ Tensor THSTensor_histc(const Tensor tensor, const int64_t bins, const int64_t mi
CATCH_TENSOR(tensor->histc(bins, min, max));
}

Tensor THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, bool onesided, int64_t length, bool return_complex)
{
auto _hop_length = hop_length == -1 ? c10::optional<int64_t>() : c10::optional<int64_t>(hop_length);
auto _win_length = win_length == -1 ? c10::optional<int64_t>() : c10::optional<int64_t>(win_length);
auto _window = window == nullptr ? c10::optional<at::Tensor>() : *window;
auto _length = length == -1 ? c10::optional<int64_t>() : c10::optional<int64_t>(length);
CATCH_TENSOR(x->istft(n_fft, _hop_length, _win_length, _window, center, normalized, onesided, _length, return_complex));
}

Tensor THSTensor_ldexp(const Tensor left, const Tensor right)
{
CATCH_TENSOR(left->ldexp(*right));
Expand Down Expand Up @@ -945,3 +954,11 @@ Tensor THSTensor_xlogy_scalar_(const Tensor x, const Scalar y)
{
CATCH_TENSOR(x->xlogy_(*y));
}

Tensor THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, bool onesided, bool return_complex)
{
auto _hop_length = hop_length == -1 ? c10::optional<int64_t>() : c10::optional<int64_t>(hop_length);
auto _win_length = win_length == -1 ? c10::optional<int64_t>() : c10::optional<int64_t>(win_length);
auto _window = window == nullptr ? c10::optional<at::Tensor>() : *window;
CATCH_TENSOR(x->stft(n_fft, _hop_length, _win_length, _window, normalized, onesided, return_complex));
}
60 changes: 60 additions & 0 deletions src/TorchSharp/Tensor/Tensor.Math.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,18 @@ public Tensor hypot(Tensor other)
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static IntPtr THSTensor_istft(IntPtr x, long n_fft, long hop_length, long win_length, IntPtr window, bool center, bool normalized, bool onesided, long length, bool return_complex);

public Tensor istft(long n_fft, long hop_length = -1, long win_length = -1, Tensor window = null, bool center = true, string pad_mode = "reflect", bool normalized = false, bool? onesided = null, long length = -1, bool return_complex = false)
{
IntPtr _window = (window is null) ? IntPtr.Zero : window.Handle;
bool _onesided = (onesided is null) ? true : (bool)onesided;
var res = THSTensor_istft(Handle, n_fft, hop_length, win_length, _window, center, normalized, _onesided, length, return_complex);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
static extern IntPtr THSTensor_log(IntPtr tensor);

Expand Down Expand Up @@ -1877,6 +1889,21 @@ public Tensor signbit()
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static IntPtr THSTensor_stft(IntPtr x, long n_fft, long hop_length, long win_length, IntPtr window, bool normalized, bool onesided, bool return_complex);

public Tensor stft(long n_fft, long hop_length = -1, long win_length = -1, Tensor window = null, bool center = true, string pad_mode = "reflect", bool normalized = false, bool? onesided = null, bool return_complex = false)
{
IntPtr _window = (window is null) ? IntPtr.Zero : window.Handle;
bool _onesided = (onesided is null) ? true : (bool)onesided;
if (center) {
// TODO
}
var res = THSTensor_stft(Handle, n_fft, hop_length, win_length, _window, normalized, _onesided, return_complex);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
static extern IntPtr THSTensor_sub(IntPtr tensor, IntPtr trg);

Expand Down Expand Up @@ -3361,5 +3388,38 @@ public static Tensor einsum(string equation, params Tensor[] tensors)
/// <param name="replacement">Whether to draw with replacement or not</param>
/// <param name="generator">Optional random number generator</param>
public static Tensor multinomial(Tensor input, long num_samples, bool replacement = false, torch.Generator generator = null) => input.multinomial(num_samples, replacement, generator);

/// <summary>
///
/// </summary>
/// <param name="x"></param>
/// <param name="n_fft"></param>
/// <param name="hop_length"></param>
/// <param name="win_length"></param>
/// <param name="window"></param>
/// <param name="center"></param>
/// <param name="pad_mode"></param>
/// <param name="normalized"></param>
/// <param name="onesided"></param>
/// <param name="length"></param>
/// <param name="return_complex"></param>
/// <returns></returns>
public static Tensor istft(Tensor x, long n_fft, long hop_length = -1, long win_length = -1, Tensor window = null, bool center = true, string pad_mode = "reflect", bool normalized = false, bool? onesided = null, long length = -1, bool return_complex = false) => x.istft(n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length, return_complex);

/// <summary>
///
/// </summary>
/// <param name="x"></param>
/// <param name="n_fft"></param>
/// <param name="hop_length"></param>
/// <param name="win_length"></param>
/// <param name="window"></param>
/// <param name="center"></param>
/// <param name="pad_mode"></param>
/// <param name="normalized"></param>
/// <param name="onesided"></param>
/// <param name="return_complex"></param>
/// <returns></returns>
public static Tensor stft(Tensor x, long n_fft, long hop_length = -1, long win_length = -1, Tensor window = null, bool center = true, string pad_mode = "reflect", bool normalized = false, bool? onesided = null, bool return_complex = false) => x.stft(n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex);
}
}
1 change: 1 addition & 0 deletions src/TorchSharp/Tensor/Tensor.torch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ public static Tensor _standard_gamma(Tensor input, torch.Generator generator = n
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static IntPtr THSTensor_sample_dirichlet_(IntPtr tensor, IntPtr gen);

Expand Down

0 comments on commit ada6545

Please sign in to comment.