Skip to content

Commit

Permalink
Ensure consistent use of Chebyshev domain.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jun 26, 2024
1 parent 2851d24 commit 2cd0b4d
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="sgw_tools",
version="2.3.1",
version="2.3.2",
author="Mark Hale",
license="MIT",
description="Spectral graph wavelet tools",
Expand Down
108 changes: 108 additions & 0 deletions sgw_tools/approximations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np
from scipy import sparse
from pygsp import utils


@utils.filterbank_handler
def compute_cheby_coeff(f, m=30, N=None, domain=None, *args, **kwargs):
r"""
Compute Chebyshev coefficients for a Filterbank.
Parameters
----------
f : Filter
Filterbank with at least 1 filter
m : int
Maximum order of Chebyshev coeff to compute
(default = 30)
N : int
Grid order used to compute quadrature
(default = m + 1)
i : int
Index of the Filterbank element to compute
(default = 0)
Returns
-------
c : ndarray
Matrix of Chebyshev coefficients
"""
G = f.G
i = kwargs.pop('i', 0)

if not N:
N = m + 1

a_arange = domain if domain else [0, G.lmax]

a1 = (a_arange[1] - a_arange[0]) / 2
a2 = (a_arange[1] + a_arange[0]) / 2
c = np.zeros(m + 1)

tmpN = np.arange(N)
num = np.cos(np.pi * (tmpN + 0.5) / N)
for o in range(m + 1):
c[o] = 2. / N * np.dot(f._kernels[i](a1 * num + a2),
np.cos(np.pi * o * (tmpN + 0.5) / N))

return c


def cheby_op(G, c, signal, domain=None, **kwargs):
r"""
Chebyshev polynomial of graph Laplacian applied to vector.
Parameters
----------
G : Graph
c : ndarray or list of ndarrays
Chebyshev coefficients for a Filter or a Filterbank
signal : ndarray
Signal to filter
Returns
-------
r : ndarray
Result of the filtering
"""
# Handle if we do not have a list of filters but only a simple filter in cheby_coeff.
if not isinstance(c, np.ndarray):
c = np.array(c)

c = np.atleast_2d(c)
Nscales, M = c.shape

if M < 2:
raise TypeError("The coefficients have an invalid shape")

# thanks to that, we can also have 1d signal.
try:
Nv = np.shape(signal)[1]
r = np.zeros((G.N * Nscales, Nv))
except IndexError:
r = np.zeros((G.N * Nscales))

a_arange = domain if domain else [0, G.lmax]

a1 = float(a_arange[1] - a_arange[0]) / 2.
a2 = float(a_arange[1] + a_arange[0]) / 2.

twf_old = signal
twf_cur = (G.L.dot(signal) - a2 * signal) / a1

tmpN = np.arange(G.N, dtype=int)
for i in range(Nscales):
r[tmpN + G.N*i] = 0.5 * c[i, 0] * twf_old + c[i, 1] * twf_cur

factor = 2/a1 * (G.L - a2 * sparse.eye(G.N))
for k in range(2, M):
twf_new = factor.dot(twf_cur) - twf_old
for i in range(Nscales):
r[tmpN + G.N*i] += c[i, k] * twf_new

twf_old = twf_cur
twf_cur = twf_new

return r
10 changes: 7 additions & 3 deletions sgw_tools/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pygsp as gsp
from . import util
from . import approximations


class GWHeat(gsp.filters.Filter):
Expand Down Expand Up @@ -89,6 +90,8 @@ def __init__(self, G, coeff_bank, domain, coeff_normalization="pygsp"):
else:
raise ValueError(f"Invalid coefficient normalization: {coeff_normalization}")

self.domain = domain

kernels = [
np.polynomial.Chebyshev(coeffs, domain=domain) for coeffs in kernel_coeffs
]
Expand Down Expand Up @@ -139,7 +142,7 @@ def filter(self, s, method='chebyshev', order=30):

if n_features_in == 1: # Analysis.
s = s.squeeze(axis=2)
s = gsp.filters.approximations.cheby_op(self.G, c, s)
s = approximations.cheby_op(self.G, c, s, domain=self.domain)
s = s.reshape((self.G.N, n_features_out, n_signals), order='F')
s = s.swapaxes(1, 2)

Expand All @@ -150,9 +153,10 @@ def filter(self, s, method='chebyshev', order=30):
s = np.zeros((self.G.N, n_signals))
tmpN = np.arange(self.G.N, dtype=int)
for i in range(n_features_in):
s += gsp.filters.approximations.cheby_op(self.G,
s += approximations.cheby_op(self.G,
c[i],
s_in[i * self.G.N + tmpN])
s_in[i * self.G.N + tmpN],
domain=self.domain)
s = np.expand_dims(s, 2)

else:
Expand Down
12 changes: 7 additions & 5 deletions tests/test_sgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,24 @@ def test_tig(self):
def test_chebyshev_filter(self):
G = gsp.graphs.Sensor(100, lap_type="normalized", seed=5)
func = lambda x: np.exp(-x**2)
signal = np.ones(G.N)
g = sgw.CustomFilter(G, func)

signal = np.ones(G.N)
order = 20
expected = g.filter(signal, order=order)

domain = [0, 2]
func_e = func(G.e)

gsp_coeffs = gsp.filters.compute_cheby_coeff(g, m=order)
gsp_g = sgw.ChebyshevFilter(G, gsp_coeffs, [0, G.lmax], "pygsp")
gsp_coeffs = sgw.approximations.compute_cheby_coeff(g, m=order, domain=domain)
gsp_g = sgw.ChebyshevFilter(G, gsp_coeffs, domain, "pygsp")
np.testing.assert_allclose(gsp_g.evaluate(G.e).squeeze(), func_e, err_msg="pygsp evaluate")
gsp_actual = gsp_g.filter(signal, order=order)
np.testing.assert_allclose(gsp_actual, expected, err_msg="pygsp coeffs")

domain = [0, 2]
np_cheby = np.polynomial.Chebyshev.fit(G.e, func(G.e), deg=order, domain=domain)
np_g = sgw.ChebyshevFilter(G, np_cheby.coef, domain, "numpy")
np.testing.assert_allclose(np_g.evaluate(G.e).squeeze(), func_e, err_msg="numpy evaluate")
np_actual = np_g.filter(signal, order=order)
np.testing.assert_allclose(np_actual, expected, err_msg="numpy coeffs", rtol=0.1)
np.testing.assert_allclose(np_actual, expected, err_msg="numpy coeffs")

0 comments on commit 2cd0b4d

Please sign in to comment.