@@ -608,24 +608,29 @@ def update_components(self):
608
608
609
609
def update_weights (self ):
610
610
"""
611
- Updates weights using matrix operations, solving a quadratic program to do so.
611
+ Updates weights by building the stretched component matrix `stretched_comps` with np.interp
612
+ and solving a quadratic program for each signal.
612
613
"""
613
614
614
- signal_length = self .signal_length
615
- n_signals = self .n_signals
616
-
617
- for m in range (n_signals ):
618
- t = np .zeros ((signal_length , self .n_components ))
619
-
620
- # Populate t using apply_interpolation
621
- for k in range (self .n_components ):
622
- t [:, k ] = apply_interpolation (self .stretch [k , m ], self .components [:, k ]).squeeze ()
623
-
624
- # Solve quadratic problem for y
625
- y = self .solve_quadratic_program (t = t , m = m )
615
+ sample_indices = np .arange (self .signal_length )
616
+ for signal in range (self .n_signals ):
617
+ # Stretch factors for this signal across components:
618
+ this_stretch = self .stretch [:, signal ]
619
+ # Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
620
+ stretched_comps = np .empty ((self .signal_length , self .n_components ), dtype = self .components .dtype )
621
+ for comp in range (self .n_components ):
622
+ pos = sample_indices / this_stretch [comp ]
623
+ stretched_comps [:, comp ] = np .interp (
624
+ pos ,
625
+ sample_indices ,
626
+ self .components [:, comp ],
627
+ left = self .components [0 , comp ],
628
+ right = self .components [- 1 , comp ],
629
+ )
626
630
627
- # Update Y
628
- self .weights [:, m ] = y
631
+ # Solve quadratic problem for a given signal and update its weight
632
+ new_weight = self .solve_quadratic_program (t = stretched_comps , m = signal )
633
+ self .weights [:, signal ] = new_weight
629
634
630
635
def regularize_function (self , stretch = None ):
631
636
if stretch is None :
@@ -712,37 +717,3 @@ def cubic_largest_real_root(p, q):
712
717
y = np .max (real_roots , axis = 0 ) * (delta < 0 ) # Keep only real roots when delta < 0
713
718
714
719
return y
715
-
716
-
717
- def apply_interpolation (a , x ):
718
- """
719
- Applies an interpolation-based transformation to `x` based on scaling `a`.
720
- """
721
- x_len = len (x )
722
-
723
- # Ensure `a` is an array and reshape for broadcasting
724
- a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
725
-
726
- # Compute fractional indices, broadcasting over `a`
727
- fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
728
-
729
- integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
730
- valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
731
-
732
- # Apply valid_mask to keep correct indices
733
- idx_int = np .where (valid_mask , integer_indices , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
734
- idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
735
-
736
- # Ensure x is a 1D array
737
- x = np .asarray (x ).ravel ()
738
-
739
- # Compute interpolated_x (linear interpolation)
740
- interpolated_x = x [idx_int ] * (1 - idx_frac + idx_int ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * (
741
- idx_frac - idx_int
742
- )
743
-
744
- # Fill the tail with the last valid value
745
- intr_x_tail = np .full ((x_len - len (idx_int ), interpolated_x .shape [1 ]), interpolated_x [- 1 , :])
746
- interpolated_x = np .vstack ([interpolated_x , intr_x_tail ])
747
-
748
- return interpolated_x
0 commit comments