Skip to content

Commit 640dacf

Browse files
author
John Halloran
committed
refactor: replace remaining apply_interpolation with np.interp
1 parent c18d868 commit 640dacf

File tree

1 file changed

+20
-49
lines changed

1 file changed

+20
-49
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 20 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -608,24 +608,29 @@ def update_components(self):
608608

609609
def update_weights(self):
610610
"""
611-
Updates weights using matrix operations, solving a quadratic program to do so.
611+
Updates weights by building the stretched component matrix `stretched_comps` with np.interp
612+
and solving a quadratic program for each signal.
612613
"""
613614

614-
signal_length = self.signal_length
615-
n_signals = self.n_signals
616-
617-
for m in range(n_signals):
618-
t = np.zeros((signal_length, self.n_components))
619-
620-
# Populate t using apply_interpolation
621-
for k in range(self.n_components):
622-
t[:, k] = apply_interpolation(self.stretch[k, m], self.components[:, k]).squeeze()
623-
624-
# Solve quadratic problem for y
625-
y = self.solve_quadratic_program(t=t, m=m)
615+
sample_indices = np.arange(self.signal_length)
616+
for signal in range(self.n_signals):
617+
# Stretch factors for this signal across components:
618+
this_stretch = self.stretch[:, signal]
619+
# Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
620+
stretched_comps = np.empty((self.signal_length, self.n_components), dtype=self.components.dtype)
621+
for comp in range(self.n_components):
622+
pos = sample_indices / this_stretch[comp]
623+
stretched_comps[:, comp] = np.interp(
624+
pos,
625+
sample_indices,
626+
self.components[:, comp],
627+
left=self.components[0, comp],
628+
right=self.components[-1, comp],
629+
)
626630

627-
# Update Y
628-
self.weights[:, m] = y
631+
# Solve quadratic problem for a given signal and update its weight
632+
new_weight = self.solve_quadratic_program(t=stretched_comps, m=signal)
633+
self.weights[:, signal] = new_weight
629634

630635
def regularize_function(self, stretch=None):
631636
if stretch is None:
@@ -712,37 +717,3 @@ def cubic_largest_real_root(p, q):
712717
y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0
713718

714719
return y
715-
716-
717-
def apply_interpolation(a, x):
718-
"""
719-
Applies an interpolation-based transformation to `x` based on scaling `a`.
720-
"""
721-
x_len = len(x)
722-
723-
# Ensure `a` is an array and reshape for broadcasting
724-
a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D
725-
726-
# Compute fractional indices, broadcasting over `a`
727-
fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M)
728-
729-
integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M))
730-
valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds
731-
732-
# Apply valid_mask to keep correct indices
733-
idx_int = np.where(valid_mask, integer_indices, x_len - 2) # Prevent out-of-bounds indexing (previously "I")
734-
idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i")
735-
736-
# Ensure x is a 1D array
737-
x = np.asarray(x).ravel()
738-
739-
# Compute interpolated_x (linear interpolation)
740-
interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * (
741-
idx_frac - idx_int
742-
)
743-
744-
# Fill the tail with the last valid value
745-
intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :])
746-
interpolated_x = np.vstack([interpolated_x, intr_x_tail])
747-
748-
return interpolated_x

0 commit comments

Comments
 (0)