11import cvxpy as cp
22import numpy as np
3+ from plotter import SNMFPlotter
34from scipy .optimize import minimize
45from scipy .sparse import coo_matrix , diags
56
@@ -73,6 +74,7 @@ def __init__(
7374 tol = 5e-7 ,
7475 n_components = None ,
7576 random_state = None ,
77+ show_plots = False ,
7678 ):
7779 """Initialize an instance of SNMF and run the optimization.
7880
@@ -112,6 +114,8 @@ def __init__(
112114 random_state : int Optional Default = None
113115 The seed for the initial guesses at the matrices (A, X, and Y) created by
114116 the decomposition.
117+ show_plots : boolean Optional Default = False
118+ Enables plotting at each step of the decomposition.
115119 """
116120
117121 self .source_matrix = source_matrix
@@ -123,6 +127,7 @@ def __init__(
123127 self .signal_length , self .n_signals = source_matrix .shape
124128 self .num_updates = 0
125129 self ._rng = np .random .default_rng (random_state )
130+ self .plotter = SNMFPlotter () if show_plots else None
126131
127132 # Enforce exclusive specification of n_components or init_weights
128133 if (n_components is None and init_weights is None ) or (
@@ -236,6 +241,13 @@ def normalize_results(self):
236241 print (f"Objective function after normalize_components: { self .objective_function :.5e} " )
237242 self ._objective_history .append (self .objective_function )
238243 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
244+ if self .plotter is not None :
245+ self .plotter .update (
246+ components = self .components ,
247+ weights = self .weights ,
248+ stretch = self .stretch ,
249+ update_tag = "normalize components" ,
250+ )
239251 if self .objective_difference < self .objective_function * self .tol and outiter >= 7 :
240252 break
241253
@@ -252,6 +264,10 @@ def outer_loop(self):
252264 if self .objective_function < self .best_objective :
253265 self .best_objective = self .objective_function
254266 self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
267+ if self .plotter is not None :
268+ self .plotter .update (
269+ components = self .components , weights = self .weights , stretch = self .stretch , update_tag = "components"
270+ )
255271
256272 self .update_weights ()
257273 self .residuals = self .get_residual_matrix ()
@@ -262,6 +278,10 @@ def outer_loop(self):
262278 if self .objective_function < self .best_objective :
263279 self .best_objective = self .objective_function
264280 self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
281+ if self .plotter is not None :
282+ self .plotter .update (
283+ components = self .components , weights = self .weights , stretch = self .stretch , update_tag = "weights"
284+ )
265285
266286 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
267287 if self ._objective_history [- 3 ] - self .objective_function < self .objective_difference * 1e-3 :
@@ -276,6 +296,10 @@ def outer_loop(self):
276296 if self .objective_function < self .best_objective :
277297 self .best_objective = self .objective_function
278298 self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
299+ if self .plotter is not None :
300+ self .plotter .update (
301+ components = self .components , weights = self .weights , stretch = self .stretch , update_tag = "stretch"
302+ )
279303
280304 def get_residual_matrix (self , components = None , weights = None , stretch = None ):
281305 # Initialize residual matrix as negative of source_matrix
0 commit comments