diff --git a/kamui/__init__.py b/kamui/__init__.py index e832eff..6e187c3 100644 --- a/kamui/__init__.py +++ b/kamui/__init__.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Tuple, Optional, Iterable, Union +from typing import Tuple, Optional, Iterable, Union, Any from .core import * from .utils import * @@ -27,8 +27,10 @@ def unwrap_dimensional( start_pixel: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, use_edgelist: bool = False, cyclical_axis: Union[int, Tuple[int, int]] = (), - **kwargs, -) -> np.ndarray: + merging_method: str = "mean", + weights: Optional[np.ndarray] = None, + **kwargs: Any, +) -> Optional[np.ndarray]: """ Unwrap the phase of a 2-D or 3-D array. @@ -45,6 +47,8 @@ def unwrap_dimensional( cyclical_axis : int or (int, int) The axis that is cyclical. Default to (). + weights : Weights defining the 'goodness' of value at each vertex. Shape must match the shape of x. + merging_method : Way of combining two phase weights into a single edge weight. kwargs : dict Other arguments passed to `kamui.unwrap_arbitrary`. @@ -61,7 +65,6 @@ def unwrap_dimensional( for i, s in enumerate(start_pixel): start_i *= x.shape[i] start_i += s - if x.ndim == 2: edges, simplices = get_2d_edges_and_simplices( x.shape, cyclical_axis=cyclical_axis @@ -74,12 +77,20 @@ def unwrap_dimensional( raise ValueError("x must be 2D or 3D") psi = x.ravel() + if weights is not None: + # convert per-vertex weights to per-edge weights + + weights = prepare_weights(weights, edges=edges, merging_method=merging_method) result = unwrap_arbitrary( - psi, edges, None if use_edgelist else simplices, start_i=start_i, **kwargs + psi, + edges, + None if use_edgelist else simplices, + start_i=start_i, + weights=weights, + **kwargs, ) if result is None: return None - return result.reshape(x.shape) diff --git a/kamui/core.py b/kamui/core.py index f697ae6..49d6b8d 100644 --- a/kamui/core.py +++ b/kamui/core.py @@ -4,11 +4,11 @@ import numpy as np from typing import Optional, Iterable + try: import maxflow except ImportError: print("PyMaxflow not found, some functions will not be available.") - __all__ = ["integrate", "calculate_k", "calculate_m", "puma"] @@ -30,7 +30,6 @@ def integrate(edges: np.ndarray, weights: np.ndarray, start_i: int = 0): pairs = np.stack([nodes[:-1], nodes[1:]], axis=1) for u, v in pairs: result[v] = result[u] + G[u, v] - return result @@ -73,7 +72,6 @@ def calculate_k( raise ValueError("simplices contain invalid edges") vals.append(-1) u = v - rows = np.array(rows) cols = np.array(cols) vals = np.array(vals) @@ -102,7 +100,6 @@ def calculate_k( c = np.ones((M * 2,), dtype=np.int64) else: c = np.tile(weights, 2) - res = linprog(c, A_eq=A_eq, b_eq=b_eq, integrality=1) if res.x is None: return None @@ -144,7 +141,6 @@ def calculate_m( ) if weights is None: weights = np.ones((M,), dtype=np.int64) - c = np.concatenate((np.zeros(N, dtype=np.int64), weights, weights)) b_eq = differences @@ -170,7 +166,6 @@ def puma(psi: np.ndarray, edges: np.ndarray, max_jump: int = 1, p: float = 1): jump_steps = list(range(1, max_jump + 1)) * 2 else: jump_steps = [max_jump] - total_nodes = psi.size def V(x): @@ -211,7 +206,6 @@ def cal_Ek(K, psi, i, j): for i in range(total_nodes): G.add_tedge(i, tmp_st_weight[0, i], tmp_st_weight[1, i]) - G.maxflow() partition = G.get_grid_segments(np.arange(total_nodes)) @@ -224,5 +218,4 @@ def cal_Ek(K, psi, i, j): else: K[~partition] -= step break - return K diff --git a/kamui/utils.py b/kamui/utils.py index b590b14..675f6cd 100644 --- a/kamui/utils.py +++ b/kamui/utils.py @@ -1,7 +1,12 @@ import numpy as np -from typing import Tuple, Optional, Iterable, Union +import numpy.typing as npt +from typing import Tuple, Iterable, Union -__all__ = ["get_2d_edges_and_simplices", "get_3d_edges_and_simplices"] +__all__ = [ + "get_2d_edges_and_simplices", + "get_3d_edges_and_simplices", + "prepare_weights", +] def get_2d_edges_and_simplices( @@ -25,8 +30,8 @@ def get_2d_edges_and_simplices( nodes = np.arange(np.prod(shape)).reshape(shape) if type(cyclical_axis) is int: cyclical_axis = (cyclical_axis,) - # if the axis length <= 2, then the axis is already cyclical + cyclical_axis = tuple(filter(lambda ax: shape[ax] > 2, cyclical_axis)) edges = np.concatenate( @@ -78,7 +83,6 @@ def get_2d_edges_and_simplices( ), axis=0, ).tolist() - return edges, simplices @@ -192,5 +196,67 @@ def get_3d_edges_and_simplices( ), axis=0, ).tolist() - return edges, simplices + + +def prepare_weights( + weights: npt.NDArray, + edges: npt.NDArray[np.int_], + smoothing: float = 0.1, + merging_method: str = "mean", +) -> npt.NDArray[np.float_]: + """Prepare weights for `calculate_m` and `calculate_k` functions. + + Assume the weights are the same shape as the phases to be unwrapped. + + Scale the weights from 0 to 1. Pick the weights corresponding to the phase pairs connected by the edges. + Compute the mean/max/min (depending on the `merging_method`) of each of those pairs to give a weight for each edge. + + Args: + weights : Array of weights of shape corresponding to the original phases array shape. + edges : Edges connecting the phases. Shape: (M, 2), where M is the number of edges. + smoothing : A positive value in range [0, 1). This is the minimal value of the rescaled weights + where they are defined. If smoothing > 0, the value of 0 is reserved for places where + the weights are originally NaN. If smoothing == 0, 0 will be used for both NaN weights + and smallest non-NaN ones. + merging_method : Way of combining two phase weights into a single edge weight. + + Returns: + Array of weights for the edges, shape: (M,). Rescaled to [0, 1]. + """ + + if not 0 <= smoothing < 1: + raise ValueError( + "`smoothing` should be a value between 0 (inclusive) and 1 (non inclusive); got " + + str(smoothing) + ) + # scale the weights from 0 to 1 + + weights = weights - np.nanmin(weights) + current_max = np.nanmax(weights) + if not current_max: + # current maximum is 0, which means all weights originally had the same value, now 0; replace everything with 1 + + weights += 1 + else: + weights /= current_max + weights *= 1 - smoothing + weights += smoothing + # pick the weights corresponding to the phases connected by the edges + # and use `merging_method` to get one weight for each edge + + allowed_merging_methods = ["min", "max", "mean"] + if merging_method not in allowed_merging_methods: + raise ValueError( + "`merging_method` should be one of: " + + ", ".join(merging_method) + + "; got " + + str(merging_method) + ) + weights_for_edges = getattr(np, merging_method)(weights.ravel()[edges], axis=1) + + # make sure there are no NaNs in the weights; replace any with 0s + + weights_for_edges[np.isnan(weights_for_edges)] = 0 + + return weights_for_edges