Skip to content

Commit

Permalink
feat: update tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
seankmartin committed Oct 16, 2024
1 parent 0003de9 commit 65c3f93
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 32 deletions.
52 changes: 23 additions & 29 deletions cryoet_data_portal_neuroglancer/precompute/contrast_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ class GMMContrastLimitCalculator(ContrastLimitCalculator):
def compute_contrast_limit(
self,
num_components: int = 3,
covariance_type: str | int = "full",
low_variance_mult: float = 3.0,
high_variance_mult: float = 0.5,
) -> tuple[float, float]:
Expand All @@ -283,10 +282,6 @@ def compute_contrast_limit(
num_components: int, optional.
The number of components to use for the GMM.
By default 3.
covariance_type: str, optional.
The covariance type to use for the GMM.
By default "full".
Options are "full", "tied", "diag", "spherical".
low_variance_mult: float, optional.
The multiplier for the low variance.
By default 3.0.
Expand All @@ -299,22 +294,29 @@ def compute_contrast_limit(
tuple[float, float]
The calculated contrast limits.
"""
if not isinstance(covariance_type, str):
covariance_type = int(covariance_type)
covariance_type = ["full", "diag", "spherical"][covariance_type]

if num_components < 2:
raise ValueError("Number of components must be at least 2.")
self.num_components = num_components
covariance_type = "full"
self.gmm_estimator = GaussianMixture(
n_components=num_components,
covariance_type=covariance_type,
max_iter=20,
random_state=42,
reg_covar=1e-4,
reg_covar=1e-5,
init_params="k-means++",
)

sample_data = self.volume.flatten()
self.gmm_estimator.fit(sample_data.reshape(-1, 1))
try:
self.gmm_estimator.fit(sample_data.reshape(-1, 1))
except ValueError:
# GMM fit can fail if the data is not well distributed - try with less components
return self.compute_contrast_limit(
num_components - 1,
low_variance_mult,
high_variance_mult,
)

# Get the stats for the gaussian which sits in the middle
means = self.gmm_estimator.means_.flatten()
Expand All @@ -325,12 +327,7 @@ def compute_contrast_limit(
# (n_features, n_features) if 'tied',
# (n_components, n_features) if 'diag',
# (n_components, n_features, n_features) if 'full'
if covariance_type == "spherical":
variances = covariances
elif covariance_type == "diag":
variances = covariances[:, 0].flatten()
elif covariance_type == "full":
variances = covariances[:, 0, 0].flatten()
variances = covariances.flatten()

# Pick the GMM component which is closest to the mean of the volume
volume_mean = np.mean(sample_data)
Expand All @@ -343,16 +340,14 @@ def compute_contrast_limit(
def _objective_function(self, params):
return self.compute_contrast_limit(
params["num_components"],
params["covariance_type"],
params["low_variance_mult"],
params["high_variance_mult"],
)

def _define_parameter_space(self, parameter_optimizer):
parameter_optimizer.space_creator(
{
"num_components": {"type": "randint", "args": [2, 3]},
"covariance_type": {"type": "choice", "args": [["full", "diag", "spherical"]]},
"num_components": {"type": "randint", "args": [2, 4]},
"low_variance_mult": {"type": "uniform", "args": [1.0, 5.0]},
"high_variance_mult": {"type": "uniform", "args": [0.1, 1.0]},
},
Expand Down Expand Up @@ -506,7 +501,7 @@ def __init__(self, volume: Optional["np.ndarray"] = None):
def compute_contrast_limit(
self,
downsample_factor: int = 5,
sample_factor: float = 0.05,
sample_factor: float = 0.10,
threshold_factor: float = 0.01,
) -> tuple[float, float]:
"""Calculate the contrast limits using decimation.
Expand Down Expand Up @@ -536,16 +531,15 @@ def compute_contrast_limit(
# Compute threshold and lower_change threshold
sample_size = int(sample_factor * len(diff_decimated))

initial_flat = np.mean(cdf[:sample_size]) # Average of first 50 points (assumed flat region)
final_flat = np.mean(cdf[-sample_size:]) # Average of last 50 points (assumed flat region)
initial_flat = np.mean(cdf[:sample_size]) # Average of first points (assumed flat region)
final_flat = np.mean(cdf[-sample_size:]) # Average of last points (assumed flat region)
midpoint = (initial_flat + final_flat) / 2
lower_curve_threshold = threshold_factor * midpoint
upper_change_threshold = lower_curve_threshold
curve_threshold = threshold_factor * midpoint

# Detect start and end of slope
start_idx_decimated = np.argmax(diff_decimated > lower_curve_threshold) # First large change
start_idx_decimated = np.argmax(diff_decimated > curve_threshold) # First large change
end_idx_decimated = (
np.argmax(diff_decimated[start_idx_decimated + 1 :] < upper_change_threshold) + start_idx_decimated
np.argmax(diff_decimated[start_idx_decimated + 1 :] < curve_threshold) + start_idx_decimated
) # first small change

# Map back the indices to original values
Expand All @@ -568,9 +562,9 @@ def _objective_function(self, params):
def _define_parameter_space(self, parameter_optimizer):
parameter_optimizer.space_creator(
{
"downsample_factor": {"type": "randint", "args": [2, 8]},
"downsample_factor": {"type": "randint", "args": [3, 7]},
"sample_factor": {"type": "uniform", "args": [0.01, 0.1]},
"threshold_factor": {"type": "uniform", "args": [0.001, 0.1]},
"threshold_factor": {"type": "uniform", "args": [0.005, 0.2]},
},
)

Expand Down
7 changes: 4 additions & 3 deletions manual_tests/contrast_limits_from_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def run_all_contrast_limit_calculations(
}
hyperopt_evals_dict = {
"percentile": 5,
"gmm": 5,
"cdf": 5,
"decimation": 5,
"gmm": 500,
"cdf": 500,
"decimation": 500,
}
for key, calc in calculator_dict.items():
max_hyperopt_evals = hyperopt_evals_dict[key]
Expand All @@ -164,6 +164,7 @@ def run_all_contrast_limit_calculations(

with open(output_path / f"contrast_limits_{id_}.json", "w") as f:
combined_dict = {k: {"limits": v, "info": info_dict[k]} for k, v in limits_dict.items()}
combined_dict["real_limits"] = volume_limit
json.dump(combined_dict, f, cls=NpEncoder, indent=4)

# Check which method is closest to the human contrast limits
Expand Down

0 comments on commit 65c3f93

Please sign in to comment.