@@ -636,19 +636,19 @@ def regularize_function(self, stretch=None):
636
636
if stretch is None :
637
637
stretch = self .stretch
638
638
639
- K = self .n_components
640
- M = self .n_signals
641
- N = self .signal_length
642
-
643
639
stretched_components , d_stretch_comps , dd_stretch_comps = self .apply_interpolation_matrix (stretch = stretch )
644
- intermediate = stretched_components .flatten (order = "F" ).reshape ((N * M , K ), order = "F" )
645
- residuals = intermediate .sum (axis = 1 ).reshape ((N , M ), order = "F" ) - self .source_matrix
640
+ intermediate = stretched_components .flatten (order = "F" ).reshape (
641
+ (self .signal_length * self .n_signals , self .n_components ), order = "F"
642
+ )
643
+ residuals = (
644
+ intermediate .sum (axis = 1 ).reshape ((self .signal_length , self .n_signals ), order = "F" ) - self .source_matrix
645
+ )
646
646
647
647
fun = self .get_objective_function (residuals , stretch )
648
648
649
- tiled_res = np .tile (residuals , (1 , K ))
649
+ tiled_res = np .tile (residuals , (1 , self . n_components ))
650
650
grad_flat = np .sum (d_stretch_comps * tiled_res , axis = 0 )
651
- gra = grad_flat .reshape ((M , K ), order = "F" ).T
651
+ gra = grad_flat .reshape ((self . n_signals , self . n_components ), order = "F" ).T
652
652
gra += self .rho * stretch @ (self ._spline_smooth_operator .T @ self ._spline_smooth_operator )
653
653
654
654
# Hessian would go here
@@ -657,10 +657,10 @@ def regularize_function(self, stretch=None):
657
657
658
658
def update_stretch (self ):
659
659
"""
660
- Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB).
660
+ Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).
661
661
"""
662
662
663
- # Flatten A for compatibility with the optimizer (since SciPy expects 1D input)
663
+ # Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input)
664
664
stretch_flat_initial = self .stretch .flatten ()
665
665
666
666
# Define the optimization function
@@ -682,7 +682,7 @@ def objective(stretch_vec):
682
682
bounds = bounds ,
683
683
)
684
684
685
- # Update A with the optimized values
685
+ # Update stretch with the optimized values
686
686
self .stretch = result .x .reshape (self .stretch .shape )
687
687
688
688
0 commit comments