Skip to content

Commit d83619e

Browse files
john-halloranJohn Halloran
andauthored
refactor: apply_interpolation_matrix() is now compute_stretched_components() (#171)
* fix: guard against zero/NaN stretches in apply_interpolation_matrix * refactor: use broadcasting instead of np.tile in apply_interpolation_matrix * refactor: flatten from a single buffer in apply_interpolation_matrix() * refactor: drastically simplify indexing in apply_interpolation_matrix() and remove legacy MATLAB terminology * style: rename apply_interpolation_matrix() to compute_stretched_components() --------- Co-authored-by: John Halloran <jhalloran@oxy.edu>
1 parent 3c54016 commit d83619e

File tree

1 file changed

+76
-71
lines changed

1 file changed

+76
-71
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -370,86 +370,89 @@ def get_objective_function(self, residuals=None, stretch=None):
370370
function = residual_term + regularization_term + sparsity_term
371371
return function
372372

373-
def apply_interpolation_matrix(self, components=None, weights=None, stretch=None):
373+
def compute_stretched_components(self, components=None, weights=None, stretch=None):
374374
"""
375-
Applies an interpolation-based transformation to the 'components' using `stretch`,
376-
weighted by `weights`. Optionally computes first (`d_stretched_components`) and
377-
second (`dd_stretched_components`) derivatives.
375+
Interpolates each component along its sample axis according to per-(component, signal)
376+
stretch factors, then applies per-(component, signal) weights. Also computes the
377+
first and second derivatives with respect to stretch. Left and right, respectively,
378+
refer to the sample prior to and subsequent to the interpolated sample's position.
379+
380+
Inputs
381+
------
382+
components : array, shape (signal_len, n_components)
383+
Each column is a component with signal_len samples.
384+
weights : array, shape (n_components, n_signals)
385+
Per-(component, signal) weights.
386+
stretch : array, shape (n_components, n_signals)
387+
Per-(component, signal) stretch factors.
388+
389+
Outputs
390+
-------
391+
stretched_components : array, shape (signal_len, n_components * n_signals)
392+
Interpolated and weighted components.
393+
d_stretched_components : array, shape (signal_len, n_components * n_signals)
394+
First derivatives with respect to stretch.
395+
dd_stretched_components : array, shape (signal_len, n_components * n_signals)
396+
Second derivatives with respect to stretch.
378397
"""
379398

399+
# --- Defaults ---
380400
if components is None:
381401
components = self.components_
382402
if weights is None:
383403
weights = self.weights_
384404
if stretch is None:
385405
stretch = self.stretch_
386406

387-
# Compute scaled indices
388-
stretch_flat = stretch.reshape(1, self.n_signals * self.n_components) ** -1
389-
stretch_tiled = np.tile(stretch_flat, (self.signal_length, 1))
390-
391-
# Compute `fractional_indices`
392-
fractional_indices = (
393-
np.tile(np.arange(self.signal_length)[:, None], (1, self.n_signals * self.n_components))
394-
* stretch_tiled
407+
# Dimensions
408+
signal_len = components.shape[0] # number of samples
409+
n_components = components.shape[1] # number of components
410+
n_signals = weights.shape[1] # number of signals
411+
412+
# Guard stretches
413+
eps = 1e-8
414+
stretch = np.clip(stretch, eps, None)
415+
stretch_inv = 1.0 / stretch
416+
417+
# Apply stretching to the original sample indices, represented as a "time-stretch"
418+
t = np.arange(signal_len, dtype=float)[:, None, None] * stretch_inv[None, :, :]
419+
# has shape (signal_len, n_components, n_signals)
420+
421+
# For each stretched coordinate, find its prior integer (original) index and their difference
422+
i0 = np.floor(t).astype(np.int64) # prior original index
423+
alpha = t - i0.astype(float) # fractional distance between left/right
424+
425+
# Clip indices to valid range (0, signal_len - 1) to maintain original size
426+
max_idx = signal_len - 1
427+
i0 = np.clip(i0, 0, max_idx)
428+
i1 = np.clip(i0 + 1, 0, max_idx)
429+
430+
# Gather sample values
431+
comps_3d = components[:, :, None] # expand components by a dimension for broadcasting across n_signals
432+
c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values
433+
c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values
434+
435+
# Linear interpolation to determine stretched sample values
436+
interp = c0 * (1.0 - alpha) + c1 * alpha
437+
interp_weighted = interp * weights[None, :, :]
438+
439+
# Derivatives
440+
di = -t * stretch_inv[None, :, :] # first-derivative coefficient
441+
ddi = -di * stretch_inv[None, :, :] * 2.0 # second-derivative coefficient
442+
443+
d_unweighted = c0 * (-di) + c1 * di
444+
dd_unweighted = c0 * (-ddi) + c1 * ddi
445+
446+
d_weighted = d_unweighted * weights[None, :, :]
447+
dd_weighted = dd_unweighted * weights[None, :, :]
448+
449+
# Flatten back to expected shape (signal_len, n_components * n_signals)
450+
return (
451+
interp_weighted.reshape(signal_len, n_components * n_signals),
452+
d_weighted.reshape(signal_len, n_components * n_signals),
453+
dd_weighted.reshape(signal_len, n_components * n_signals),
395454
)
396455

397-
# Weighting matrix
398-
weights_flat = weights.reshape(1, self.n_signals * self.n_components)
399-
weights_tiled = np.tile(weights_flat, (self.signal_length, 1))
400-
401-
# Bias for indexing into reshaped components
402-
# TODO break this up or describe what it does better
403-
bias = np.kron(
404-
np.arange(self.n_components) * (self.signal_length + 1),
405-
np.ones((self.signal_length, self.n_signals), dtype=int),
406-
).reshape(self.signal_length, self.n_components * self.n_signals)
407-
408-
# Handle boundary conditions for interpolation
409-
components_bounded = np.vstack(
410-
[components, components[-1, :]]
411-
) # Duplicate last row (like MATLAB, not sure why)
412-
413-
# Compute floor indices
414-
floor_indices = np.floor(fractional_indices).astype(int)
415-
416-
floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length)
417-
floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length)
418-
419-
# Compute fractional part
420-
fractional_floor_indices = fractional_indices - floor_indices
421-
422-
# Compute offset indices
423-
offset_indices_1 = floor_indices_1 + bias
424-
offset_indices_2 = floor_indices_2 + bias
425-
426-
# Extract values
427-
# Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
428-
comp_values_1 = components_bounded.flatten(order="F")[(offset_indices_1 - 1).ravel(order="F")].reshape(
429-
self.signal_length, self.n_components * self.n_signals, order="F"
430-
) # order = F uses FORTRAN, column major order
431-
comp_values_2 = components_bounded.flatten(order="F")[(offset_indices_2 - 1).ravel(order="F")].reshape(
432-
self.signal_length, self.n_components * self.n_signals, order="F"
433-
)
434-
435-
# Interpolation
436-
unweighted_stretched_comps = (
437-
comp_values_1 * (1 - fractional_floor_indices) + comp_values_2 * fractional_floor_indices
438-
)
439-
stretched_components = unweighted_stretched_comps * weights_tiled # Apply weighting
440-
441-
# Compute first derivative
442-
di = -fractional_indices * stretch_tiled
443-
d_comps_unweighted = comp_values_1 * (-di) + comp_values_2 * di
444-
d_stretched_components = d_comps_unweighted * weights_tiled
445-
446-
# Compute second derivative
447-
ddi = -di * stretch_tiled * 2
448-
dd_comps_unweighted = comp_values_1 * (-ddi) + comp_values_2 * ddi
449-
dd_stretched_components = dd_comps_unweighted * weights_tiled
450-
451-
return stretched_components, d_stretched_components, dd_stretched_components
452-
453456
def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None):
454457
"""
455458
Computes the transformation matrix `stretch_transformed` for residuals,
@@ -560,7 +563,7 @@ def update_components(self):
560563
Updates `components` using gradient-based optimization with adaptive step size.
561564
"""
562565
# Compute stretched components using the interpolation function
563-
stretched_components, _, _ = self.apply_interpolation_matrix() # Discard the derivatives
566+
stretched_components, _, _ = self.compute_stretched_components() # Discard the derivatives
564567
# Compute reshaped_stretched_components and component_residuals
565568
intermediate_reshaped = stretched_components.flatten(order="F").reshape(
566569
(self.signal_length * self.n_signals, self.n_components), order="F"
@@ -648,7 +651,9 @@ def regularize_function(self, stretch=None):
648651
if stretch is None:
649652
stretch = self.stretch_
650653

651-
stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch)
654+
stretched_components, d_stretch_comps, dd_stretch_comps = self.compute_stretched_components(
655+
stretch=stretch
656+
)
652657
intermediate = stretched_components.flatten(order="F").reshape(
653658
(self.signal_length * self.n_signals, self.n_components), order="F"
654659
)
@@ -751,8 +756,8 @@ def reconstruct_matrix(components, weights, stretch):
751756
"""
752757

753758
signal_len = components.shape[0]
754-
n_signals = weights.shape[1]
755759
n_components = components.shape[1]
760+
n_signals = weights.shape[1]
756761

757762
reconstructed_matrix = np.zeros((signal_len, n_signals))
758763
sample_indices = np.arange(signal_len)

0 commit comments

Comments
 (0)