Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a new Spectrogram layer based on Conv1D operations, supporting GPU-parallelization and fine-tuning #20313

Merged
merged 17 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras.src.initializers.constant_initializers import Identity as identity
from keras.src.initializers.constant_initializers import Ones
from keras.src.initializers.constant_initializers import Ones as ones
from keras.src.initializers.constant_initializers import STFTInitializer
from keras.src.initializers.constant_initializers import Zeros
from keras.src.initializers.constant_initializers import Zeros as zeros
from keras.src.initializers.initializer import Initializer
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
from keras.src.layers.preprocessing.normalization import Normalization
from keras.src.layers.preprocessing.pipeline import Pipeline
from keras.src.layers.preprocessing.rescaling import Rescaling
from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram
from keras.src.layers.preprocessing.string_lookup import StringLookup
from keras.src.layers.preprocessing.text_vectorization import TextVectorization
from keras.src.layers.regularization.activity_regularization import (
Expand Down
1 change: 1 addition & 0 deletions keras/api/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras.src.initializers.constant_initializers import Identity as identity
from keras.src.initializers.constant_initializers import Ones
from keras.src.initializers.constant_initializers import Ones as ones
from keras.src.initializers.constant_initializers import STFTInitializer
from keras.src.initializers.constant_initializers import Zeros
from keras.src.initializers.constant_initializers import Zeros as zeros
from keras.src.initializers.initializer import Initializer
Expand Down
1 change: 1 addition & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
from keras.src.layers.preprocessing.normalization import Normalization
from keras.src.layers.preprocessing.pipeline import Pipeline
from keras.src.layers.preprocessing.rescaling import Rescaling
from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram
from keras.src.layers.preprocessing.string_lookup import StringLookup
from keras.src.layers.preprocessing.text_vectorization import TextVectorization
from keras.src.layers.regularization.activity_regularization import (
Expand Down
2 changes: 2 additions & 0 deletions keras/src/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.src.initializers.constant_initializers import Constant
from keras.src.initializers.constant_initializers import Identity
from keras.src.initializers.constant_initializers import Ones
from keras.src.initializers.constant_initializers import STFTInitializer
from keras.src.initializers.constant_initializers import Zeros
from keras.src.initializers.initializer import Initializer
from keras.src.initializers.random_initializers import GlorotNormal
Expand All @@ -25,6 +26,7 @@
Constant,
Identity,
Ones,
STFTInitializer,
Zeros,
GlorotNormal,
GlorotUniform,
Expand Down
118 changes: 118 additions & 0 deletions keras/src/initializers/constant_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.src.backend import standardize_dtype
from keras.src.initializers.initializer import Initializer
from keras.src.saving import serialization_lib
from keras.src.utils.module_utils import scipy


@keras_export(["keras.initializers.Constant", "keras.initializers.constant"])
Expand Down Expand Up @@ -151,3 +152,120 @@ def __call__(self, shape, dtype=None):
)
dtype = standardize_dtype(dtype)
return self.gain * ops.eye(*shape, dtype=dtype)


@keras_export(["keras.initializers.STFTInitializer"])
class STFTInitializer(Initializer):
"""Initializer of Conv kernels for Short-term Fourier Transformation (STFT).

Since the formula involves complex numbers, this class compute either the
real or the imaginary components of the final output.

Additionally, this initializer supports windowing functions across the time
dimension as commonly used in STFT. Windowing functions from the module
`scipy.signal.windows` are supported, including the common `hann` and
`hamming` windowing functions. This layer supports periodic windows and
scaling-based normalization.

This is primarly intended for use in the `STFTSpectrogram` layer.

Examples:

>>> # Standalone usage:
>>> initializer = STFTInitializer("real", "hann", "density", False)
>>> values = initializer(shape=(128, 1, 513))

Args:
side: String, `"real"` or `"imag"` deciding if the kernel will compute
the real side or the imaginary side of the output.
window: String for the name of the windowing function in the
`scipy.signal.windows` module, or array_like for the window values,
or `None` for no windowing.
scaling: String, `"density"` or `"spectrum"` for scaling of the window
for normalization, either L2 or L1 normalization.
`None` for no scaling.
periodic: Boolean, if True, the window function will be treated as
periodic. Defaults to `False`.
"""

def __init__(self, side, window="hann", scaling="density", periodic=False):
if side not in ["real", "imag"]:
raise ValueError(f"side should be 'real' or 'imag', not {side}")
if isinstance(window, str):
# throws an exception for invalid window function
scipy.signal.get_window(window, 1)
if scaling is not None and scaling not in ["density", "spectrum"]:
raise ValueError(
"Scaling is invalid, it must be `None`, 'density' "
f"or 'spectrum'. Received scaling={scaling}"
)
self.side = side
self.window = window
self.scaling = scaling
self.periodic = periodic

def __call__(self, shape, dtype=None):
"""Returns a tensor object initialized as specified by the initializer.

The shape is assumed to be `(T, 1, F // 2 + 1)`, where `T` is the size
of the given window, and `F` is the number of frequency bands. Only half
the frequency bands are used, which is a common practice in STFT,
because the second half are the conjugates of the first half in
a reversed order.

Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
are supported. If not specified, `keras.backend.floatx()`
is used, which default to `float32` unless you configured it
otherwise (via `keras.backend.set_floatx(float_dtype)`).
"""
dtype = standardize_dtype(dtype)
frame_length, input_channels, fft_length = shape

win = None
scaling = 1
if self.window is not None:
win = self.window
if isinstance(win, str):
# Using SciPy since it provides more windowing functions,
# easier to be compatible with multiple backends.
win = scipy.signal.get_window(win, frame_length, self.periodic)
win = ops.convert_to_tensor(win, dtype=dtype)
if len(win.shape) != 1 or win.shape[-1] != frame_length:
raise ValueError(
"The shape of `window` must be equal to [frame_length]."
f"Received: window shape={win.shape}"
)
win = ops.reshape(win, [frame_length, 1, 1])
if self.scaling == "density":
scaling = ops.sqrt(ops.sum(ops.square(win)))
elif self.scaling == "spectrum":
scaling = ops.sum(ops.abs(win))

_fft_length = (fft_length - 1) * 2
freq = (
ops.reshape(ops.arange(fft_length, dtype=dtype), (1, 1, fft_length))
/ _fft_length
)
time = ops.reshape(
ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1)
)
args = -2 * time * freq * ops.arccos(ops.cast(-1, dtype))

if self.side == "real":
kernel = ops.cast(ops.cos(args), dtype)
else:
kernel = ops.cast(ops.sin(args), dtype)

if win is not None:
kernel = kernel * win / scaling
return kernel

def get_config(self):
return {
"side": self.side,
"window": self.window,
"periodic": self.periodic,
"scaling": self.scaling,
}
63 changes: 63 additions & 0 deletions keras/src/initializers/constant_initializers_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import scipy.signal

from keras.src import backend
from keras.src import initializers
Expand Down Expand Up @@ -67,3 +68,65 @@ def test_identity_initializer(self):
self.assertAllClose(np_values, np.eye(*shape) * gain)

self.run_class_serialization_test(initializer)

def test_stft_initializer(self):
shape = (256, 1, 513)
time_range = np.arange(256).reshape((-1, 1, 1))
freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1))
pi = np.arccos(np.float64(-1))
args = -2 * pi * time_range * freq_range

tol_kwargs = {}
if backend.backend() == "jax":
# TODO(mostafa-mahmoud): investigate the cases
# of non-small error in jax and torch
tol_kwargs = {"atol": 1e-4, "rtol": 1e-6}

initializer = initializers.STFTInitializer("real", None)
values = backend.convert_to_numpy(initializer(shape))
self.assertAllClose(np.cos(args), values, atol=1e-4)
self.run_class_serialization_test(initializer)

initializer = initializers.STFTInitializer(
"real",
"hamming",
None,
True,
)
window = scipy.signal.windows.get_window("hamming", 256, True)
window = window.astype("float64").reshape((-1, 1, 1))
values = backend.convert_to_numpy(initializer(shape, "float64"))
self.assertAllClose(np.cos(args) * window, values, **tol_kwargs)
self.run_class_serialization_test(initializer)

initializer = initializers.STFTInitializer(
"imag",
"tukey",
"density",
False,
)
window = scipy.signal.windows.get_window("tukey", 256, False)
window = window.astype("float64").reshape((-1, 1, 1))
window = window / np.sqrt(np.sum(window**2))
values = backend.convert_to_numpy(initializer(shape, "float64"))
self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)
self.run_class_serialization_test(initializer)

initializer = initializers.STFTInitializer(
"imag",
list(range(1, 257)),
"spectrum",
)
window = np.arange(1, 257)
window = window.astype("float64").reshape((-1, 1, 1))
window = window / np.sum(window)
values = backend.convert_to_numpy(initializer(shape, "float64"))
self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)
self.run_class_serialization_test(initializer)

with self.assertRaises(ValueError):
initializers.STFTInitializer("imaginary")
with self.assertRaises(ValueError):
initializers.STFTInitializer("real", scaling="l2")
with self.assertRaises(ValueError):
initializers.STFTInitializer("real", window="unknown")
1 change: 1 addition & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
from keras.src.layers.preprocessing.normalization import Normalization
from keras.src.layers.preprocessing.pipeline import Pipeline
from keras.src.layers.preprocessing.rescaling import Rescaling
from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram
from keras.src.layers.preprocessing.string_lookup import StringLookup
from keras.src.layers.preprocessing.text_vectorization import TextVectorization
from keras.src.layers.regularization.activity_regularization import (
Expand Down
Loading