From e5dbca5b6817f3732e7c39d4c20a5ca1658492d4 Mon Sep 17 00:00:00 2001 From: piyushrpt Date: Fri, 17 May 2024 17:38:47 -0700 Subject: [PATCH] Addressing PR feedback from #10 --- src/spurt/mcf/_ortools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spurt/mcf/_ortools.py b/src/spurt/mcf/_ortools.py index 356cdd7..b295a77 100644 --- a/src/spurt/mcf/_ortools.py +++ b/src/spurt/mcf/_ortools.py @@ -147,7 +147,7 @@ def compute_residues_from_gradients( self, graddata: ArrayLike, ) -> ArrayLike: - """Compute phase residues for one set of real input gradients.""" + if graddata.size != self.nedges: errmsg = ( f"Size mismatch for residue computation." @@ -155,8 +155,8 @@ def compute_residues_from_gradients( ) raise ValueError(errmsg) - cyc0 = np.abs(self.dual_edges[:, 0]) - cyc1 = np.abs(self.dual_edges[:, 1]) + cyc0 = self.dual_edges[:, 0] + cyc1 = self.dual_edges[:, 1] cyc0_dir = self.dual_edge_dir[:, 0] cyc1_dir = self.dual_edge_dir[:, 1] grad_sum = np.zeros(self.ncycles + 1, dtype=np.float32) @@ -164,7 +164,7 @@ def compute_residues_from_gradients( np.add.at(grad_sum, cyc0, cyc0_dir * graddata) np.add.at(grad_sum, cyc1, cyc1_dir * graddata) - residues = np.rint(grad_sum / (2 * np.pi)) + residues = np.rint(grad_sum / (2 * np.pi)).astype(int) # Set supply of groud_node residues[0] = -np.sum(residues[1:])