Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions src/diffpy/snmf/snmf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ class SNMFOptimizer:
num_updates : int
The total number of times that any of (stretch, components, and weights) have had their values changed.
If not terminated by other means, this value is used to stop when reaching max_iter.
objective_function: float
The value corresponding to the minimization of the difference between the source_matrix and the
products of (stretch, components, and weights). For full details see the sNMF paper. Smaller corresponds to
better agreement and is desirable.
objective_difference : float
The change in the objective function value since the last update. A negative value
means that the result improved.
Expand Down Expand Up @@ -134,7 +130,7 @@ def __init__(
# Initialize weights and determine number of components
if init_weights is None:
self.n_components = n_components
self.weights = self._rng.beta(a=2.5, b=1.5, size=(self.n_components, self.n_signals))
self.weights = self._rng.beta(a=2.0, b=2.0, size=(self.n_components, self.n_signals))
else:
self.n_components = init_weights.shape[0]
self.weights = init_weights
Expand Down Expand Up @@ -165,20 +161,20 @@ def __init__(

# Set up residual matrix, objective function, and history
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
self._objective_history = []
self.update_objective()
self.objective_difference = None
self._objective_history = [self.objective_function]

# Set up tracking variables for updateX()
# Set up tracking variables for update_components()
self._prev_components = None
self.grad_components = np.zeros_like(self.components) # Gradient of X (zeros for now)
self._prev_grad_components = np.zeros_like(self.components) # Previous gradient of X (zeros for now)

regularization_term = 0.5 * rho * np.linalg.norm(self._spline_smooth_operator @ self.stretch.T, "fro") ** 2
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
print(
f"Start, Objective function: {self.objective_function:.5e}"
f", Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}"
f"Start, Objective function: {self._objective_history[-1]:.5e}"
f", Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}"
)

# Main optimization loop
Expand All @@ -191,15 +187,15 @@ def __init__(
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
print(
f"Num_updates: {self.num_updates}, "
f"Obj fun: {self.objective_function:.5e}, "
f"Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}, "
f"Obj fun: {self._objective_history[-1]:.5e}, "
f"Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}, "
f"Iter: {iter}"
)

# Convergence check: decide when to terminate for small/no improvement
print(self.objective_difference, " < ", self.objective_function * tol)
if self.objective_difference < self.objective_function * tol and iter >= 20:
if self.objective_difference < self._objective_history[-1] * tol and iter >= 20:
break
print(self.objective_difference, " < ", self._objective_history[-1] * tol)

# Normalize our results
weights_row_max = np.max(self.weights, axis=1, keepdims=True)
Expand All @@ -214,17 +210,17 @@ def __init__(
self.grad_components = np.zeros_like(self.components)
self._prev_grad_components = np.zeros_like(self.components)
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
self.objective_difference = None
self._objective_history = [self.objective_function]
self._objective_history = []
self.update_objective()
for norm_iter in range(100):
self.update_components()
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after normX: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
self.update_objective()
print(f"Objective function after normalize_components: {self._objective_history[-1]:.5e}")
self._objective_history.append(self._objective_history[-1])
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
if self.objective_difference < self.objective_function * tol and norm_iter >= 20:
if self.objective_difference < self._objective_history[-1] * tol and norm_iter >= 20:
break
# end of normalization (and program)
# note that objective function may not fully recover after normalization, this is okay
Expand All @@ -233,30 +229,33 @@ def __init__(
def optimize_loop(self):
# Update components first
self._prev_grad_components = self.grad_components.copy()

self.update_components()

self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_components: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
self.update_objective()
print(f"Objective function after update_components: {self._objective_history[-1]:.5e}")

if self.objective_difference is None:
self.objective_difference = self._objective_history[-1] - self.objective_function
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]

# Now we update weights
self.update_weights()

self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_weights: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
self.update_objective()
print(f"Objective function after update_weights: {self._objective_history[-1]:.5e}")

# Now we update stretch
self.update_stretch()

self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_stretch: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
self.update_objective()
print(f"Objective function after update_stretch: {self._objective_history[-1]:.5e}")

self.objective_difference = self._objective_history[-2] - self._objective_history[-1]

def apply_interpolation(self, a, x, return_derivatives=False):
Expand Down Expand Up @@ -328,7 +327,8 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
residuals += weights[k, :] * stretched_components # Element-wise scaling and sum
return residuals

def get_objective_function(self, residuals=None, stretch=None):
def update_objective(self, residuals=None, stretch=None):
to_return = not (residuals is None and stretch is None)
if residuals is None:
residuals = self.residuals
if stretch is None:
Expand All @@ -338,7 +338,11 @@ def get_objective_function(self, residuals=None, stretch=None):
sparsity_term = self.eta * np.sum(np.sqrt(self.components)) # Square root penalty
# Final objective function value
function = residual_term + regularization_term + sparsity_term
return function

if to_return:
return function # Get value directly for use
else:
self._objective_history.append(function) # Store value

def apply_interpolation_matrix(self, components=None, weights=None, stretch=None, return_derivatives=False):
"""
Expand Down Expand Up @@ -590,7 +594,7 @@ def update_components(self):
)
self.components = mask * self.components

objective_improvement = self._objective_history[-1] - self.get_objective_function(
objective_improvement = self._objective_history[-1] - self.update_objective(
residuals=self.get_residual_matrix()
)

Expand Down Expand Up @@ -645,7 +649,7 @@ def regularize_function(self, stretch=None):
stretch_difference = stretch_difference - self.source_matrix

# Compute objective function
reg_func = self.get_objective_function(stretch_difference, stretch)
reg_func = self.update_objective(stretch_difference, stretch)

# Compute gradient
tiled_derivative = np.sum(
Expand Down