Skip to content

Commit

Permalink
feat: progress on contrast limits
Browse files Browse the repository at this point in the history
  • Loading branch information
seankmartin committed Sep 11, 2024
1 parent 75b1a2e commit 5f9e7e0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 30 deletions.
74 changes: 48 additions & 26 deletions cryoet_data_portal_neuroglancer/precompute/contrast_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def wrapper(*args, **kwargs):
import time

start_time = time.time()
LOGGER.info(f"Running function {func.__name__}.")
LOGGER.info(f"Running function {func.__name__}")
result = func(*args, **kwargs)
end_time = time.time()
LOGGER.info(f"Function {func.__name__} took {end_time - start_time} seconds.")
LOGGER.info(f"Function {func.__name__} took {end_time - start_time:.2f} seconds")
return result

return wrapper
Expand Down Expand Up @@ -55,9 +55,10 @@ def _restrict_volume_around_central_z_slice(
central_z_slice = central_z_slice or (0.5 * volume.shape[0] - 0.5)

if z_radius is None:
lowest_points = find_peaks(-standard_deviation_per_z_slice, prominence=0.1)[0]
lowest_points, _ = find_peaks(-standard_deviation_per_z_slice, prominence=0.05)
if len(lowest_points) < 2:
raise ValueError("Not enough low points found to auto compute z-radius.")
LOGGER.warning("Could not find enough low points in the standard deviation per z-slice.")
return volume
for value in lowest_points:
if value < central_z_slice:
z_min = value
Expand Down Expand Up @@ -151,7 +152,7 @@ def contrast_limits_from_percentiles(
low_value = np.percentile(self.volume.flatten(), low_percentile)
high_value = np.percentile(self.volume.flatten(), high_percentile)

return low_value, high_value
return low_value.compute()[0], high_value.compute()[0]

@compute_with_timer
def contrast_limits_from_mean(
Expand Down Expand Up @@ -217,14 +218,22 @@ def contrast_limits_from_gmm(self) -> tuple[float, float]:
means = self.gmm_estimator.means_.flatten()
covariances = self.gmm_estimator.covariances_.flatten()

return means[1] - 2 * covariances[1], means[1] + 2 * covariances[1]
# pick the middle GMM component - TODO should actually be the one with the
# mean closest to the mean of the volume
means = means[np.argsort(means)]
covariances = covariances[np.argsort(means)]

return means[1] - 0.1 * np.sqrt(covariances[1]), means[1] + 0.1 * np.sqrt(covariances[1])

def plot_gmm_clusters(self, output_filename: Optional[str | Path] = None) -> None:
"""Plot the GMM clusters."""
fig, ax = plt.subplots()

# TODO improve this plot with std
ax.plot(self.gmm_estimator.means_)
ax.plot(
self.gmm_estimator.means_.flatten(),
self.gmm_estimator.covariances_.flatten(),
"o",
)
if output_filename:
fig.savefig(output_filename)
else:
Expand Down Expand Up @@ -253,7 +262,7 @@ def plot_kmeans_clusters(self, output_filename: Optional[str | Path] = None) ->
"""Plot the KMeans clusters."""
fig, ax = plt.subplots()

ax.plot(self.kmeans_estimator.cluster_centers_)
ax.plot(self.kmeans_estimator.cluster_centers_, "o")
if output_filename:
fig.savefig(output_filename)
else:
Expand Down Expand Up @@ -281,7 +290,13 @@ def contrast_limits_from_kmeans(self) -> tuple[float, float]:
cluster_centers = self.kmeans_estimator.cluster_centers_
cluster_centers.sort()

return cluster_centers[0], cluster_centers[-1]
# Find the boundaries of the middle cluster
distance_to_middle = np.abs(cluster_centers - cluster_centers[1])

left_boundary = cluster_centers[0][0] + 0.75 * distance_to_middle[0][0]
right_boundary = cluster_centers[-1][0] - 0.75 * distance_to_middle[-1][0]

return left_boundary, right_boundary


class CDFContrastLimitCalculator(ContrastLimitCalculator):
Expand All @@ -297,6 +312,7 @@ def __init__(self, volume: Optional["np.ndarray"] = None):
super().__init__(volume)
self.cdf = None
self.limits = None
self.second_derivative = None

@compute_with_timer
def contrast_limits_from_cdf(self) -> tuple[float, float]:
Expand All @@ -310,39 +326,45 @@ def contrast_limits_from_cdf(self) -> tuple[float, float]:
# Calculate the histogram of the volume
min_value = np.min(self.volume.flatten())
max_value = np.max(self.volume.flatten())
hist, bin_edges = np.histogram(self.volume.flatten(), bins=1000, range=[min_value, max_value])
hist, bin_edges = np.histogram(self.volume.flatten(), bins=400, range=[min_value, max_value])

# Calculate the CDF of the histogram
cdf = np.cumsum(hist) / np.sum(hist)
gradient = np.gradient(cdf)

# Find the biggest positive peak
peaks = find_peaks(gradient, prominence=0.1)
biggest_peak = np.argmax(gradient[peaks])

# Find where the function starts to become flat after the peak
# Find where the function starts to flatten
gradient = np.gradient(cdf.compute())
second_derivative = np.gradient(gradient)
flat_points = np.where(second_derivative[biggest_peak:] < 0.0001)[0]
# TODO improve error handling
peaks, _ = find_peaks(second_derivative, prominence=0.01)

self.cdf = cdf
self.limits = bin_edges[biggest_peak], bin_edges[biggest_peak + flat_points[0]]
# If no peaks, take the argmax of the gradient
biggest_peak = np.argmax(second_derivative) if len(peaks) == 0 else peaks[np.argmax(second_derivative[peaks])]

negative_peaks, _ = find_peaks(-second_derivative, prominence=0.01)
smallest_negative_peak = (
np.argmin(second_derivative)
if len(negative_peaks) == 0
else negative_peaks[np.argmin(second_derivative[negative_peaks])]
)

x = np.linspace(min_value, max_value, 400)
self.cdf = [x, cdf]
self.limits = (bin_edges[biggest_peak].compute(), bin_edges[smallest_negative_peak].compute())
self.second_derivative = second_derivative

return self.limits

def plot_cdf_and_limits(self, output_filename: Optional[str | Path] = None) -> None:
def plot_cdf(self, output_filename: Optional[str | Path] = None) -> None:
"""Plot the CDF and the calculated limits."""
fig, ax = plt.subplots()

ax.plot(self.cdf)
ax.plot(self.cdf[0], self.cdf[1])
ax.axvline(self.limits[0], color="r")
ax.axvline(self.limits[1], color="r")

ax.plot(self.cdf[0], self.second_derivative, "g")

if output_filename:
fig.savefig(output_filename)
else:
plt.show()
plt.close(fig)


# Other possibility is to take the derivative of the histogram and find the peaks
11 changes: 7 additions & 4 deletions manual_tests/contrast_limits_from_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import logging
from pathlib import Path

Expand All @@ -19,11 +20,11 @@
OUTPUT_FOLDER = "/media/starfish/LargeSSD/data/cryoET/data/FromAPI"

id_to_path_map = {
630: "630-TS_045.zarr",
1000: "1000/16.zarr",
}

id_to_human_contrast_limits = {
630: {
1000: {
"slice": [-0.2, 0.15],
"volume": [-0.035, 0.009],
"gain": -7.6,
Expand All @@ -34,8 +35,9 @@
def grab_tomogram(id_: int, zarr_path: Path):
client = Client()
if not zarr_path.exists():
zarr_path.mkdir(parents=True, exist_ok=True)
tomogram = Tomogram.get_by_id(client, id_)
tomogram.download_omezarr(str(zarr_path))
tomogram.download_omezarr(str(zarr_path.parent.resolve()))


def run_all_contrast_limit_calculations(id_, input_data_path, output_path):
Expand Down Expand Up @@ -71,8 +73,9 @@ def run_all_contrast_limit_calculations(id_, input_data_path, output_path):
limits_dict["cdf"] = limits
cdf_calculator.plot_cdf(output_path / "cdf.png")

print(limits_dict)
with open(output_path / "contrast_limits.json", "w") as f:
f.write(limits_dict)
json.dump(limits_dict, f)

human_contrast = id_to_human_contrast_limits[id_]
volume_limit = human_contrast["volume"]
Expand Down

0 comments on commit 5f9e7e0

Please sign in to comment.