Skip to content

Commit c59cc70

Browse files
author
John Halloran
committed
refactor: reconstruct separate from residuals
1 parent 9cbd801 commit c59cc70

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,7 @@ def outer_loop(self):
304304

305305
def get_residual_matrix(self, components=None, weights=None, stretch=None):
306306
"""
307-
Return the residuals (difference) between the source matrix and its reconstruction
308-
from the given components, weights, and stretch factors.
309-
310-
Each component profile is stretched, interpolated to fractional positions,
311-
weighted per signal, and summed to form the reconstruction. The residuals
312-
are the source matrix minus this reconstruction.
307+
Return the residuals (difference) between the source matrix and its reconstruction.
313308
314309
Parameters
315310
----------
@@ -329,21 +324,8 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
329324
if stretch is None:
330325
stretch = self.stretch
331326

332-
residuals = -self.source_matrix.copy()
333-
sample_indices = np.arange(components.shape[0]) # (signal_len,)
334-
335-
for comp in range(components.shape[1]): # loop over components
336-
residuals += (
337-
np.interp(
338-
sample_indices[:, None]
339-
/ stretch[comp][None, :], # fractional positions (signal_len, n_signals)
340-
sample_indices, # (signal_len,)
341-
components[:, comp], # component profile (signal_len,)
342-
left=components[0, comp],
343-
right=components[-1, comp],
344-
)
345-
* weights[comp][None, :] # broadcast (n_signals,) over rows
346-
)
327+
reconstructed_matrix = reconstruct_matrix(components, weights, stretch)
328+
residuals = reconstructed_matrix - self.source_matrix
347329

348330
return residuals
349331

@@ -718,3 +700,44 @@ def cubic_largest_real_root(p, q):
718700
y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0
719701

720702
return y
703+
704+
705+
def reconstruct_matrix(components=None, weights=None, stretch=None):
706+
"""
707+
Construct the approximation of the source matrix corresponding to the
708+
given components, weights, and stretch factors.
709+
710+
Each component profile is stretched, interpolated to fractional positions,
711+
weighted per signal, and summed to form the reconstruction.
712+
713+
Parameters
714+
----------
715+
components : (signal_len, n_components) array
716+
weights : (n_components, n_signals) array
717+
stretch : (n_components, n_signals) array
718+
719+
Returns
720+
-------
721+
reconstructed_matrix : (signal_len, n_signals) array
722+
"""
723+
724+
signal_len = components.shape[0]
725+
n_signals = weights.shape[1]
726+
n_components = components.shape[1]
727+
728+
reconstructed_matrix = np.zeros((signal_len, n_signals))
729+
sample_indices = np.arange(signal_len)
730+
731+
for comp in range(n_components): # loop over components
732+
reconstructed_matrix += (
733+
np.interp(
734+
sample_indices[:, None] / stretch[comp][None, :], # fractional positions (signal_len, n_signals)
735+
sample_indices, # (signal_len,)
736+
components[:, comp], # component profile (signal_len,)
737+
left=components[0, comp],
738+
right=components[-1, comp],
739+
)
740+
* weights[comp][None, :] # broadcast (n_signals,) over rows
741+
)
742+
743+
return reconstructed_matrix

0 commit comments

Comments
 (0)