-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib.rs
96 lines (88 loc) · 3.14 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use ::rustft::{fft, ifft, Stft, WindowFunction};
use num_complex::Complex;
use numpy::{PyArray1, PyArray2, PyArray3};
use pyo3::prelude::*;
#[pyfunction]
fn rust_fft(py: Python, input: &PyArray1<f64>) -> PyResult<Py<PyArray1<Complex<f64>>>> {
let binding = input.readonly();
let input_array = binding.as_array();
let fft_res = fft(input_array)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Error in fft: {}", e))
})?
.to_vec();
let py_result = PyArray1::from_vec_bound(py, fft_res);
Ok(py_result.into())
}
#[pyfunction]
fn rust_ifft(py: Python, input: &PyArray1<Complex<f64>>) -> PyResult<Py<PyArray1<f64>>> {
// Convert input to Array1<Complex<f64>> for ifft
let binding = input.readonly();
let input_array = binding.as_array();
let ifft_res = ifft(input_array)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Error in ifft: {}", e))
})?
.to_vec();
let py_result = PyArray1::from_vec_bound(py, ifft_res);
Ok(py_result.into())
}
#[pyfunction]
fn rust_fft_roundtrip_test(py: Python, input: &PyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
let fft_result = rust_fft(py, input)?;
let ifft_result = rust_ifft(py, &fft_result.as_ref(py))?;
Ok(ifft_result)
}
#[pyfunction]
fn rust_stft(
py: Python,
input: &PyArray2<f64>,
n_fft: usize,
hop_length: usize,
) -> PyResult<Py<PyArray3<Complex<f64>>>> {
let binding = input.readonly();
let input_array = binding.as_array();
let stft = Stft::new(n_fft, hop_length, WindowFunction::Hann, true);
let stft_res = stft.forward(input_array).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Error in stft: {}", e))
})?;
let py_result = PyArray3::from_owned_array_bound(py, stft_res);
Ok(py_result.into())
}
#[pyfunction]
fn rust_istft(
py: Python,
input: &PyArray3<Complex<f64>>,
n_fft: usize,
hop_length: usize,
) -> PyResult<Py<PyArray2<f64>>> {
let binding = input.readonly();
let input_array = binding.as_array();
let stft = Stft::new(n_fft, hop_length, WindowFunction::Hann, true);
let istft_res = stft.inverse(input_array).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Error in stft: {}", e))
})?;
let py_output = PyArray2::from_owned_array_bound(py, istft_res);
Ok(py_output.into())
}
#[pyfunction]
fn rust_stft_roundtrip(
py: Python,
input: &PyArray2<f64>,
n_fft: usize,
hop_length: usize,
) -> PyResult<Py<PyArray2<f64>>> {
let stft_result = rust_stft(py, input, n_fft, hop_length)?;
let istft_result = rust_istft(py, &stft_result.as_ref(py), n_fft, hop_length)?;
Ok(istft_result)
}
#[pymodule]
fn rustft(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(rust_fft, m)?)?;
m.add_function(wrap_pyfunction!(rust_ifft, m)?)?;
m.add_function(wrap_pyfunction!(rust_fft_roundtrip_test, m)?)?;
m.add_function(wrap_pyfunction!(rust_stft, m)?)?;
m.add_function(wrap_pyfunction!(rust_istft, m)?)?;
m.add_function(wrap_pyfunction!(rust_stft_roundtrip, m)?)?;
Ok(())
}