Skip to content

Commit

Permalink
feat: add v1 of GMM and k-means contrast limits
Browse files Browse the repository at this point in the history
  • Loading branch information
seankmartin committed Sep 10, 2024
1 parent 48cd2f7 commit eda99a9
Showing 1 changed file with 99 additions and 4 deletions.
103 changes: 99 additions & 4 deletions cryoet_data_portal_neuroglancer/precompute/contrast_limits.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Methods for computing contrast limits for Neuroglancer image layers."""

from abc import abstractmethod
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import find_peaks
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture


def _restrict_volume_around_central_z_slice(
Expand Down Expand Up @@ -36,8 +39,7 @@ def _restrict_volume_around_central_z_slice(
if z_radius is None:
lowest_points = find_peaks(-standard_deviation_per_z_slice, prominence=0.1)[0]
if len(lowest_points) < 2:
# TODO create fallback instead
raise ValueError("Not enough low points found")
raise ValueError("Not enough low points found to auto compute z-radius.")
for value in lowest_points:
if value < central_z_slice:
z_min = value
Expand Down Expand Up @@ -108,7 +110,6 @@ def trim_volume_around_central_zslice(
z_radius,
)

@abstractmethod
def contrast_limits_from_percentiles(
self,
low_percentile: float = 1.0,
Expand Down Expand Up @@ -155,3 +156,97 @@ def contrast_limits_from_mean(
width = multipler * rms_value

return mean_value - width, mean_value + width


class GMMContrastLimitCalculator(ContrastLimitCalculator):

def __init__(self, volume: Optional["np.ndarray"] = None, num_components: int = 3):
"""Initialize the contrast limit calculator.
Parameters
----------
volume: np.ndarray or None, optional.
Input volume for calculating contrast limits.
num_components: int, optional.
The number of components to use for GMM.
By default 3.
"""
super().__init__(volume)
self.num_components = num_components
# cov_type in ["spherical", "diag", "tied", "full"]
self.gmm_estimator = GaussianMixture(
n_components=num_components,
covariance_type="full",
max_iter=200,
random_state=0,
)

def contrast_limits_from_gmm(self) -> tuple[float, float]:
"""Calculate the contrast limits using Gaussian Mixture Model.
Returns
-------
tuple[float, float]
The calculated contrast limits.
"""
self.gmm_estimator.fit(self.volume.reshape(-1, 1))

# Get the stats for the gaussian which sits in the middle
means = self.gmm_estimator.means_.flatten()
covariances = self.gmm_estimator.covariances_.flatten()

return means[1] - 2 * covariances[1], means[1] + 2 * covariances[1]


class KMeansContrastLimitCalculator(ContrastLimitCalculator):

def __init__(self, volume: Optional["np.ndarray"] = None, num_clusters: int = 3):
"""Initialize the contrast limit calculator.
Parameters
----------
volume: np.ndarray or None, optional.
Input volume for calculating contrast limits.
num_clusters: int, optional.
The number of clusters to use for KMeans.
By default 3.
"""
super().__init__(volume)
self.num_clusters = num_clusters
self.kmeans_estimator = KMeans(n_clusters=num_clusters, random_state=0)

def plot_kmeans_clusters(self, output_filename: Optional[str | Path] = None) -> None:
"""Plot the KMeans clusters."""
fig, ax = plt.subplots()

ax.hist(self.volume.flatten(), bins=100, alpha=0.5)
ax.hist(self.kmeans_estimator.cluster_centers_, bins=100, alpha=0.5)
if output_filename:
fig.savefig(output_filename)
else:
plt.show()
plt.close(fig)

def contrast_limits_from_kmeans(self) -> tuple[float, float]:
"""Calculate the contrast limits using KMeans clustering.
Parameters
----------
num_clusters: int, optional.
The number of clusters to use for KMeans.
By default 3.
Returns
-------
tuple[float, float]
The calculated contrast limits.
"""
self.kmeans_estimator.fit(self.volume.reshape(-1, 1))

cluster_centers = self.kmeans_estimator.cluster_centers_
cluster_centers.sort()

return cluster_centers[0], cluster_centers[-1]


# Other possibility is to take the derivative of the histogram and find the peaks

0 comments on commit eda99a9

Please sign in to comment.