diff --git a/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py b/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py index 0128ebe..ed50d88 100644 --- a/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py +++ b/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py @@ -272,7 +272,6 @@ class GMMContrastLimitCalculator(ContrastLimitCalculator): def compute_contrast_limit( self, num_components: int = 3, - covariance_type: str | int = "full", low_variance_mult: float = 3.0, high_variance_mult: float = 0.5, ) -> tuple[float, float]: @@ -283,10 +282,6 @@ def compute_contrast_limit( num_components: int, optional. The number of components to use for the GMM. By default 3. - covariance_type: str, optional. - The covariance type to use for the GMM. - By default "full". - Options are "full", "tied", "diag", "spherical". low_variance_mult: float, optional. The multiplier for the low variance. By default 3.0. @@ -299,22 +294,29 @@ def compute_contrast_limit( tuple[float, float] The calculated contrast limits. """ - if not isinstance(covariance_type, str): - covariance_type = int(covariance_type) - covariance_type = ["full", "diag", "spherical"][covariance_type] - + if num_components < 2: + raise ValueError("Number of components must be at least 2.") self.num_components = num_components + covariance_type = "full" self.gmm_estimator = GaussianMixture( n_components=num_components, covariance_type=covariance_type, max_iter=20, random_state=42, - reg_covar=1e-4, + reg_covar=1e-5, init_params="k-means++", ) sample_data = self.volume.flatten() - self.gmm_estimator.fit(sample_data.reshape(-1, 1)) + try: + self.gmm_estimator.fit(sample_data.reshape(-1, 1)) + except ValueError: + # GMM fit can fail if the data is not well distributed - try with less components + return self.compute_contrast_limit( + num_components - 1, + low_variance_mult, + high_variance_mult, + ) # Get the stats for the gaussian which sits in the middle means = self.gmm_estimator.means_.flatten() @@ -325,12 +327,7 @@ def compute_contrast_limit( # (n_features, n_features) if 'tied', # (n_components, n_features) if 'diag', # (n_components, n_features, n_features) if 'full' - if covariance_type == "spherical": - variances = covariances - elif covariance_type == "diag": - variances = covariances[:, 0].flatten() - elif covariance_type == "full": - variances = covariances[:, 0, 0].flatten() + variances = covariances.flatten() # Pick the GMM component which is closest to the mean of the volume volume_mean = np.mean(sample_data) @@ -343,7 +340,6 @@ def compute_contrast_limit( def _objective_function(self, params): return self.compute_contrast_limit( params["num_components"], - params["covariance_type"], params["low_variance_mult"], params["high_variance_mult"], ) @@ -351,8 +347,7 @@ def _objective_function(self, params): def _define_parameter_space(self, parameter_optimizer): parameter_optimizer.space_creator( { - "num_components": {"type": "randint", "args": [2, 3]}, - "covariance_type": {"type": "choice", "args": [["full", "diag", "spherical"]]}, + "num_components": {"type": "randint", "args": [2, 4]}, "low_variance_mult": {"type": "uniform", "args": [1.0, 5.0]}, "high_variance_mult": {"type": "uniform", "args": [0.1, 1.0]}, }, @@ -506,7 +501,7 @@ def __init__(self, volume: Optional["np.ndarray"] = None): def compute_contrast_limit( self, downsample_factor: int = 5, - sample_factor: float = 0.05, + sample_factor: float = 0.10, threshold_factor: float = 0.01, ) -> tuple[float, float]: """Calculate the contrast limits using decimation. @@ -536,16 +531,15 @@ def compute_contrast_limit( # Compute threshold and lower_change threshold sample_size = int(sample_factor * len(diff_decimated)) - initial_flat = np.mean(cdf[:sample_size]) # Average of first 50 points (assumed flat region) - final_flat = np.mean(cdf[-sample_size:]) # Average of last 50 points (assumed flat region) + initial_flat = np.mean(cdf[:sample_size]) # Average of first points (assumed flat region) + final_flat = np.mean(cdf[-sample_size:]) # Average of last points (assumed flat region) midpoint = (initial_flat + final_flat) / 2 - lower_curve_threshold = threshold_factor * midpoint - upper_change_threshold = lower_curve_threshold + curve_threshold = threshold_factor * midpoint # Detect start and end of slope - start_idx_decimated = np.argmax(diff_decimated > lower_curve_threshold) # First large change + start_idx_decimated = np.argmax(diff_decimated > curve_threshold) # First large change end_idx_decimated = ( - np.argmax(diff_decimated[start_idx_decimated + 1 :] < upper_change_threshold) + start_idx_decimated + np.argmax(diff_decimated[start_idx_decimated + 1 :] < curve_threshold) + start_idx_decimated ) # first small change # Map back the indices to original values @@ -568,9 +562,9 @@ def _objective_function(self, params): def _define_parameter_space(self, parameter_optimizer): parameter_optimizer.space_creator( { - "downsample_factor": {"type": "randint", "args": [2, 8]}, + "downsample_factor": {"type": "randint", "args": [3, 7]}, "sample_factor": {"type": "uniform", "args": [0.01, 0.1]}, - "threshold_factor": {"type": "uniform", "args": [0.001, 0.1]}, + "threshold_factor": {"type": "uniform", "args": [0.005, 0.2]}, }, ) diff --git a/manual_tests/contrast_limits_from_api.py b/manual_tests/contrast_limits_from_api.py index 653eb06..6741c8d 100644 --- a/manual_tests/contrast_limits_from_api.py +++ b/manual_tests/contrast_limits_from_api.py @@ -139,9 +139,9 @@ def run_all_contrast_limit_calculations( } hyperopt_evals_dict = { "percentile": 5, - "gmm": 5, - "cdf": 5, - "decimation": 5, + "gmm": 500, + "cdf": 500, + "decimation": 500, } for key, calc in calculator_dict.items(): max_hyperopt_evals = hyperopt_evals_dict[key] @@ -164,6 +164,7 @@ def run_all_contrast_limit_calculations( with open(output_path / f"contrast_limits_{id_}.json", "w") as f: combined_dict = {k: {"limits": v, "info": info_dict[k]} for k, v in limits_dict.items()} + combined_dict["real_limits"] = volume_limit json.dump(combined_dict, f, cls=NpEncoder, indent=4) # Check which method is closest to the human contrast limits