@@ -370,86 +370,89 @@ def get_objective_function(self, residuals=None, stretch=None):
370
370
function = residual_term + regularization_term + sparsity_term
371
371
return function
372
372
373
- def apply_interpolation_matrix (self , components = None , weights = None , stretch = None ):
373
+ def compute_stretched_components (self , components = None , weights = None , stretch = None ):
374
374
"""
375
- Applies an interpolation-based transformation to the 'components' using `stretch`,
376
- weighted by `weights`. Optionally computes first (`d_stretched_components`) and
377
- second (`dd_stretched_components`) derivatives.
375
+ Interpolates each component along its sample axis according to per-(component, signal)
376
+ stretch factors, then applies per-(component, signal) weights. Also computes the
377
+ first and second derivatives with respect to stretch. Left and right, respectively,
378
+ refer to the sample prior to and subsequent to the interpolated sample's position.
379
+
380
+ Inputs
381
+ ------
382
+ components : array, shape (signal_len, n_components)
383
+ Each column is a component with signal_len samples.
384
+ weights : array, shape (n_components, n_signals)
385
+ Per-(component, signal) weights.
386
+ stretch : array, shape (n_components, n_signals)
387
+ Per-(component, signal) stretch factors.
388
+
389
+ Outputs
390
+ -------
391
+ stretched_components : array, shape (signal_len, n_components * n_signals)
392
+ Interpolated and weighted components.
393
+ d_stretched_components : array, shape (signal_len, n_components * n_signals)
394
+ First derivatives with respect to stretch.
395
+ dd_stretched_components : array, shape (signal_len, n_components * n_signals)
396
+ Second derivatives with respect to stretch.
378
397
"""
379
398
399
+ # --- Defaults ---
380
400
if components is None :
381
401
components = self .components_
382
402
if weights is None :
383
403
weights = self .weights_
384
404
if stretch is None :
385
405
stretch = self .stretch_
386
406
387
- # Compute scaled indices
388
- stretch_flat = stretch .reshape (1 , self .n_signals * self .n_components ) ** - 1
389
- stretch_tiled = np .tile (stretch_flat , (self .signal_length , 1 ))
390
-
391
- # Compute `fractional_indices`
392
- fractional_indices = (
393
- np .tile (np .arange (self .signal_length )[:, None ], (1 , self .n_signals * self .n_components ))
394
- * stretch_tiled
407
+ # Dimensions
408
+ signal_len = components .shape [0 ] # number of samples
409
+ n_components = components .shape [1 ] # number of components
410
+ n_signals = weights .shape [1 ] # number of signals
411
+
412
+ # Guard stretches
413
+ eps = 1e-8
414
+ stretch = np .clip (stretch , eps , None )
415
+ stretch_inv = 1.0 / stretch
416
+
417
+ # Apply stretching to the original sample indices, represented as a "time-stretch"
418
+ t = np .arange (signal_len , dtype = float )[:, None , None ] * stretch_inv [None , :, :]
419
+ # has shape (signal_len, n_components, n_signals)
420
+
421
+ # For each stretched coordinate, find its prior integer (original) index and their difference
422
+ i0 = np .floor (t ).astype (np .int64 ) # prior original index
423
+ alpha = t - i0 .astype (float ) # fractional distance between left/right
424
+
425
+ # Clip indices to valid range (0, signal_len - 1) to maintain original size
426
+ max_idx = signal_len - 1
427
+ i0 = np .clip (i0 , 0 , max_idx )
428
+ i1 = np .clip (i0 + 1 , 0 , max_idx )
429
+
430
+ # Gather sample values
431
+ comps_3d = components [:, :, None ] # expand components by a dimension for broadcasting across n_signals
432
+ c0 = np .take_along_axis (comps_3d , i0 , axis = 0 ) # left sample values
433
+ c1 = np .take_along_axis (comps_3d , i1 , axis = 0 ) # right sample values
434
+
435
+ # Linear interpolation to determine stretched sample values
436
+ interp = c0 * (1.0 - alpha ) + c1 * alpha
437
+ interp_weighted = interp * weights [None , :, :]
438
+
439
+ # Derivatives
440
+ di = - t * stretch_inv [None , :, :] # first-derivative coefficient
441
+ ddi = - di * stretch_inv [None , :, :] * 2.0 # second-derivative coefficient
442
+
443
+ d_unweighted = c0 * (- di ) + c1 * di
444
+ dd_unweighted = c0 * (- ddi ) + c1 * ddi
445
+
446
+ d_weighted = d_unweighted * weights [None , :, :]
447
+ dd_weighted = dd_unweighted * weights [None , :, :]
448
+
449
+ # Flatten back to expected shape (signal_len, n_components * n_signals)
450
+ return (
451
+ interp_weighted .reshape (signal_len , n_components * n_signals ),
452
+ d_weighted .reshape (signal_len , n_components * n_signals ),
453
+ dd_weighted .reshape (signal_len , n_components * n_signals ),
395
454
)
396
455
397
- # Weighting matrix
398
- weights_flat = weights .reshape (1 , self .n_signals * self .n_components )
399
- weights_tiled = np .tile (weights_flat , (self .signal_length , 1 ))
400
-
401
- # Bias for indexing into reshaped components
402
- # TODO break this up or describe what it does better
403
- bias = np .kron (
404
- np .arange (self .n_components ) * (self .signal_length + 1 ),
405
- np .ones ((self .signal_length , self .n_signals ), dtype = int ),
406
- ).reshape (self .signal_length , self .n_components * self .n_signals )
407
-
408
- # Handle boundary conditions for interpolation
409
- components_bounded = np .vstack (
410
- [components , components [- 1 , :]]
411
- ) # Duplicate last row (like MATLAB, not sure why)
412
-
413
- # Compute floor indices
414
- floor_indices = np .floor (fractional_indices ).astype (int )
415
-
416
- floor_indices_1 = np .minimum (floor_indices + 1 , self .signal_length )
417
- floor_indices_2 = np .minimum (floor_indices_1 + 1 , self .signal_length )
418
-
419
- # Compute fractional part
420
- fractional_floor_indices = fractional_indices - floor_indices
421
-
422
- # Compute offset indices
423
- offset_indices_1 = floor_indices_1 + bias
424
- offset_indices_2 = floor_indices_2 + bias
425
-
426
- # Extract values
427
- # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
428
- comp_values_1 = components_bounded .flatten (order = "F" )[(offset_indices_1 - 1 ).ravel (order = "F" )].reshape (
429
- self .signal_length , self .n_components * self .n_signals , order = "F"
430
- ) # order = F uses FORTRAN, column major order
431
- comp_values_2 = components_bounded .flatten (order = "F" )[(offset_indices_2 - 1 ).ravel (order = "F" )].reshape (
432
- self .signal_length , self .n_components * self .n_signals , order = "F"
433
- )
434
-
435
- # Interpolation
436
- unweighted_stretched_comps = (
437
- comp_values_1 * (1 - fractional_floor_indices ) + comp_values_2 * fractional_floor_indices
438
- )
439
- stretched_components = unweighted_stretched_comps * weights_tiled # Apply weighting
440
-
441
- # Compute first derivative
442
- di = - fractional_indices * stretch_tiled
443
- d_comps_unweighted = comp_values_1 * (- di ) + comp_values_2 * di
444
- d_stretched_components = d_comps_unweighted * weights_tiled
445
-
446
- # Compute second derivative
447
- ddi = - di * stretch_tiled * 2
448
- dd_comps_unweighted = comp_values_1 * (- ddi ) + comp_values_2 * ddi
449
- dd_stretched_components = dd_comps_unweighted * weights_tiled
450
-
451
- return stretched_components , d_stretched_components , dd_stretched_components
452
-
453
456
def apply_transformation_matrix (self , stretch = None , weights = None , residuals = None ):
454
457
"""
455
458
Computes the transformation matrix `stretch_transformed` for residuals,
@@ -560,7 +563,7 @@ def update_components(self):
560
563
Updates `components` using gradient-based optimization with adaptive step size.
561
564
"""
562
565
# Compute stretched components using the interpolation function
563
- stretched_components , _ , _ = self .apply_interpolation_matrix () # Discard the derivatives
566
+ stretched_components , _ , _ = self .compute_stretched_components () # Discard the derivatives
564
567
# Compute reshaped_stretched_components and component_residuals
565
568
intermediate_reshaped = stretched_components .flatten (order = "F" ).reshape (
566
569
(self .signal_length * self .n_signals , self .n_components ), order = "F"
@@ -648,7 +651,9 @@ def regularize_function(self, stretch=None):
648
651
if stretch is None :
649
652
stretch = self .stretch_
650
653
651
- stretched_components , d_stretch_comps , dd_stretch_comps = self .apply_interpolation_matrix (stretch = stretch )
654
+ stretched_components , d_stretch_comps , dd_stretch_comps = self .compute_stretched_components (
655
+ stretch = stretch
656
+ )
652
657
intermediate = stretched_components .flatten (order = "F" ).reshape (
653
658
(self .signal_length * self .n_signals , self .n_components ), order = "F"
654
659
)
@@ -751,8 +756,8 @@ def reconstruct_matrix(components, weights, stretch):
751
756
"""
752
757
753
758
signal_len = components .shape [0 ]
754
- n_signals = weights .shape [1 ]
755
759
n_components = components .shape [1 ]
760
+ n_signals = weights .shape [1 ]
756
761
757
762
reconstructed_matrix = np .zeros ((signal_len , n_signals ))
758
763
sample_indices = np .arange (signal_len )
0 commit comments