From 60d13806223dcf3129dc41d12a52b5f899b1a4a2 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Thu, 11 Dec 2025 00:31:02 -0800 Subject: [PATCH 01/48] Refactor regression code to incorporate optional computation of pairwise subspace orthogonality metrics --- simplexity/analysis/linear_regression.py | 275 +++++++++++++++++++++-- 1 file changed, 261 insertions(+), 14 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 1ce5c086..18e81a07 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -10,6 +10,7 @@ import numpy as np from simplexity.analysis.normalization import normalize_weights, standardize_features, standardize_targets +from simplexity.logger import SIMPLEXITY_LOGGER def _design_matrix(x: jax.Array, fit_intercept: bool) -> jax.Array: @@ -69,8 +70,33 @@ def linear_regression( beta, _, _, _ = jnp.linalg.lstsq(weighted_design, weighted_targets, rcond=None) predictions = design @ beta scalars = _regression_metrics(predictions, y_arr, w_arr) - projections = {"projected": predictions} - return scalars, projections + arrays = {"projected": predictions, "parameters": beta} + return scalars, arrays + + +def _compute_regression_metrics( + x: jax.Array, + y: jax.Array, + weights: jax.Array | np.ndarray | None, + beta: jax.Array, + predictions: jax.Array | None = None, + *, + fit_intercept: bool = True, +): + x_arr = standardize_features(x) + y_arr = standardize_targets(y) + if x_arr.shape[0] != y_arr.shape[0]: + raise ValueError("Features and targets must share the same first dimension") + if x_arr.shape[0] == 0: + raise ValueError("At least one sample is required") + w_arr = normalize_weights(weights, x_arr.shape[0]) + if w_arr is None: + w_arr = jnp.ones(x_arr.shape[0], dtype=x_arr.dtype) / x_arr.shape[0] + if predictions is None: + design = _design_matrix(x_arr, fit_intercept) + predictions = design @ beta + scalars = _regression_metrics(predictions, y_arr, w_arr) + return scalars def _compute_beta_from_svd( @@ -115,6 +141,7 @@ def linear_regression_svd( best_scalars: Mapping[str, float] | None = None best_rcond = rconds[0] best_error = float("inf") + best_beta: jax.Array | None = None for rcond in rconds: threshold = rcond * max_singular beta = _compute_beta_from_svd(u, s, vh, weighted_targets, threshold) @@ -128,12 +155,193 @@ def linear_regression_svd( best_pred = predictions best_scalars = scalars best_rcond = rcond - if best_pred is None or best_scalars is None: + best_beta = beta + if best_pred is None or best_scalars is None or best_beta is None: raise RuntimeError("Unable to compute linear regression solution") scalars = dict(best_scalars) scalars["best_rcond"] = float(best_rcond) - projections = {"projected": best_pred} - return scalars, projections + arrays = {"projected": best_pred, "parameters": best_beta} + return scalars, arrays + + +def _process_individual_factors( + layer_activations: jax.Array, + belief_states: tuple[jax.Array, ...], + weights: jax.Array, + use_svd: bool, + **kwargs: Any, +) -> list[tuple[Mapping[str, float], Mapping[str, jax.Array]]]: + """Process each factor individually using either standard or SVD regression.""" + results = [] + regression_fn = linear_regression_svd if use_svd else linear_regression + for factor in belief_states: + if not isinstance(factor, jax.Array): + raise ValueError("Each factor in belief_states must be a jax.Array") + factor_scalars, factor_arrays = regression_fn(layer_activations, factor, weights, **kwargs) + results.append((factor_scalars, factor_arrays)) + return results + + +def _merge_results_with_prefix( + scalars: dict[str, float], + arrays: dict[str, jax.Array], + results: tuple[Mapping[str, float], Mapping[str, jax.Array]], + prefix: str, +) -> None: + results_scalars, results_arrays = results + scalars.update({f"{prefix}/{key}": value for key, value in results_scalars.items()}) + arrays.update({f"{prefix}/{key}": value for key, value in results_arrays.items()}) + + +def _split_concat_results( + layer_activations: jax.Array, + weights: jax.Array, + belief_states: tuple[jax.Array, ...], + concat_results: tuple[Mapping[str, float], Mapping[str, jax.Array]], + **kwargs: Any, +) -> list[tuple[Mapping[str, float], Mapping[str, jax.Array]]]: + """Split concatenated regression results into individual factors.""" + _, concat_arrays = concat_results + + # Split the concatenated parameters and projections into the individual factors + factor_dims = [factor.shape[-1] for factor in belief_states] + split_indices = jnp.cumsum(jnp.array(factor_dims))[:-1] + + parameters_list = jnp.split(concat_arrays["parameters"], split_indices, axis=-1) + projections_list = jnp.split(concat_arrays["projected"], split_indices, axis=-1) + + # Only recompute scalar metrics, reuse projections and parameters + # Filter out rcond_values from kwargs (only relevant for SVD during fitting, not metrics) + metrics_kwargs = {k: v for k, v in kwargs.items() if k != 'rcond_values'} + results = [] + for factor, parameters, projections in zip(belief_states, parameters_list, projections_list): + factor_scalars = _compute_regression_metrics( + layer_activations, + factor, + weights, + parameters, + predictions=projections, + **metrics_kwargs, + ) + factor_arrays = {"projected": projections, "parameters": parameters} + results.append((factor_scalars, factor_arrays)) + return results + + +def _compute_subspace_orthogonality( + parameters_pair: list[jax.Array], +) -> tuple[dict[str, float], dict[str, jax.Array]]: + # Compute the orthonormal bases for the two subspaces using QR decomposition + q1, _ = jnp.linalg.qr(parameters_pair[0]) + q2, _ = jnp.linalg.qr(parameters_pair[1]) + # Compute the singular values of the interaction matrix + interaction_matrix = q1.T @ q2 + singular_values = jnp.linalg.svd(interaction_matrix, compute_uv=False) + + # Clip the singular values to the range [0, 1] + singular_values = jnp.clip(singular_values, 0, 1) + + # Compute the subspace overlap score + min_dim = min(q1.shape[1], q2.shape[1]) + subspace_overlap_score = jnp.sum(singular_values**2) / min_dim + + # Compute the max singular value + max_singular_value = jnp.max(singular_values) + + # Compute the min singular value + min_singular_value = jnp.min(singular_values) + + # Compute the participation ratio + participation_ratio = jnp.sum(singular_values**2)**2 / jnp.sum(singular_values**4) + + # Compute the entropy + probs = singular_values**2 / jnp.sum(singular_values**2) + entropy = -jnp.sum(probs * jnp.log(probs)) + + # Compute the effective rank + effective_rank = jnp.exp(entropy) + + scalars = { + "subspace_overlap": float(subspace_overlap_score), + "max_singular_value": float(max_singular_value), + "min_singular_value": float(min_singular_value), + "participation_ratio": float(participation_ratio), + "entropy": float(entropy), + "effective_rank": float(effective_rank), + } + + singular_values = { + "singular_values": singular_values, + } + + return scalars, singular_values + + +def _compute_all_pairwise_orthogonality( + parameters_list: list[jax.Array], +) -> tuple[dict[str, float], dict[str, jax.Array]]: + scalars = {} + singular_values = {} + factor_pairs = list(itertools.combinations(range(len(parameters_list)), 2)) + for i, j in factor_pairs: + params_pair = [ + parameters_list[i], + parameters_list[j], + ] + orthogonality_scalars, orthogonality_singular_values = _compute_subspace_orthogonality(params_pair) + scalars.update({ + f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items() + }) + singular_values.update({ + f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_singular_values.items() + }) + return scalars, singular_values + + +def _handle_factored_regression( + layer_activations: jax.Array, + weights: jax.Array, + belief_states: tuple[jax.Array, ...], + concat_belief_states: bool, + compute_subspace_orthogonality: bool, + use_svd: bool, + **kwargs: Any, +) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: + """Handle regression for factored belief states using either standard or SVD method.""" + scalars: dict[str, float] = {} + arrays: dict[str, jax.Array] = {} + + regression_fn = linear_regression_svd if use_svd else linear_regression + + # Process concatenated belief states if requested + if concat_belief_states: + belief_states_concat = jnp.concatenate(belief_states, axis=-1) + concat_results = regression_fn(layer_activations, belief_states_concat, weights, **kwargs) + _merge_results_with_prefix(scalars, arrays, concat_results, "concat") + + # Split the concatenated parameters and projections into the individual factors + factor_results = _split_concat_results( + layer_activations, + weights, + belief_states, + concat_results, + **kwargs, + ) + else: + factor_results = _process_individual_factors( + layer_activations, belief_states, weights, use_svd, **kwargs + ) + + for factor_idx, factor_result in enumerate(factor_results): + _merge_results_with_prefix(scalars, arrays, factor_result, f"factor_{factor_idx}") + + if compute_subspace_orthogonality: + parameters_list = [factor_arrays["parameters"] for _, factor_arrays in factor_results] + orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(parameters_list) + scalars.update(orthogonality_scalars) + arrays.update(orthogonality_singular_values) + + return scalars, arrays def _apply_layer_regression( @@ -168,25 +376,64 @@ def layer_linear_regression( layer_activations: jax.Array, weights: jax.Array, belief_states: jax.Array | tuple[jax.Array, ...] | None, - to_factors: bool = False, + concat_belief_states: bool = False, + compute_subspace_orthogonality: bool = False, + use_svd: bool = False, **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """Layer-wise regression helper that wraps :func:`linear_regression`.""" + """ + Layer-wise regression helper that wraps :func:`linear_regression` or :func:`linear_regression_svd`. + + Args: + layer_activations: Neural network activations for a single layer + weights: Sample weights for weighted regression + belief_states: Target belief states (single array or tuple for factored processes) + concat_belief_states: If True and belief_states is a tuple, concatenate and regress jointly + compute_subspace_orthogonality: If True, compute orthogonality between factor subspaces + use_svd: If True, use SVD-based regression instead of standard least squares + **kwargs: Additional arguments passed to regression function (fit_intercept, rcond_values, etc.) + + Returns: + scalars: Dictionary of scalar metrics + arrays: Dictionary of arrays (projected predictions, parameters, singular values if orthogonality computed) + """ if belief_states is None: raise ValueError("linear_regression requires belief_states") - return _apply_layer_regression(linear_regression, layer_activations, weights, belief_states, to_factors, **kwargs) + + regression_fn = linear_regression_svd if use_svd else linear_regression + + if not isinstance(belief_states, tuple): + if compute_subspace_orthogonality: + SIMPLEXITY_LOGGER.warning("Subspace orthogonality cannot be computed for a single belief state") + scalars, arrays = regression_fn(layer_activations, belief_states, weights, **kwargs) + return scalars, arrays + + return _handle_factored_regression( + layer_activations, weights, belief_states, + concat_belief_states, compute_subspace_orthogonality, use_svd, **kwargs + ) def layer_linear_regression_svd( layer_activations: jax.Array, weights: jax.Array, belief_states: jax.Array | tuple[jax.Array, ...] | None, - to_factors: bool = False, + concat_belief_states: bool = False, + compute_subspace_orthogonality: bool = False, **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """Layer-wise regression helper that wraps :func:`linear_regression_svd`.""" - if belief_states is None: - raise ValueError("linear_regression_svd requires belief_states") - return _apply_layer_regression( - linear_regression_svd, layer_activations, weights, belief_states, to_factors, **kwargs + """ + Layer-wise SVD regression helper (wrapper around layer_linear_regression with use_svd=True). + + This function is provided for backward compatibility and convenience. + Consider using layer_linear_regression with use_svd=True for new code. + """ + return layer_linear_regression( + layer_activations=layer_activations, + weights=weights, + belief_states=belief_states, + concat_belief_states=concat_belief_states, + compute_subspace_orthogonality=compute_subspace_orthogonality, + use_svd=True, + **kwargs, ) From 09c1d89a24afcb9fec6ccc8552b81457f19b37ec Mon Sep 17 00:00:00 2001 From: loren-ac Date: Thu, 11 Dec 2025 18:46:22 -0800 Subject: [PATCH 02/48] Refine regression API and add comprehensive orthogonality tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Separate coeffs/intercept in return structure (omit intercept key when fit_intercept=False) - Rename to_factors → concat_belief_states for clarity - Add 9 orthogonality tests with principled numerical thresholds (safety_factor=10) - Test orthogonal, aligned, contained subspaces; multi-factor scenarios; edge cases - Update validators and existing tests for new parameter structure - Add informative assertion messages for debugging numerical precision --- simplexity/analysis/layerwise_analysis.py | 21 +- simplexity/analysis/linear_regression.py | 96 ++- tests/analysis/test_layerwise_analysis.py | 69 +- tests/analysis/test_linear_regression.py | 913 +++++++++++++++++++--- 4 files changed, 955 insertions(+), 144 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index e76e1c6d..e22c6f4c 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -34,23 +34,31 @@ class AnalysisRegistration: def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) - allowed = {"fit_intercept", "to_factors"} + allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd"} unexpected = set(provided) - allowed if unexpected: raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}") fit_intercept = bool(provided.get("fit_intercept", True)) - to_factors = bool(provided.get("to_factors", False)) - return {"fit_intercept": fit_intercept, "to_factors": to_factors} + concat_belief_states = bool(provided.get("concat_belief_states", False)) + compute_subspace_orthogonality = bool(provided.get("compute_subspace_orthogonality", False)) + use_svd = bool(provided.get("use_svd", False)) + return { + "fit_intercept": fit_intercept, + "concat_belief_states": concat_belief_states, + "compute_subspace_orthogonality": compute_subspace_orthogonality, + "use_svd": use_svd, + } def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) - allowed = {"fit_intercept", "rcond_values", "to_factors"} + allowed = {"fit_intercept", "rcond_values", "concat_belief_states", "compute_subspace_orthogonality"} unexpected = set(provided) - allowed if unexpected: raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}") fit_intercept = bool(provided.get("fit_intercept", True)) - to_factors = bool(provided.get("to_factors", False)) + concat_belief_states = bool(provided.get("concat_belief_states", False)) + compute_subspace_orthogonality = bool(provided.get("compute_subspace_orthogonality", False)) rcond_values = provided.get("rcond_values") if rcond_values is not None: if not isinstance(rcond_values, (list, tuple)): @@ -60,7 +68,8 @@ def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> rcond_values = tuple(float(v) for v in rcond_values) return { "fit_intercept": fit_intercept, - "to_factors": to_factors, + "concat_belief_states": concat_belief_states, + "compute_subspace_orthogonality": compute_subspace_orthogonality, "rcond_values": rcond_values, } diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 18e81a07..8b1577cf 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -70,7 +70,20 @@ def linear_regression( beta, _, _, _ = jnp.linalg.lstsq(weighted_design, weighted_targets, rcond=None) predictions = design @ beta scalars = _regression_metrics(predictions, y_arr, w_arr) - arrays = {"projected": predictions, "parameters": beta} + + # Separate intercept and coefficients + if fit_intercept: + arrays = { + "projected": predictions, + "coeffs": beta[1:], # Linear coefficients (excluding intercept) + "intercept": beta[0:1], # Intercept term (keep 2D: [1, n_targets]) + } + else: + arrays = { + "projected": predictions, + "coeffs": beta, # All parameters are coefficients when no intercept + } + return scalars, arrays @@ -160,7 +173,20 @@ def linear_regression_svd( raise RuntimeError("Unable to compute linear regression solution") scalars = dict(best_scalars) scalars["best_rcond"] = float(best_rcond) - arrays = {"projected": best_pred, "parameters": best_beta} + + # Separate intercept and coefficients + if fit_intercept: + arrays = { + "projected": best_pred, + "coeffs": best_beta[1:], # Linear coefficients (excluding intercept) + "intercept": best_beta[0:1], # Intercept term (keep 2D: [1, n_targets]) + } + else: + arrays = { + "projected": best_pred, + "coeffs": best_beta, # All parameters are coefficients when no intercept + } + return scalars, arrays @@ -203,37 +229,62 @@ def _split_concat_results( """Split concatenated regression results into individual factors.""" _, concat_arrays = concat_results - # Split the concatenated parameters and projections into the individual factors + # Split the concatenated coefficients and projections into the individual factors factor_dims = [factor.shape[-1] for factor in belief_states] split_indices = jnp.cumsum(jnp.array(factor_dims))[:-1] - parameters_list = jnp.split(concat_arrays["parameters"], split_indices, axis=-1) + coeffs_list = jnp.split(concat_arrays["coeffs"], split_indices, axis=-1) projections_list = jnp.split(concat_arrays["projected"], split_indices, axis=-1) - # Only recompute scalar metrics, reuse projections and parameters + # Handle intercept - split if present + if "intercept" in concat_arrays: + intercepts_list = jnp.split(concat_arrays["intercept"], split_indices, axis=-1) + else: + intercepts_list = [None] * len(belief_states) + + # Only recompute scalar metrics, reuse projections and coefficients # Filter out rcond_values from kwargs (only relevant for SVD during fitting, not metrics) metrics_kwargs = {k: v for k, v in kwargs.items() if k != 'rcond_values'} + fit_intercept = kwargs.get("fit_intercept", True) + results = [] - for factor, parameters, projections in zip(belief_states, parameters_list, projections_list): + for factor, coeffs, intercept, projections in zip(belief_states, coeffs_list, intercepts_list, projections_list): + # Reconstruct full beta for metrics computation + if intercept is not None: + beta = jnp.concatenate([intercept, coeffs], axis=0) + else: + beta = coeffs + factor_scalars = _compute_regression_metrics( layer_activations, factor, weights, - parameters, + beta, predictions=projections, **metrics_kwargs, ) - factor_arrays = {"projected": projections, "parameters": parameters} + + # Build factor arrays - include intercept only if present + factor_arrays = {"projected": projections, "coeffs": coeffs} + if intercept is not None: + factor_arrays["intercept"] = intercept + results.append((factor_scalars, factor_arrays)) return results def _compute_subspace_orthogonality( - parameters_pair: list[jax.Array], + coeffs_pair: list[jax.Array], ) -> tuple[dict[str, float], dict[str, jax.Array]]: + """ + Compute orthogonality metrics between two coefficient subspaces. + + Args: + coeffs_pair: List of two coefficient matrices (excludes intercept) + """ # Compute the orthonormal bases for the two subspaces using QR decomposition - q1, _ = jnp.linalg.qr(parameters_pair[0]) - q2, _ = jnp.linalg.qr(parameters_pair[1]) + q1, _ = jnp.linalg.qr(coeffs_pair[0]) + q2, _ = jnp.linalg.qr(coeffs_pair[1]) # Compute the singular values of the interaction matrix interaction_matrix = q1.T @ q2 singular_values = jnp.linalg.svd(interaction_matrix, compute_uv=False) @@ -278,17 +329,23 @@ def _compute_subspace_orthogonality( def _compute_all_pairwise_orthogonality( - parameters_list: list[jax.Array], + coeffs_list: list[jax.Array], ) -> tuple[dict[str, float], dict[str, jax.Array]]: + """ + Compute pairwise orthogonality metrics for all factor pairs. + + Args: + coeffs_list: List of coefficient matrices (one per factor, excludes intercepts) + """ scalars = {} singular_values = {} - factor_pairs = list(itertools.combinations(range(len(parameters_list)), 2)) + factor_pairs = list(itertools.combinations(range(len(coeffs_list)), 2)) for i, j in factor_pairs: - params_pair = [ - parameters_list[i], - parameters_list[j], + coeffs_pair = [ + coeffs_list[i], + coeffs_list[j], ] - orthogonality_scalars, orthogonality_singular_values = _compute_subspace_orthogonality(params_pair) + orthogonality_scalars, orthogonality_singular_values = _compute_subspace_orthogonality(coeffs_pair) scalars.update({ f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items() }) @@ -336,8 +393,9 @@ def _handle_factored_regression( _merge_results_with_prefix(scalars, arrays, factor_result, f"factor_{factor_idx}") if compute_subspace_orthogonality: - parameters_list = [factor_arrays["parameters"] for _, factor_arrays in factor_results] - orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(parameters_list) + # Extract coefficients (excludes intercept) for orthogonality computation + coeffs_list = [factor_arrays["coeffs"] for _, factor_arrays in factor_results] + orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(coeffs_list) scalars.update(orthogonality_scalars) arrays.update(orthogonality_singular_values) diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index c0d1a839..2df502d2 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -30,7 +30,7 @@ def test_layerwise_analysis_linear_regression_namespacing(analysis_inputs) -> No ) assert set(scalars) >= {"layer_a_r2", "layer_b_r2"} - assert set(projections) == {"layer_a_projected", "layer_b_projected"} + assert set(projections) == {"layer_a_projected", "layer_b_projected", "layer_a_coeffs", "layer_b_coeffs", "layer_a_intercept", "layer_b_intercept"} def test_layerwise_analysis_requires_targets(analysis_inputs) -> None: @@ -161,3 +161,70 @@ def test_layerwise_analysis_property_accessors() -> None: assert analysis.concat_layers assert not analysis.use_probs_as_weights assert not analysis.requires_belief_states + + +def test_linear_regression_accepts_concat_belief_states() -> None: + """linear_regression validator should accept concat_belief_states parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": False, "concat_belief_states": True}) + + assert params["fit_intercept"] is False + assert params["concat_belief_states"] is True + + +def test_linear_regression_svd_accepts_concat_belief_states() -> None: + """linear_regression_svd validator should accept concat_belief_states parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + params = validator({"fit_intercept": True, "concat_belief_states": True, "rcond_values": [1e-3]}) + + assert params["fit_intercept"] is True + assert params["concat_belief_states"] is True + assert params["rcond_values"] == (0.001,) + + +def test_linear_regression_concat_belief_states_defaults_false() -> None: + """concat_belief_states should default to False when not provided.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": True}) + + assert params["concat_belief_states"] is False + +def test_linear_regression_accepts_compute_subspace_orthogonality() -> None: + """linear_regression validator should accept compute_subspace_orthogonality parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": True, "compute_subspace_orthogonality": True}) + + assert params["fit_intercept"] is True + assert params["compute_subspace_orthogonality"] is True + +def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: + """linear_regression_svd validator should accept compute_subspace_orthogonality parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + params = validator({"fit_intercept": True, "compute_subspace_orthogonality": True, "rcond_values": [1e-3]}) + + assert params["fit_intercept"] is True + assert params["compute_subspace_orthogonality"] is True + assert params["rcond_values"] == (0.001,) + +def test_linear_regression_compute_subspace_orthogonality_defaults_false() -> None: + """compute_subspace_orthogonality should default to False when not provided.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": True}) + + assert params["compute_subspace_orthogonality"] is False + + +def test_linear_regression_accepts_use_svd() -> None: + """linear_regression validator should accept use_svd parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"use_svd": True}) + + assert params["use_svd"] is True + + +def test_linear_regression_use_svd_defaults_false() -> None: + """use_svd should default to False when not provided.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({}) + + assert params["use_svd"] is False diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index f6bfe084..63a6fd22 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -1,6 +1,7 @@ """Tests for reusable linear regression helpers.""" import chex +import jax import jax.numpy as jnp import pytest @@ -12,18 +13,44 @@ ) +def _compute_orthogonality_threshold( + x: jax.Array, + *factors: jax.Array, + safety_factor: int = 10, +) -> float: + """Compute principled threshold for near-zero orthogonality checks. + + Threshold is based on machine precision scaled by problem dimensions. + For orthogonality via QR + SVD, typical numerical error is O(ε·n) where + ε is machine epsilon and n is the maximum relevant dimension. + + Args: + x: Input features array (used for dtype and dimension) + *factors: Factor arrays being compared (used for output dimensions) + safety_factor: Multiplicative safety factor (default 1000) + + Returns: + Threshold value for considering singular values as effectively zero + """ + eps = jnp.finfo(x.dtype).eps + n_features = x.shape[1] + factor_dims = [f.shape[1] for f in factors] + max_dim = max(n_features, *factor_dims) + return float(max_dim * eps * safety_factor) + + def test_linear_regression_perfect_fit() -> None: """Verify weighted least squares recovers a perfect linear relation.""" x = jnp.arange(6.0).reshape(-1, 1) y = 3.0 * x + 2.0 weights = jnp.ones(x.shape[0]) - scalars, projections = linear_regression(x, y, weights) + scalars, arrays = linear_regression(x, y, weights) assert pytest.approx(1.0) == scalars["r2"] assert pytest.approx(0.0, abs=1e-5) == scalars["rmse"] assert pytest.approx(0.0, abs=1e-5) == scalars["mae"] - chex.assert_trees_all_close(projections["projected"], y) + chex.assert_trees_all_close(arrays["projected"], y) def test_linear_regression_svd_selects_best_rcond() -> None: @@ -32,7 +59,7 @@ def test_linear_regression_svd_selects_best_rcond() -> None: y = jnp.sum(x, axis=1, keepdims=True) weights = jnp.array([0.1, 0.2, 0.3, 0.4]) - scalars, projections = linear_regression_svd( + scalars, arrays = linear_regression_svd( x, y, weights, @@ -40,7 +67,7 @@ def test_linear_regression_svd_selects_best_rcond() -> None: ) assert scalars["best_rcond"] in {1e-6, 1e-4, 1e-2} - chex.assert_trees_all_close(projections["projected"], y) + chex.assert_trees_all_close(arrays["projected"], y) def test_layer_regression_requires_targets() -> None: @@ -90,10 +117,10 @@ def test_linear_regression_without_intercept_uses_uniform_weights() -> None: x = jnp.arange(1.0, 4.0)[:, None] y = 2.0 * x - scalars, projections = linear_regression(x, y, None, fit_intercept=False) + scalars, arrays = linear_regression(x, y, None, fit_intercept=False) assert pytest.approx(1.0) == scalars["r2"] - chex.assert_trees_all_close(projections["projected"], y) + chex.assert_trees_all_close(arrays["projected"], y) def test_linear_regression_svd_handles_empty_features() -> None: @@ -102,10 +129,10 @@ def test_linear_regression_svd_handles_empty_features() -> None: y = jnp.arange(3.0)[:, None] weights = jnp.ones(3) - scalars, projections = linear_regression_svd(x, y, weights, fit_intercept=False) + scalars, arrays = linear_regression_svd(x, y, weights, fit_intercept=False) assert scalars["best_rcond"] == pytest.approx(1e-15) - chex.assert_trees_all_close(projections["projected"], jnp.zeros_like(y)) + chex.assert_trees_all_close(arrays["projected"], jnp.zeros_like(y)) def test_linear_regression_accepts_one_dimensional_inputs() -> None: @@ -114,10 +141,10 @@ def test_linear_regression_accepts_one_dimensional_inputs() -> None: y = 5.0 * x + 1.0 weights = jnp.ones_like(x) - scalars, projections = linear_regression(x, y, weights) + scalars, arrays = linear_regression(x, y, weights) assert pytest.approx(1.0) == scalars["r2"] - chex.assert_trees_all_close(projections["projected"], y[:, None]) + chex.assert_trees_all_close(arrays["projected"], y[:, None]) def test_linear_regression_rejects_high_rank_inputs() -> None: @@ -174,7 +201,7 @@ def test_layer_linear_regression_svd_runs_end_to_end() -> None: weights = jnp.ones(3) / 3.0 beliefs = 2.0 * x.sum(axis=1, keepdims=True) - scalars, projections = layer_linear_regression_svd( + scalars, arrays = layer_linear_regression_svd( x, weights, beliefs, @@ -182,11 +209,11 @@ def test_layer_linear_regression_svd_runs_end_to_end() -> None: ) assert pytest.approx(1.0, abs=1e-6) == scalars["r2"] - chex.assert_trees_all_close(projections["projected"], beliefs) + chex.assert_trees_all_close(arrays["projected"], beliefs) -def test_layer_linear_regression_to_factors_basic() -> None: - """Layer regression with to_factors should regress to each factor separately.""" +def test_layer_linear_regression_belief_states_tuple_default() -> None: + """By default, layer regression should regress to each factor separately if given a tuple of belief states.""" x = jnp.arange(12.0).reshape(4, 3) # 4 samples, 3 features weights = jnp.ones(4) / 4.0 @@ -195,11 +222,10 @@ def test_layer_linear_regression_to_factors_basic() -> None: factor_1 = jnp.array([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3], [0.4, 0.4, 0.2], [0.3, 0.3, 0.4]]) # [4, 3] factored_beliefs = (factor_0, factor_1) - scalars, projections = layer_linear_regression( + scalars, arrays = layer_linear_regression( x, weights, factored_beliefs, - to_factors=True, ) # Should have separate metrics for each factor @@ -213,16 +239,29 @@ def test_layer_linear_regression_to_factors_basic() -> None: assert "factor_1/dist" in scalars # Should have separate projections for each factor - assert "factor_0/projected" in projections - assert "factor_1/projected" in projections + assert "factor_0/projected" in arrays + assert "factor_1/projected" in arrays + + # Should have separate parameters for each factor + assert "factor_0/coeffs" in arrays + assert "factor_1/coeffs" in arrays + + # Should have separate intercepts for each factor by default + assert "factor_0/intercept" in arrays + assert "factor_1/intercept" in arrays # Check shapes - assert projections["factor_0/projected"].shape == factor_0.shape - assert projections["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/projected"].shape == factor_0.shape + assert arrays["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/coeffs"].shape == (x.shape[1], factor_0.shape[1]) + assert arrays["factor_1/coeffs"].shape == (x.shape[1], factor_1.shape[1]) + assert arrays["factor_0/intercept"].shape == (1, factor_0.shape[1]) + assert arrays["factor_1/intercept"].shape == (1, factor_1.shape[1]) -def test_layer_linear_regression_svd_to_factors_basic() -> None: - """Layer regression SVD with to_factors should regress to each factor separately.""" + +def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: + """By default, layer regression SVD should regress to each factor separately if given a tuple of belief states.""" x = jnp.arange(12.0).reshape(4, 3) # 4 samples, 3 features weights = jnp.ones(4) / 4.0 @@ -231,30 +270,43 @@ def test_layer_linear_regression_svd_to_factors_basic() -> None: factor_1 = jnp.array([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3], [0.4, 0.4, 0.2], [0.3, 0.3, 0.4]]) # [4, 3] factored_beliefs = (factor_0, factor_1) - scalars, projections = layer_linear_regression_svd( + scalars, arrays = layer_linear_regression_svd( x, weights, factored_beliefs, - to_factors=True, rcond_values=[1e-6], ) - # Should have separate metrics for each factor including best_rcond - assert "factor_0/r2" in scalars - assert "factor_1/r2" in scalars - assert "factor_0/best_rcond" in scalars - assert "factor_1/best_rcond" in scalars + # Should have ALL regression metrics for each factor including best_rcond + for factor in [0, 1]: + assert f"factor_{factor}/r2" in scalars + assert f"factor_{factor}/rmse" in scalars + assert f"factor_{factor}/mae" in scalars + assert f"factor_{factor}/dist" in scalars + assert f"factor_{factor}/best_rcond" in scalars # Should have separate projections for each factor - assert "factor_0/projected" in projections - assert "factor_1/projected" in projections + assert "factor_0/projected" in arrays + assert "factor_1/projected" in arrays + + # Should have separate coefficients for each factor + assert "factor_0/coeffs" in arrays + assert "factor_1/coeffs" in arrays + + # Should have separate intercepts for each factor by default + assert "factor_0/intercept" in arrays + assert "factor_1/intercept" in arrays # Check shapes - assert projections["factor_0/projected"].shape == factor_0.shape - assert projections["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/projected"].shape == factor_0.shape + assert arrays["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/coeffs"].shape == (x.shape[1], factor_0.shape[1]) + assert arrays["factor_1/coeffs"].shape == (x.shape[1], factor_1.shape[1]) + assert arrays["factor_0/intercept"].shape == (1, factor_0.shape[1]) + assert arrays["factor_1/intercept"].shape == (1, factor_1.shape[1]) -def test_layer_linear_regression_to_factors_single_factor() -> None: +def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: """to_factors=True should work with a single factor tuple.""" x = jnp.arange(9.0).reshape(3, 3) weights = jnp.ones(3) / 3.0 @@ -263,126 +315,751 @@ def test_layer_linear_regression_to_factors_single_factor() -> None: factor_0 = jnp.array([[0.3, 0.7], [0.5, 0.5], [0.8, 0.2]]) factored_beliefs = (factor_0,) - scalars, projections = layer_linear_regression( + scalars, arrays = layer_linear_regression( x, weights, factored_beliefs, - to_factors=True, ) - # Should have metrics for single factor + # Should have ALL metrics, projections, coefficients, and intercept for single factor assert "factor_0/r2" in scalars - assert "factor_0/projected" in projections - assert projections["factor_0/projected"].shape == factor_0.shape + assert "factor_0/rmse" in scalars + assert "factor_0/mae" in scalars + assert "factor_0/dist" in scalars + assert "factor_0/projected" in arrays + assert "factor_0/coeffs" in arrays + assert "factor_0/intercept" in arrays + assert arrays["factor_0/projected"].shape == factor_0.shape + assert arrays["factor_0/coeffs"].shape == (x.shape[1], factor_0.shape[1]) + assert arrays["factor_0/intercept"].shape == (1, factor_0.shape[1]) + + +def test_orthogonality_with_orthogonal_subspaces() -> None: + """Orthogonal factors constructed explicitly should have near-zero overlap.""" + + # Create truly orthogonal coefficient matrices by construction + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Define orthogonal coefficient matrices + # W_0 uses first 3 features, W_1 uses last 3 features + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # (6, 2) + + # Generate factors using orthogonal subspaces (no intercept for simplicity) + factor_0 = x @ W_0 # (100, 2) + factor_1 = x @ W_1 # (100, 2) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, # No intercept for cleaner test + ) -def test_layer_linear_regression_to_factors_requires_tuple() -> None: - """to_factors=True requires belief_states to be a tuple.""" - x = jnp.ones((3, 2)) - weights = jnp.ones(3) / 3.0 - beliefs_array = jnp.ones((3, 2)) + # Should have standard factor metrics with perfect fit + assert scalars["factor_0/r2"] > 0.99 # Should fit nearly perfectly + assert scalars["factor_1/r2"] > 0.99 + + # Should have ALL orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/min_singular_value" in scalars + assert "orthogonality_0_1/participation_ratio" in scalars + assert "orthogonality_0_1/entropy" in scalars + assert "orthogonality_0_1/effective_rank" in scalars + + # Compute principled threshold based on machine precision and problem size + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) + + # Should indicate near-zero overlap (orthogonal by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] < threshold + assert scalars["orthogonality_0_1/max_singular_value"] < threshold + + # Should have singular values in arrays + assert "orthogonality_0_1/singular_values" in arrays + # Both factors have 2 dimensions, so min(2, 2) = 2 singular values + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + # All singular values should be near zero (orthogonal) + assert jnp.all(arrays["orthogonality_0_1/singular_values"] < threshold) + + +def test_orthogonality_with_aligned_subspaces() -> None: + """Aligned factors with identical column spaces should have high overlap.""" + + # Create truly aligned coefficient matrices by construction + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Define aligned coefficient matrices - W_1 = W_0 @ A for invertible A + # This ensures span(W_1) = span(W_0) + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + W_1 = jnp.array([[0.5, 1.0], [1.0, 0.5], [1.5, 1.5], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + + # Generate factors using aligned subspaces (no intercept for simplicity) + factor_0 = x @ W_0 # (100, 2) + factor_1 = x @ W_1 # (100, 2) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples - with pytest.raises(ValueError, match="belief_states must be a tuple when to_factors is True"): - layer_linear_regression(x, weights, beliefs_array, to_factors=True) + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, # No intercept for cleaner test + ) - with pytest.raises(ValueError, match="belief_states must be a tuple when to_factors is True"): - layer_linear_regression_svd(x, weights, beliefs_array, to_factors=True) + # Should have standard factor metrics with perfect fit + assert scalars["factor_0/r2"] > 0.99 # Should fit nearly perfectly + assert scalars["factor_1/r2"] > 0.99 + + # Should have ALL orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/min_singular_value" in scalars + assert "orthogonality_0_1/participation_ratio" in scalars + assert "orthogonality_0_1/entropy" in scalars + assert "orthogonality_0_1/effective_rank" in scalars + + # Should indicate high overlap (aligned by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] > 0.99 + assert scalars["orthogonality_0_1/max_singular_value"] > 0.99 + + # Should have singular values in arrays + assert "orthogonality_0_1/singular_values" in arrays + # Both factors have 2 dimensions, so min(2, 2) = 2 singular values + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + # All singular values should be near 1.0 (perfectly aligned) + assert jnp.all(arrays["orthogonality_0_1/singular_values"] > 0.99) + + +def test_orthogonality_with_three_factors() -> None: + """Three factors should produce all pairwise orthogonality combinations.""" + + # Create three mutually orthogonal coefficient matrices + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Define three orthogonal coefficient matrices using disjoint features + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 0-1 + W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 2-3 + W_2 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) # Uses features 4-5 + + # Generate factors using orthogonal subspaces + factor_0 = x @ W_0 # (100, 2) + factor_1 = x @ W_1 # (100, 2) + factor_2 = x @ W_2 # (100, 2) + factored_beliefs = (factor_0, factor_1, factor_2) + weights = jnp.ones(n_samples) / n_samples + + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + # Should have standard factor metrics for all three factors + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + assert scalars["factor_2/r2"] > 0.99 -def test_layer_linear_regression_to_factors_validates_tuple_contents() -> None: - """to_factors=True requires all elements in tuple to be jax.Arrays.""" - x = jnp.ones((3, 2)) - weights = jnp.ones(3) / 3.0 + # Compute principled threshold based on machine precision and problem size + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1, factor_2) - # Invalid: tuple contains non-array - invalid_beliefs = (jnp.ones((3, 2)), "not an array") # type: ignore + # Should have ALL three pairwise orthogonality combinations + pairwise_keys = ["orthogonality_0_1", "orthogonality_0_2", "orthogonality_1_2"] + for pair_key in pairwise_keys: + assert f"{pair_key}/subspace_overlap" in scalars + assert f"{pair_key}/max_singular_value" in scalars + assert f"{pair_key}/min_singular_value" in scalars + assert f"{pair_key}/participation_ratio" in scalars + assert f"{pair_key}/entropy" in scalars + assert f"{pair_key}/effective_rank" in scalars + assert f"{pair_key}/singular_values" in arrays - with pytest.raises(ValueError, match="Each factor in belief_states must be a jax.Array"): - layer_linear_regression(x, weights, invalid_beliefs, to_factors=True) # type: ignore + # All pairs should be orthogonal (near-zero overlap) + overlap = scalars[f"{pair_key}/subspace_overlap"] + assert overlap < threshold, f"{pair_key} subspace_overlap={overlap} >= threshold={threshold}" - with pytest.raises(ValueError, match="Each factor in belief_states must be a jax.Array"): - layer_linear_regression_svd(x, weights, invalid_beliefs, to_factors=True) # type: ignore + max_sv = scalars[f"{pair_key}/max_singular_value"] + assert max_sv < threshold, f"{pair_key} max_singular_value={max_sv} >= threshold={threshold}" + # Each pair has 2D subspaces, so 2 singular values + assert arrays[f"{pair_key}/singular_values"].shape[0] == 2 + svs = arrays[f"{pair_key}/singular_values"] + assert jnp.all(svs < threshold), f"{pair_key} singular_values={svs} not all < threshold={threshold}" -def test_layer_linear_regression_to_factors_false_works() -> None: - """to_factors=False requires belief_states to be a single array, not a tuple.""" - x = jnp.ones((3, 2)) - weights = jnp.ones(3) / 3.0 - # Invalid: tuple when to_factors=False - factored_beliefs = (jnp.ones((3, 2)), jnp.ones((3, 3))) +def test_orthogonality_not_computed_by_default() -> None: + """Orthogonality metrics should not be computed when compute_subspace_orthogonality=False.""" - scalars, projections = layer_linear_regression(x, weights, factored_beliefs, to_factors=False) - assert "r2" in scalars - assert "projected" in projections - assert projections["projected"].shape == (3, 5) + # Setup two-factor regression + n_samples, n_features = 50, 4 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) - scalars, projections = layer_linear_regression_svd(x, weights, factored_beliefs, to_factors=False) + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + + factor_0 = x @ W_0 + factor_1 = x @ W_1 + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + # Run WITHOUT compute_subspace_orthogonality (default is False) + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + fit_intercept=False, + ) + + # Should have standard factor metrics + assert "factor_0/r2" in scalars + assert "factor_1/r2" in scalars + + # Should NOT have any orthogonality metrics + orthogonality_keys = [ + "orthogonality_0_1/subspace_overlap", + "orthogonality_0_1/max_singular_value", + "orthogonality_0_1/min_singular_value", + "orthogonality_0_1/participation_ratio", + "orthogonality_0_1/entropy", + "orthogonality_0_1/effective_rank", + ] + for key in orthogonality_keys: + assert key not in scalars + + # Should NOT have orthogonality singular values in arrays + assert "orthogonality_0_1/singular_values" not in arrays + + +def test_orthogonality_warning_for_single_belief_state(caplog: pytest.LogCaptureFixture) -> None: + """Should warn when requesting orthogonality with a single belief state.""" + + # Setup single-factor regression + n_samples, n_features = 30, 4 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + belief_state = jax.random.normal(key, (n_samples, 2)) + weights = jnp.ones(n_samples) / n_samples + + # Request orthogonality with single belief state (not a tuple) + with caplog.at_level("WARNING"): + scalars, arrays = layer_linear_regression( + x, + weights, + belief_state, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + + # Should have logged a warning + assert "Subspace orthogonality cannot be computed for a single belief state" in caplog.text + + # Should still run regression successfully assert "r2" in scalars - assert "projected" in projections - assert projections["projected"].shape == (3, 5) + assert "projected" in arrays + # Should NOT have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" not in scalars + assert "orthogonality_0_1/singular_values" not in arrays -def test_factored_regression_perfect_linear_fit() -> None: - """Test factored regression with perfectly linear targets achieves perfect fit. - Uses targets that are exact linear combinations of features to verify - the regression machinery works correctly for the factored case. - """ - # 5 samples, 4 features - x = jnp.array( - [ - [1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [4.0, 5.0, 6.0, 7.0], - [5.0, 6.0, 7.0, 8.0], - ] +def test_use_svd_flag_equivalence() -> None: + """layer_linear_regression with use_svd=True should match layer_linear_regression_svd.""" + + n_samples, n_features = 40, 4 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Test with single belief state + belief_state = jax.random.normal(key, (n_samples, 3)) + weights = jnp.ones(n_samples) / n_samples + rcond_values = [1e-6, 1e-4] + + # Method 1: use_svd=True + scalars_flag, arrays_flag = layer_linear_regression( + x, + weights, + belief_state, + use_svd=True, + rcond_values=rcond_values, + ) + + # Method 2: layer_linear_regression_svd + scalars_wrapper, arrays_wrapper = layer_linear_regression_svd( + x, + weights, + belief_state, + rcond_values=rcond_values, + ) + + # Should produce identical results + assert scalars_flag.keys() == scalars_wrapper.keys() + for key in scalars_flag: + assert scalars_flag[key] == pytest.approx(scalars_wrapper[key]) + + assert arrays_flag.keys() == arrays_wrapper.keys() + for key in arrays_flag: + chex.assert_trees_all_close(arrays_flag[key], arrays_wrapper[key]) + + # Test with factored belief states + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + factor_0 = x @ W_0 + factor_1 = x @ W_1 + factored_beliefs = (factor_0, factor_1) + + # Method 1: use_svd=True with factored beliefs + scalars_flag_fact, arrays_flag_fact = layer_linear_regression( + x, + weights, + factored_beliefs, + use_svd=True, + rcond_values=rcond_values, + ) + + # Method 2: layer_linear_regression_svd with factored beliefs + scalars_wrapper_fact, arrays_wrapper_fact = layer_linear_regression_svd( + x, + weights, + factored_beliefs, + rcond_values=rcond_values, + ) + + # Should produce identical results + assert scalars_flag_fact.keys() == scalars_wrapper_fact.keys() + for key in scalars_flag_fact: + assert scalars_flag_fact[key] == pytest.approx(scalars_wrapper_fact[key]) + + assert arrays_flag_fact.keys() == arrays_wrapper_fact.keys() + for key in arrays_flag_fact: + chex.assert_trees_all_close(arrays_flag_fact[key], arrays_wrapper_fact[key]) + + +def test_use_svd_with_orthogonality() -> None: + """SVD regression should work with orthogonality computation.""" + + n_samples, n_features = 80, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create orthogonal coefficient matrices + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + + factor_0 = x @ W_0 + factor_1 = x @ W_1 + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + # Run SVD regression with orthogonality computation + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + use_svd=True, + compute_subspace_orthogonality=True, + rcond_values=[1e-6], + fit_intercept=False, + ) + + # Should have standard factor metrics with SVD + assert "factor_0/r2" in scalars + assert "factor_1/r2" in scalars + assert "factor_0/best_rcond" in scalars + assert "factor_1/best_rcond" in scalars + + # Should have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/singular_values" in arrays + + # Compute principled threshold + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) + + # Should indicate near-zero overlap (orthogonal by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] < threshold + assert scalars["orthogonality_0_1/max_singular_value"] < threshold + + # Should have good regression fit + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + +def test_orthogonality_with_different_subspace_dimensions() -> None: + """Orthogonality should work when factors have different output dimensions.""" + + n_samples, n_features = 100, 8 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create orthogonal coefficient matrices with different output dimensions + # factor_0 has 2 output dimensions, factor_1 has 5 output dimensions + W_0 = jnp.array([ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ]) # (8, 2) + W_1 = jnp.array([ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]) # (8, 5) + + factor_0 = x @ W_0 # (100, 2) + factor_1 = x @ W_1 # (100, 5) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + + # Should have standard factor metrics + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + # Should have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/singular_values" in arrays + + # Compute principled threshold + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) + + # Should indicate near-zero overlap (orthogonal by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] < threshold + assert scalars["orthogonality_0_1/max_singular_value"] < threshold + + # Singular values shape should be min(2, 5) = 2 + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + assert jnp.all(arrays["orthogonality_0_1/singular_values"] < threshold) + + +def test_orthogonality_with_contained_subspace() -> None: + """Smaller subspace fully contained in larger subspace should show high alignment.""" + + n_samples, n_features = 100, 8 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create coefficient matrices where factor_0's subspace is contained in factor_1's + # factor_0: 2D subspace using features [0, 1] + # factor_1: 3D subspace using features [0, 1, 2] (contains factor_0's space) + W_0 = jnp.array([ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ]) # (8, 2) + W_1 = jnp.array([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ]) # (8, 3) + + factor_0 = x @ W_0 # (100, 2) + factor_1 = x @ W_1 # (100, 3) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + + # Should have standard factor metrics + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + # Should have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/singular_values" in arrays + + # Singular values shape should be min(2, 3) = 2 + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + + # Since factor_0's subspace is contained in factor_1's, singular values should be near 1.0 + # (indicating perfect alignment in the 2D shared subspace) + assert scalars["orthogonality_0_1/subspace_overlap"] > 0.99 + assert scalars["orthogonality_0_1/max_singular_value"] > 0.99 + assert scalars["orthogonality_0_1/min_singular_value"] > 0.99 + assert jnp.all(arrays["orthogonality_0_1/singular_values"] > 0.99) + + +def test_orthogonality_excludes_intercept() -> None: + """Orthogonality should be computed using only coefficients, not intercept.""" + + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create orthogonal coefficient matrices + W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + + # Add different intercepts to the factors + intercept_0 = jnp.array([[5.0, -3.0]]) + intercept_1 = jnp.array([[10.0, 7.0]]) + + factor_0 = x @ W_0 + intercept_0 # (100, 2) + factor_1 = x @ W_1 + intercept_1 # (100, 2) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + # Run with fit_intercept=True + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=True, ) - weights = jnp.ones(5) / 5.0 - # Factor 0: 3 states, exact linear combination (with intercept) - # y0 = [x0 + 1, x1 + 2, x2 + 3] - factor_0 = jnp.stack([x[:, 0] + 1, x[:, 1] + 2, x[:, 2] + 3], axis=1) + # Should have intercepts for both factors + assert "factor_0/intercept" in arrays + assert "factor_1/intercept" in arrays + + # Should have good regression fit + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + # Orthogonality should still be near-zero (computed from coefficients only, not intercepts) + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) - # Factor 1: 2 states, exact linear combination - # y1 = [x1, x3] - factor_1 = jnp.stack([x[:, 1], x[:, 3]], axis=1) + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars - scalars, projections = layer_linear_regression(x, weights, (factor_0, factor_1), to_factors=True) + overlap = scalars["orthogonality_0_1/subspace_overlap"] + assert overlap < threshold, f"subspace_overlap={overlap} >= threshold={threshold}" - # Should achieve perfect R² since targets are exact linear combinations - assert scalars["factor_0/r2"] > 0.99, f"factor_0 R² too low: {scalars['factor_0/r2']}" - assert scalars["factor_1/r2"] > 0.99, f"factor_1 R² too low: {scalars['factor_1/r2']}" + max_sv = scalars["orthogonality_0_1/max_singular_value"] + assert max_sv < threshold, f"max_singular_value={max_sv} >= threshold={threshold}" - # Projections should match targets very closely - chex.assert_trees_all_close(projections["factor_0/projected"], factor_0, atol=1e-4) - chex.assert_trees_all_close(projections["factor_1/projected"], factor_1, atol=1e-4) + # The different intercepts should not affect orthogonality + svs = arrays["orthogonality_0_1/singular_values"] + assert jnp.all(svs < threshold), f"singular_values={svs} not all < threshold={threshold}" -def test_factored_regression_different_state_counts() -> None: - """Test factored regression with factors having different numbers of states. +def test_linear_regression_constant_targets_r2_and_dist() -> None: + """Constant targets should yield r2==0, and dist matches weighted residual norm. - This reproduces a scenario where factors have different dimensionality, - which is common in factored generative processes. + With intercept: perfect fit to constant -> zero residuals but r2 fallback to 0.0. + Without intercept: nonzero residuals; verify `dist` against manual computation. """ - x = jnp.arange(24.0).reshape(6, 4) # 6 samples, 4 features - weights = jnp.ones(6) / 6.0 + x = jnp.arange(4.0)[:, None] + y = jnp.ones_like(x) * 3.0 + weights = jnp.array([0.1, 0.2, 0.3, 0.4]) + + # With intercept -> perfect constant fit, but r2 should fallback to 0.0 when variance is zero + scalars, arrays = linear_regression(x, y, weights) + assert scalars["r2"] == 0.0 + assert jnp.isclose(scalars["rmse"], 0.0, atol=1e-6, rtol=0.0).item() + assert jnp.isclose(scalars["mae"], 0.0, atol=1e-6, rtol=0.0).item() + assert jnp.isclose(scalars["dist"], 0.0, atol=1e-6, rtol=0.0).item() + + # Without intercept -> cannot fit a constant perfectly; r2 still 0.0, and dist should match manual computation + scalars_no_int, arrays_no_int = linear_regression(x, y, weights, fit_intercept=False) + assert scalars_no_int["r2"] == 0.0 + residuals = arrays_no_int["projected"] - y + per_sample = jnp.sqrt(jnp.sum(residuals**2, axis=1)) + expected_dist = float(jnp.sum(per_sample * weights)) + assert jnp.isclose(scalars_no_int["dist"], expected_dist, atol=1e-6, rtol=0.0).item() + + +def test_linear_regression_intercept_and_shapes_both_solvers() -> None: + """Validate intercept presence/absence and array shapes for both solvers.""" + n, d, t = 5, 3, 2 + x = jnp.arange(float(n * d)).reshape(n, d) + # Construct multi-target y with known linear relation and intercept + true_coeffs = jnp.array([[1.0, 2.0], [0.5, -1.0], [3.0, 0.0]]) # (d, t) + true_intercept = jnp.array([[0.7, -0.3]]) # (1, t) + y = x @ true_coeffs + true_intercept + weights = jnp.ones(n) / n + + # Standard solver, with intercept + scalars, arrays = linear_regression(x, y, weights, fit_intercept=True) + assert "projected" in arrays + assert "coeffs" in arrays + assert "intercept" in arrays + assert arrays["projected"].shape == (n, t) + assert arrays["coeffs"].shape == (d, t) + assert arrays["intercept"].shape == (1, t) + + # Standard solver, without intercept + scalars_no_int, arrays_no_int = linear_regression(x, y, weights, fit_intercept=False) + assert "projected" in arrays_no_int + assert "coeffs" in arrays_no_int + assert "intercept" not in arrays_no_int + assert arrays_no_int["projected"].shape == (n, t) + assert arrays_no_int["coeffs"].shape == (d, t) + + # SVD solver, with intercept + scalars_svd, arrays_svd = linear_regression_svd(x, y, weights, fit_intercept=True) + assert "projected" in arrays_svd + assert "coeffs" in arrays_svd + assert "intercept" in arrays_svd + assert arrays_svd["projected"].shape == (n, t) + assert arrays_svd["coeffs"].shape == (d, t) + assert arrays_svd["intercept"].shape == (1, t) + + # SVD solver, without intercept + scalars_svd_no_int, arrays_svd_no_int = linear_regression_svd(x, y, weights, fit_intercept=False) + assert "projected" in arrays_svd_no_int + assert "coeffs" in arrays_svd_no_int + assert "intercept" not in arrays_svd_no_int + assert arrays_svd_no_int["projected"].shape == (n, t) + assert arrays_svd_no_int["coeffs"].shape == (d, t) + + +def test_layer_linear_regression_concat_vs_separate_equivalence() -> None: + """Concat and separate factor regressions should yield identical per-factor arrays.""" + n, d = 6, 3 + x = jnp.arange(float(n * d)).reshape(n, d) + # Two factors with different output dims + W0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) + b0 = jnp.array([[0.3, -0.2]]) # (1, 2) + factor_0 = x @ W0 + b0 + + W1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) + b1 = jnp.array([[0.1, 0.2, -0.1]]) # (1, 3) + factor_1 = x @ W1 + b1 - # Factor 0: 3 states (like "mess3") - factor_0_raw = x[:, :3] - factor_0 = factor_0_raw / factor_0_raw.sum(axis=1, keepdims=True) + factored_beliefs = (factor_0, factor_1) + weights = jnp.array([0.05, 0.10, 0.15, 0.20, 0.25, 0.25]) + + # Separate per-factor regression + scalars_sep, arrays_sep = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=False, + ) - # Factor 1: 2 states (like "tom quantum") - factor_1_raw = x[:, :2] - factor_1 = factor_1_raw / factor_1_raw.sum(axis=1, keepdims=True) + # Concatenated regression with splitting + scalars_cat, arrays_cat = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=True, + ) - scalars, projections = layer_linear_regression(x, weights, (factor_0, factor_1), to_factors=True) + # Concat path should also provide combined arrays + assert "concat/projected" in arrays_cat + assert "concat/coeffs" in arrays_cat + assert "concat/intercept" in arrays_cat - # Verify shapes are correct - assert projections["factor_0/projected"].shape == (6, 3) - assert projections["factor_1/projected"].shape == (6, 2) + # Per-factor arrays should match between separate and concatenated flows + for k in ["projected", "coeffs", "intercept"]: + chex.assert_trees_all_close(arrays_sep[f"factor_0/{k}"], arrays_cat[f"factor_0/{k}"]) + chex.assert_trees_all_close(arrays_sep[f"factor_1/{k}"], arrays_cat[f"factor_1/{k}"]) + + +def test_layer_linear_regression_svd_concat_vs_separate_equivalence_best_rcond() -> None: + """SVD regression: concat-split vs separate produce identical per-factor arrays. + + If belief concatenation is enabled, we only report rcond for the concatenated fit as "concat/best_rcond". + If belief concatenation is disabled, we report rcond for each factor as "factor_k/best_rcond". + """ + n, d = 6, 3 + x = jnp.arange(float(n * d)).reshape(n, d) + # Two factors with different output dims + W0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) + b0 = jnp.array([[0.3, -0.2]]) # (1, 2) + factor_0 = x @ W0 + b0 + + W1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) + b1 = jnp.array([[0.1, 0.2, -0.1]]) # (1, 3) + factor_1 = x @ W1 + b1 + + factored_beliefs = (factor_0, factor_1) + weights = jnp.array([0.05, 0.10, 0.15, 0.20, 0.25, 0.25]) + + # Separate per-factor SVD regression + scalars_sep, arrays_sep = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=False, + use_svd=True, + rcond_values=[1e-3], + ) + + # Concatenated SVD regression with splitting + scalars_cat, arrays_cat = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=True, + use_svd=True, + rcond_values=[1e-3], + ) - # Both should achieve reasonable fit - assert scalars["factor_0/r2"] > 0.5, f"factor_0 R² too low: {scalars['factor_0/r2']}" - assert scalars["factor_1/r2"] > 0.5, f"factor_1 R² too low: {scalars['factor_1/r2']}" + # Concat path should provide combined arrays and best_rcond + assert "concat/projected" in arrays_cat + assert "concat/coeffs" in arrays_cat + assert "concat/intercept" in arrays_cat + assert "concat/best_rcond" in scalars_cat + assert scalars_cat["concat/best_rcond"] == pytest.approx(1e-3) + + # Separate path should include per-factor best_rcond; concat-split path should not + assert "factor_0/best_rcond" in scalars_sep and "factor_1/best_rcond" in scalars_sep + assert "factor_0/best_rcond" not in scalars_cat and "factor_1/best_rcond" not in scalars_cat + + # Per-factor arrays should match between separate and concat-split flows + for k in ["projected", "coeffs", "intercept"]: + chex.assert_trees_all_close(arrays_sep[f"factor_0/{k}"], arrays_cat[f"factor_0/{k}"]) + chex.assert_trees_all_close(arrays_sep[f"factor_1/{k}"], arrays_cat[f"factor_1/{k}"]) + + # Overlapping scalar metrics should agree closely across flows + for metric in ["r2", "rmse", "mae", "dist"]: + assert jnp.isclose( + jnp.asarray(scalars_sep[f"factor_0/{metric}"]), + jnp.asarray(scalars_cat[f"factor_0/{metric}"]), + atol=1e-6, + rtol=0.0, + ).item() + assert jnp.isclose( + jnp.asarray(scalars_sep[f"factor_1/{metric}"]), + jnp.asarray(scalars_cat[f"factor_1/{metric}"]), + atol=1e-6, + rtol=0.0, + ).item() From d43935a538fda2b9822160f0c334ac16901f79be Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 11:52:43 -0800 Subject: [PATCH 03/48] Organize imports --- simplexity/analysis/linear_regression.py | 60 ++++++++++++------------ 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 8b1577cf..e767469d 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools from collections.abc import Callable, Mapping, Sequence from typing import Any @@ -75,13 +76,13 @@ def linear_regression( if fit_intercept: arrays = { "projected": predictions, - "coeffs": beta[1:], # Linear coefficients (excluding intercept) + "coeffs": beta[1:], # Linear coefficients (excluding intercept) "intercept": beta[0:1], # Intercept term (keep 2D: [1, n_targets]) } else: arrays = { "projected": predictions, - "coeffs": beta, # All parameters are coefficients when no intercept + "coeffs": beta, # All parameters are coefficients when no intercept } return scalars, arrays @@ -178,24 +179,24 @@ def linear_regression_svd( if fit_intercept: arrays = { "projected": best_pred, - "coeffs": best_beta[1:], # Linear coefficients (excluding intercept) + "coeffs": best_beta[1:], # Linear coefficients (excluding intercept) "intercept": best_beta[0:1], # Intercept term (keep 2D: [1, n_targets]) } else: arrays = { "projected": best_pred, - "coeffs": best_beta, # All parameters are coefficients when no intercept + "coeffs": best_beta, # All parameters are coefficients when no intercept } return scalars, arrays def _process_individual_factors( - layer_activations: jax.Array, - belief_states: tuple[jax.Array, ...], - weights: jax.Array, - use_svd: bool, - **kwargs: Any, + layer_activations: jax.Array, + belief_states: tuple[jax.Array, ...], + weights: jax.Array, + use_svd: bool, + **kwargs: Any, ) -> list[tuple[Mapping[str, float], Mapping[str, jax.Array]]]: """Process each factor individually using either standard or SVD regression.""" results = [] @@ -244,7 +245,7 @@ def _split_concat_results( # Only recompute scalar metrics, reuse projections and coefficients # Filter out rcond_values from kwargs (only relevant for SVD during fitting, not metrics) - metrics_kwargs = {k: v for k, v in kwargs.items() if k != 'rcond_values'} + metrics_kwargs = {k: v for k, v in kwargs.items() if k != "rcond_values"} fit_intercept = kwargs.get("fit_intercept", True) results = [] @@ -303,7 +304,7 @@ def _compute_subspace_orthogonality( min_singular_value = jnp.min(singular_values) # Compute the participation ratio - participation_ratio = jnp.sum(singular_values**2)**2 / jnp.sum(singular_values**4) + participation_ratio = jnp.sum(singular_values**2) ** 2 / jnp.sum(singular_values**4) # Compute the entropy probs = singular_values**2 / jnp.sum(singular_values**2) @@ -346,23 +347,21 @@ def _compute_all_pairwise_orthogonality( coeffs_list[j], ] orthogonality_scalars, orthogonality_singular_values = _compute_subspace_orthogonality(coeffs_pair) - scalars.update({ - f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items() - }) - singular_values.update({ - f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_singular_values.items() - }) + scalars.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items()}) + singular_values.update( + {f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_singular_values.items()} + ) return scalars, singular_values def _handle_factored_regression( - layer_activations: jax.Array, - weights: jax.Array, - belief_states: tuple[jax.Array, ...], - concat_belief_states: bool, - compute_subspace_orthogonality: bool, - use_svd: bool, - **kwargs: Any, + layer_activations: jax.Array, + weights: jax.Array, + belief_states: tuple[jax.Array, ...], + concat_belief_states: bool, + compute_subspace_orthogonality: bool, + use_svd: bool, + **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: """Handle regression for factored belief states using either standard or SVD method.""" scalars: dict[str, float] = {} @@ -385,9 +384,7 @@ def _handle_factored_regression( **kwargs, ) else: - factor_results = _process_individual_factors( - layer_activations, belief_states, weights, use_svd, **kwargs - ) + factor_results = _process_individual_factors(layer_activations, belief_states, weights, use_svd, **kwargs) for factor_idx, factor_result in enumerate(factor_results): _merge_results_with_prefix(scalars, arrays, factor_result, f"factor_{factor_idx}") @@ -467,8 +464,13 @@ def layer_linear_regression( return scalars, arrays return _handle_factored_regression( - layer_activations, weights, belief_states, - concat_belief_states, compute_subspace_orthogonality, use_svd, **kwargs + layer_activations, + weights, + belief_states, + concat_belief_states, + compute_subspace_orthogonality, + use_svd, + **kwargs, ) From 9e600a47ff28ff4d46adc6e6b6244b729088e1d0 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 11:53:57 -0800 Subject: [PATCH 04/48] Fix lint issues --- simplexity/analysis/linear_regression.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index e767469d..52fc0dbf 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -246,10 +246,11 @@ def _split_concat_results( # Only recompute scalar metrics, reuse projections and coefficients # Filter out rcond_values from kwargs (only relevant for SVD during fitting, not metrics) metrics_kwargs = {k: v for k, v in kwargs.items() if k != "rcond_values"} - fit_intercept = kwargs.get("fit_intercept", True) results = [] - for factor, coeffs, intercept, projections in zip(belief_states, coeffs_list, intercepts_list, projections_list): + for factor, coeffs, intercept, projections in zip( + belief_states, coeffs_list, intercepts_list, projections_list, strict=True + ): # Reconstruct full beta for metrics computation if intercept is not None: beta = jnp.concatenate([intercept, coeffs], axis=0) @@ -277,8 +278,7 @@ def _split_concat_results( def _compute_subspace_orthogonality( coeffs_pair: list[jax.Array], ) -> tuple[dict[str, float], dict[str, jax.Array]]: - """ - Compute orthogonality metrics between two coefficient subspaces. + """Compute orthogonality metrics between two coefficient subspaces. Args: coeffs_pair: List of two coefficient matrices (excludes intercept) @@ -332,8 +332,7 @@ def _compute_subspace_orthogonality( def _compute_all_pairwise_orthogonality( coeffs_list: list[jax.Array], ) -> tuple[dict[str, float], dict[str, jax.Array]]: - """ - Compute pairwise orthogonality metrics for all factor pairs. + """Compute pairwise orthogonality metrics for all factor pairs. Args: coeffs_list: List of coefficient matrices (one per factor, excludes intercepts) @@ -436,8 +435,7 @@ def layer_linear_regression( use_svd: bool = False, **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """ - Layer-wise regression helper that wraps :func:`linear_regression` or :func:`linear_regression_svd`. + """Layer-wise regression helper that wraps :func:`linear_regression` or :func:`linear_regression_svd`. Args: layer_activations: Neural network activations for a single layer @@ -482,8 +480,7 @@ def layer_linear_regression_svd( compute_subspace_orthogonality: bool = False, **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """ - Layer-wise SVD regression helper (wrapper around layer_linear_regression with use_svd=True). + """Layer-wise SVD regression helper (wrapper around layer_linear_regression with use_svd=True). This function is provided for backward compatibility and convenience. Consider using layer_linear_regression with use_svd=True for new code. From edba4fe579c39e35a409346ba3a6d9b054a6cf87 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 11:55:32 -0800 Subject: [PATCH 05/48] Fix slices --- simplexity/analysis/linear_regression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 52fc0dbf..44fadb36 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -77,7 +77,7 @@ def linear_regression( arrays = { "projected": predictions, "coeffs": beta[1:], # Linear coefficients (excluding intercept) - "intercept": beta[0:1], # Intercept term (keep 2D: [1, n_targets]) + "intercept": beta[:1], # Intercept term (keep 2D: [1, n_targets]) } else: arrays = { @@ -180,7 +180,7 @@ def linear_regression_svd( arrays = { "projected": best_pred, "coeffs": best_beta[1:], # Linear coefficients (excluding intercept) - "intercept": best_beta[0:1], # Intercept term (keep 2D: [1, n_targets]) + "intercept": best_beta[:1], # Intercept term (keep 2D: [1, n_targets]) } else: arrays = { From 70eb56e77b06015cc4dc29a8c2485e90ec33d24f Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:12:05 -0800 Subject: [PATCH 06/48] Simplify lr kwarg validation --- simplexity/analysis/layerwise_analysis.py | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index e22c6f4c..e3a0f458 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -32,33 +32,16 @@ class AnalysisRegistration: validator: ValidatorFn -def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: +def _base_validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) - allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd"} - unexpected = set(provided) - allowed - if unexpected: - raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}") - fit_intercept = bool(provided.get("fit_intercept", True)) - concat_belief_states = bool(provided.get("concat_belief_states", False)) - compute_subspace_orthogonality = bool(provided.get("compute_subspace_orthogonality", False)) - use_svd = bool(provided.get("use_svd", False)) - return { - "fit_intercept": fit_intercept, - "concat_belief_states": concat_belief_states, - "compute_subspace_orthogonality": compute_subspace_orthogonality, - "use_svd": use_svd, - } - - -def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: - provided = dict(kwargs or {}) - allowed = {"fit_intercept", "rcond_values", "concat_belief_states", "compute_subspace_orthogonality"} + allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd", "rcond_values"} unexpected = set(provided) - allowed if unexpected: raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}") fit_intercept = bool(provided.get("fit_intercept", True)) concat_belief_states = bool(provided.get("concat_belief_states", False)) compute_subspace_orthogonality = bool(provided.get("compute_subspace_orthogonality", False)) + use_svd = bool(provided.get("use_svd", False)) rcond_values = provided.get("rcond_values") if rcond_values is not None: if not isinstance(rcond_values, (list, tuple)): @@ -70,10 +53,23 @@ def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> "fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states, "compute_subspace_orthogonality": compute_subspace_orthogonality, + "use_svd": use_svd, "rcond_values": rcond_values, } +def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + kwargs = _base_validate_linear_regression_kwargs(kwargs) + kwargs.pop("rcond_values") + return kwargs + + +def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + kwargs = _base_validate_linear_regression_kwargs(kwargs) + kwargs.pop("use_svd") + return kwargs + + def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) allowed = {"n_components", "variance_thresholds"} From 9cc9810e03e1576891d40470d0b0896ad402fbc4 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:12:54 -0800 Subject: [PATCH 07/48] Add return type --- simplexity/analysis/linear_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 44fadb36..ac1f1c17 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -96,7 +96,7 @@ def _compute_regression_metrics( predictions: jax.Array | None = None, *, fit_intercept: bool = True, -): +) -> Mapping[str, float]: x_arr = standardize_features(x) y_arr = standardize_targets(y) if x_arr.shape[0] != y_arr.shape[0]: From d403bc727f46b04517f56dda5a4ea215ced1fbe2 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:13:44 -0800 Subject: [PATCH 08/48] Add pylint ignore --- simplexity/analysis/linear_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index ac1f1c17..193916ba 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -88,7 +88,7 @@ def linear_regression( return scalars, arrays -def _compute_regression_metrics( +def _compute_regression_metrics( # pylint: disable=too-many-arguments x: jax.Array, y: jax.Array, weights: jax.Array | np.ndarray | None, From 1c55be037b265ae9edb42897fa902966ca094141 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:15:45 -0800 Subject: [PATCH 09/48] Fix potential division by zero --- simplexity/analysis/linear_regression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 193916ba..ce2c2e79 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -304,8 +304,8 @@ def _compute_subspace_orthogonality( min_singular_value = jnp.min(singular_values) # Compute the participation ratio - participation_ratio = jnp.sum(singular_values**2) ** 2 / jnp.sum(singular_values**4) - + denom = jnp.sum(singular_values**4) + participation_ratio = jnp.where(denom == 0, 0.0, jnp.sum(singular_values**2) ** 2 / denom) # Compute the entropy probs = singular_values**2 / jnp.sum(singular_values**2) entropy = -jnp.sum(probs * jnp.log(probs)) From c3d070ca4594a21fc9854c877d5e284b10d47a2d Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:19:50 -0800 Subject: [PATCH 10/48] Fix potential log(0) issue --- simplexity/analysis/linear_regression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index ce2c2e79..43af7d78 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -308,7 +308,8 @@ def _compute_subspace_orthogonality( participation_ratio = jnp.where(denom == 0, 0.0, jnp.sum(singular_values**2) ** 2 / denom) # Compute the entropy probs = singular_values**2 / jnp.sum(singular_values**2) - entropy = -jnp.sum(probs * jnp.log(probs)) + log_probs = jnp.where(probs > 0, jnp.log(probs), 0.0) + entropy = -jnp.sum(probs * log_probs) # Compute the effective rank effective_rank = jnp.exp(entropy) From 0c9a37f613f1c3abf42a7ea65af8f54e41d68238 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:23:08 -0800 Subject: [PATCH 11/48] Enhance subspace orthogonality computation by adding a check for multiple belief states. Log a warning if only one belief state is present, preventing unnecessary calculations. --- simplexity/analysis/linear_regression.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 43af7d78..6f5a6716 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -390,11 +390,14 @@ def _handle_factored_regression( _merge_results_with_prefix(scalars, arrays, factor_result, f"factor_{factor_idx}") if compute_subspace_orthogonality: - # Extract coefficients (excludes intercept) for orthogonality computation - coeffs_list = [factor_arrays["coeffs"] for _, factor_arrays in factor_results] - orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(coeffs_list) - scalars.update(orthogonality_scalars) - arrays.update(orthogonality_singular_values) + if len(belief_states) > 1: + # Extract coefficients (excludes intercept) for orthogonality computation + coeffs_list = [factor_arrays["coeffs"] for _, factor_arrays in factor_results] + orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(coeffs_list) + scalars.update(orthogonality_scalars) + arrays.update(orthogonality_singular_values) + else: + SIMPLEXITY_LOGGER.warning("Subspace orthogonality cannot be computed for a single belief state") return scalars, arrays From 74c676034d3b60edbdbd1636c498b8aad2af9afb Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:23:56 -0800 Subject: [PATCH 12/48] Fix docstring inconsistency --- tests/analysis/test_linear_regression.py | 91 +++++++++++++----------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index 63a6fd22..2549aa17 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -27,7 +27,7 @@ def _compute_orthogonality_threshold( Args: x: Input features array (used for dtype and dimension) *factors: Factor arrays being compared (used for output dimensions) - safety_factor: Multiplicative safety factor (default 1000) + safety_factor: Multiplicative safety factor (default 10) Returns: Threshold value for considering singular values as effectively zero @@ -259,7 +259,6 @@ def test_layer_linear_regression_belief_states_tuple_default() -> None: assert arrays["factor_1/intercept"].shape == (1, factor_1.shape[1]) - def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: """By default, layer regression SVD should regress to each factor separately if given a tuple of belief states.""" x = jnp.arange(12.0).reshape(4, 3) # 4 samples, 3 features @@ -706,26 +705,30 @@ def test_orthogonality_with_different_subspace_dimensions() -> None: # Create orthogonal coefficient matrices with different output dimensions # factor_0 has 2 output dimensions, factor_1 has 5 output dimensions - W_0 = jnp.array([ - [1.0, 0.0], - [0.0, 1.0], - [1.0, 1.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - ]) # (8, 2) - W_1 = jnp.array([ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ]) # (8, 5) + W_0 = jnp.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) # (8, 2) + W_1 = jnp.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ] + ) # (8, 5) factor_0 = x @ W_0 # (100, 2) factor_1 = x @ W_1 # (100, 5) @@ -771,26 +774,30 @@ def test_orthogonality_with_contained_subspace() -> None: # Create coefficient matrices where factor_0's subspace is contained in factor_1's # factor_0: 2D subspace using features [0, 1] # factor_1: 3D subspace using features [0, 1, 2] (contains factor_0's space) - W_0 = jnp.array([ - [1.0, 0.0], - [0.0, 1.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - [0.0, 0.0], - ]) # (8, 2) - W_1 = jnp.array([ - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ]) # (8, 3) + W_0 = jnp.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) # (8, 2) + W_1 = jnp.array( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) # (8, 3) factor_0 = x @ W_0 # (100, 2) factor_1 = x @ W_1 # (100, 3) From d3b02352089f7fb532fd3f97dbfbcd5b33332274 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:24:26 -0800 Subject: [PATCH 13/48] Update docstring --- tests/analysis/test_linear_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index 2549aa17..cebcca32 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -306,7 +306,7 @@ def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: - """to_factors=True should work with a single factor tuple.""" + """Handles a single factor provided as a tuple of belief states.""" x = jnp.arange(9.0).reshape(3, 3) weights = jnp.ones(3) / 3.0 From 2d4a97f728329d1eda48b6879aed1b3f56a82687 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 12:29:09 -0800 Subject: [PATCH 14/48] Fix lint issues --- tests/analysis/test_linear_regression.py | 136 ++++++++++++----------- 1 file changed, 70 insertions(+), 66 deletions(-) diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index cebcca32..cb3ce270 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -1,5 +1,7 @@ """Tests for reusable linear regression helpers.""" +# pylint: disable=too-many-lines + import chex import jax import jax.numpy as jnp @@ -342,13 +344,13 @@ def test_orthogonality_with_orthogonal_subspaces() -> None: x = jax.random.normal(key, (n_samples, n_features)) # Define orthogonal coefficient matrices - # W_0 uses first 3 features, W_1 uses last 3 features - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) - W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # (6, 2) + # w_0 uses first 3 features, w_1 uses last 3 features + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # (6, 2) # Generate factors using orthogonal subspaces (no intercept for simplicity) - factor_0 = x @ W_0 # (100, 2) - factor_1 = x @ W_1 # (100, 2) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 2) factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -395,14 +397,14 @@ def test_orthogonality_with_aligned_subspaces() -> None: key = jax.random.PRNGKey(0) x = jax.random.normal(key, (n_samples, n_features)) - # Define aligned coefficient matrices - W_1 = W_0 @ A for invertible A - # This ensures span(W_1) = span(W_0) - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) - W_1 = jnp.array([[0.5, 1.0], [1.0, 0.5], [1.5, 1.5], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + # Define aligned coefficient matrices - w_1 = w_0 @ A for invertible A + # This ensures span(w_1) = span(w_0) + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + w_1 = jnp.array([[0.5, 1.0], [1.0, 0.5], [1.5, 1.5], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) # Generate factors using aligned subspaces (no intercept for simplicity) - factor_0 = x @ W_0 # (100, 2) - factor_1 = x @ W_1 # (100, 2) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 2) factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -447,14 +449,14 @@ def test_orthogonality_with_three_factors() -> None: x = jax.random.normal(key, (n_samples, n_features)) # Define three orthogonal coefficient matrices using disjoint features - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 0-1 - W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 2-3 - W_2 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) # Uses features 4-5 + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 0-1 + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 2-3 + w_2 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) # Uses features 4-5 # Generate factors using orthogonal subspaces - factor_0 = x @ W_0 # (100, 2) - factor_1 = x @ W_1 # (100, 2) - factor_2 = x @ W_2 # (100, 2) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 2) + factor_2 = x @ w_2 # (100, 2) factored_beliefs = (factor_0, factor_1, factor_2) weights = jnp.ones(n_samples) / n_samples @@ -506,11 +508,11 @@ def test_orthogonality_not_computed_by_default() -> None: key = jax.random.PRNGKey(0) x = jax.random.normal(key, (n_samples, n_features)) - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) - W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) - factor_0 = x @ W_0 - factor_1 = x @ W_1 + factor_0 = x @ w_0 + factor_1 = x @ w_1 factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -605,18 +607,18 @@ def test_use_svd_flag_equivalence() -> None: # Should produce identical results assert scalars_flag.keys() == scalars_wrapper.keys() - for key in scalars_flag: - assert scalars_flag[key] == pytest.approx(scalars_wrapper[key]) + for key, value in scalars_flag.items(): + assert value == pytest.approx(scalars_wrapper[key]) assert arrays_flag.keys() == arrays_wrapper.keys() - for key in arrays_flag: - chex.assert_trees_all_close(arrays_flag[key], arrays_wrapper[key]) + for key, value in arrays_flag.items(): + chex.assert_trees_all_close(value, arrays_wrapper[key]) # Test with factored belief states - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) - W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) - factor_0 = x @ W_0 - factor_1 = x @ W_1 + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + factor_0 = x @ w_0 + factor_1 = x @ w_1 factored_beliefs = (factor_0, factor_1) # Method 1: use_svd=True with factored beliefs @@ -638,12 +640,12 @@ def test_use_svd_flag_equivalence() -> None: # Should produce identical results assert scalars_flag_fact.keys() == scalars_wrapper_fact.keys() - for key in scalars_flag_fact: - assert scalars_flag_fact[key] == pytest.approx(scalars_wrapper_fact[key]) + for key, value in scalars_flag_fact.items(): + assert value == pytest.approx(scalars_wrapper_fact[key]) assert arrays_flag_fact.keys() == arrays_wrapper_fact.keys() - for key in arrays_flag_fact: - chex.assert_trees_all_close(arrays_flag_fact[key], arrays_wrapper_fact[key]) + for key, value in arrays_flag_fact.items(): + chex.assert_trees_all_close(value, arrays_wrapper_fact[key]) def test_use_svd_with_orthogonality() -> None: @@ -654,11 +656,11 @@ def test_use_svd_with_orthogonality() -> None: x = jax.random.normal(key, (n_samples, n_features)) # Create orthogonal coefficient matrices - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) - W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) - factor_0 = x @ W_0 - factor_1 = x @ W_1 + factor_0 = x @ w_0 + factor_1 = x @ w_1 factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -705,7 +707,7 @@ def test_orthogonality_with_different_subspace_dimensions() -> None: # Create orthogonal coefficient matrices with different output dimensions # factor_0 has 2 output dimensions, factor_1 has 5 output dimensions - W_0 = jnp.array( + w_0 = jnp.array( [ [1.0, 0.0], [0.0, 1.0], @@ -717,7 +719,7 @@ def test_orthogonality_with_different_subspace_dimensions() -> None: [0.0, 0.0], ] ) # (8, 2) - W_1 = jnp.array( + w_1 = jnp.array( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -730,8 +732,8 @@ def test_orthogonality_with_different_subspace_dimensions() -> None: ] ) # (8, 5) - factor_0 = x @ W_0 # (100, 2) - factor_1 = x @ W_1 # (100, 5) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 5) factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -774,7 +776,7 @@ def test_orthogonality_with_contained_subspace() -> None: # Create coefficient matrices where factor_0's subspace is contained in factor_1's # factor_0: 2D subspace using features [0, 1] # factor_1: 3D subspace using features [0, 1, 2] (contains factor_0's space) - W_0 = jnp.array( + w_0 = jnp.array( [ [1.0, 0.0], [0.0, 1.0], @@ -786,7 +788,7 @@ def test_orthogonality_with_contained_subspace() -> None: [0.0, 0.0], ] ) # (8, 2) - W_1 = jnp.array( + w_1 = jnp.array( [ [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], @@ -799,8 +801,8 @@ def test_orthogonality_with_contained_subspace() -> None: ] ) # (8, 3) - factor_0 = x @ W_0 # (100, 2) - factor_1 = x @ W_1 # (100, 3) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 3) factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -840,15 +842,15 @@ def test_orthogonality_excludes_intercept() -> None: x = jax.random.normal(key, (n_samples, n_features)) # Create orthogonal coefficient matrices - W_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) - W_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # Add different intercepts to the factors intercept_0 = jnp.array([[5.0, -3.0]]) intercept_1 = jnp.array([[10.0, 7.0]]) - factor_0 = x @ W_0 + intercept_0 # (100, 2) - factor_1 = x @ W_1 + intercept_1 # (100, 2) + factor_0 = x @ w_0 + intercept_0 # (100, 2) + factor_1 = x @ w_1 + intercept_1 # (100, 2) factored_beliefs = (factor_0, factor_1) weights = jnp.ones(n_samples) / n_samples @@ -897,7 +899,7 @@ def test_linear_regression_constant_targets_r2_and_dist() -> None: weights = jnp.array([0.1, 0.2, 0.3, 0.4]) # With intercept -> perfect constant fit, but r2 should fallback to 0.0 when variance is zero - scalars, arrays = linear_regression(x, y, weights) + scalars, _ = linear_regression(x, y, weights) assert scalars["r2"] == 0.0 assert jnp.isclose(scalars["rmse"], 0.0, atol=1e-6, rtol=0.0).item() assert jnp.isclose(scalars["mae"], 0.0, atol=1e-6, rtol=0.0).item() @@ -923,7 +925,7 @@ def test_linear_regression_intercept_and_shapes_both_solvers() -> None: weights = jnp.ones(n) / n # Standard solver, with intercept - scalars, arrays = linear_regression(x, y, weights, fit_intercept=True) + _, arrays = linear_regression(x, y, weights, fit_intercept=True) assert "projected" in arrays assert "coeffs" in arrays assert "intercept" in arrays @@ -932,7 +934,7 @@ def test_linear_regression_intercept_and_shapes_both_solvers() -> None: assert arrays["intercept"].shape == (1, t) # Standard solver, without intercept - scalars_no_int, arrays_no_int = linear_regression(x, y, weights, fit_intercept=False) + _, arrays_no_int = linear_regression(x, y, weights, fit_intercept=False) assert "projected" in arrays_no_int assert "coeffs" in arrays_no_int assert "intercept" not in arrays_no_int @@ -940,7 +942,7 @@ def test_linear_regression_intercept_and_shapes_both_solvers() -> None: assert arrays_no_int["coeffs"].shape == (d, t) # SVD solver, with intercept - scalars_svd, arrays_svd = linear_regression_svd(x, y, weights, fit_intercept=True) + _, arrays_svd = linear_regression_svd(x, y, weights, fit_intercept=True) assert "projected" in arrays_svd assert "coeffs" in arrays_svd assert "intercept" in arrays_svd @@ -949,7 +951,7 @@ def test_linear_regression_intercept_and_shapes_both_solvers() -> None: assert arrays_svd["intercept"].shape == (1, t) # SVD solver, without intercept - scalars_svd_no_int, arrays_svd_no_int = linear_regression_svd(x, y, weights, fit_intercept=False) + _, arrays_svd_no_int = linear_regression_svd(x, y, weights, fit_intercept=False) assert "projected" in arrays_svd_no_int assert "coeffs" in arrays_svd_no_int assert "intercept" not in arrays_svd_no_int @@ -962,19 +964,19 @@ def test_layer_linear_regression_concat_vs_separate_equivalence() -> None: n, d = 6, 3 x = jnp.arange(float(n * d)).reshape(n, d) # Two factors with different output dims - W0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) + w_0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) b0 = jnp.array([[0.3, -0.2]]) # (1, 2) - factor_0 = x @ W0 + b0 + factor_0 = x @ w_0 + b0 - W1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) + w_1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) b1 = jnp.array([[0.1, 0.2, -0.1]]) # (1, 3) - factor_1 = x @ W1 + b1 + factor_1 = x @ w_1 + b1 factored_beliefs = (factor_0, factor_1) weights = jnp.array([0.05, 0.10, 0.15, 0.20, 0.25, 0.25]) # Separate per-factor regression - scalars_sep, arrays_sep = layer_linear_regression( + _, arrays_sep = layer_linear_regression( x, weights, factored_beliefs, @@ -982,7 +984,7 @@ def test_layer_linear_regression_concat_vs_separate_equivalence() -> None: ) # Concatenated regression with splitting - scalars_cat, arrays_cat = layer_linear_regression( + _, arrays_cat = layer_linear_regression( x, weights, factored_beliefs, @@ -1009,13 +1011,13 @@ def test_layer_linear_regression_svd_concat_vs_separate_equivalence_best_rcond() n, d = 6, 3 x = jnp.arange(float(n * d)).reshape(n, d) # Two factors with different output dims - W0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) + w_0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) b0 = jnp.array([[0.3, -0.2]]) # (1, 2) - factor_0 = x @ W0 + b0 + factor_0 = x @ w_0 + b0 - W1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) + w_1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) b1 = jnp.array([[0.1, 0.2, -0.1]]) # (1, 3) - factor_1 = x @ W1 + b1 + factor_1 = x @ w_1 + b1 factored_beliefs = (factor_0, factor_1) weights = jnp.array([0.05, 0.10, 0.15, 0.20, 0.25, 0.25]) @@ -1048,8 +1050,10 @@ def test_layer_linear_regression_svd_concat_vs_separate_equivalence_best_rcond() assert scalars_cat["concat/best_rcond"] == pytest.approx(1e-3) # Separate path should include per-factor best_rcond; concat-split path should not - assert "factor_0/best_rcond" in scalars_sep and "factor_1/best_rcond" in scalars_sep - assert "factor_0/best_rcond" not in scalars_cat and "factor_1/best_rcond" not in scalars_cat + assert "factor_0/best_rcond" in scalars_sep + assert "factor_1/best_rcond" in scalars_sep + assert "factor_0/best_rcond" not in scalars_cat + assert "factor_1/best_rcond" not in scalars_cat # Per-factor arrays should match between separate and concat-split flows for k in ["projected", "coeffs", "intercept"]: From 335d21048d993425b7061cd254c326589c7938e3 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 15:57:11 -0800 Subject: [PATCH 15/48] Refactor linear regression kwargs validation and improve logging. Temporarily disable pylint checks during AST traversal to avoid crashes related to package imports. --- simplexity/analysis/layerwise_analysis.py | 26 +++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index e3a0f458..a18fae41 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -1,5 +1,12 @@ """Composable layer-wise analysis orchestration.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + from __future__ import annotations from collections.abc import Callable, Mapping, Sequence @@ -16,6 +23,7 @@ DEFAULT_VARIANCE_THRESHOLDS, layer_pca_analysis, ) +from simplexity.logger import SIMPLEXITY_LOGGER AnalysisFn = Callable[..., tuple[Mapping[str, float], Mapping[str, jax.Array]]] @@ -32,7 +40,7 @@ class AnalysisRegistration: validator: ValidatorFn -def _base_validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: +def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd", "rcond_values"} unexpected = set(provided) - allowed @@ -48,6 +56,8 @@ def _base_validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> raise TypeError("rcond_values must be a sequence of floats") if len(rcond_values) == 0: raise ValueError("rcond_values must not be empty") + if not use_svd: + SIMPLEXITY_LOGGER.warning("rcond_values are only used when use_svd is True") rcond_values = tuple(float(v) for v in rcond_values) return { "fit_intercept": fit_intercept, @@ -58,18 +68,6 @@ def _base_validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> } -def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: - kwargs = _base_validate_linear_regression_kwargs(kwargs) - kwargs.pop("rcond_values") - return kwargs - - -def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: - kwargs = _base_validate_linear_regression_kwargs(kwargs) - kwargs.pop("use_svd") - return kwargs - - def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) allowed = {"n_components", "variance_thresholds"} @@ -104,7 +102,7 @@ def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: "linear_regression_svd": AnalysisRegistration( fn=layer_linear_regression_svd, requires_belief_states=True, - validator=_validate_linear_regression_svd_kwargs, + validator=_validate_linear_regression_kwargs, ), "pca": AnalysisRegistration( fn=layer_pca_analysis, From 358985ce044253786d937fa9e2ba078dd72f4de7 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 12:52:31 -0800 Subject: [PATCH 16/48] Fix merge conflict --- simplexity/analysis/linear_regression.py | 51 ++++++++++++++++++++---- tests/analysis/test_linear_regression.py | 35 ++++++++++------ 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 6f5a6716..b0209f13 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -303,13 +303,37 @@ def _compute_subspace_orthogonality( # Compute the min singular value min_singular_value = jnp.min(singular_values) + + # Check if subspace is degenerate + if max_singular_value == 0: + SIMPLEXITY_LOGGER.warning( + "Degenerate subspace detected during orthogonality computation." + "All singular values are zero." + "Setting probability values to zero." + "Setting participation ratio to zero." + ) + pratio_denominator = 1.0 + probs_denominator = 1.0 + else: + pratio_denominator = jnp.sum(singular_values**4) + probs_denominator = jnp.sum(singular_values**2) + # Compute the participation ratio - denom = jnp.sum(singular_values**4) - participation_ratio = jnp.where(denom == 0, 0.0, jnp.sum(singular_values**2) ** 2 / denom) + participation_ratio = jnp.sum(singular_values**2)**2 / pratio_denominator + # Compute the entropy - probs = singular_values**2 / jnp.sum(singular_values**2) - log_probs = jnp.where(probs > 0, jnp.log(probs), 0.0) - entropy = -jnp.sum(probs * log_probs) + probs = singular_values**2 / probs_denominator + num_zeros = jnp.sum(probs == 0) + if num_zeros > 0: + SIMPLEXITY_LOGGER.warning( + f"Encountered {num_zeros} probability values of zero during entropy computation." + "This is likely due to numerical instability." + "Setting corresponding entropy contribution to zero." + ) + nonzero_probs = probs[probs > 0] + entropy = -jnp.sum(nonzero_probs * jnp.log(nonzero_probs)) + else: + entropy = -jnp.sum(probs * jnp.log(probs)) # Compute the effective rank effective_rank = jnp.exp(entropy) @@ -454,14 +478,25 @@ def layer_linear_regression( scalars: Dictionary of scalar metrics arrays: Dictionary of arrays (projected predictions, parameters, singular values if orthogonality computed) """ - if belief_states is None: + + # If no belief states are provided, raise an error + if ( + belief_states is None + or (isinstance(belief_states, tuple) and len(belief_states) == 0) + or (isinstance(belief_states, jax.Array) and belief_states.size == 0) + ): raise ValueError("linear_regression requires belief_states") regression_fn = linear_regression_svd if use_svd else linear_regression - if not isinstance(belief_states, tuple): + if not isinstance(belief_states, tuple) or len(belief_states) == 1: if compute_subspace_orthogonality: - SIMPLEXITY_LOGGER.warning("Subspace orthogonality cannot be computed for a single belief state") + SIMPLEXITY_LOGGER.warning( + "Subspace orthogonality requires multiple factors." + "Received single factor of type %s; skipping orthogonality metrics.", + type(belief_states).__name__, + ) + belief_states = belief_states[0] if isinstance(belief_states, tuple) else belief_states scalars, arrays = regression_fn(layer_activations, belief_states, weights, **kwargs) return scalars, arrays diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index cb3ce270..eebea8fd 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -308,7 +308,11 @@ def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: +<<<<<<< HEAD """Handles a single factor provided as a tuple of belief states.""" +======= + """Single-element tuple should behave the same as passing a single array.""" +>>>>>>> 8aad089 (Change single factor regre) x = jnp.arange(9.0).reshape(3, 3) weights = jnp.ones(3) / 3.0 @@ -322,17 +326,24 @@ def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: factored_beliefs, ) - # Should have ALL metrics, projections, coefficients, and intercept for single factor - assert "factor_0/r2" in scalars - assert "factor_0/rmse" in scalars - assert "factor_0/mae" in scalars - assert "factor_0/dist" in scalars - assert "factor_0/projected" in arrays - assert "factor_0/coeffs" in arrays - assert "factor_0/intercept" in arrays - assert arrays["factor_0/projected"].shape == factor_0.shape - assert arrays["factor_0/coeffs"].shape == (x.shape[1], factor_0.shape[1]) - assert arrays["factor_0/intercept"].shape == (1, factor_0.shape[1]) + # Should have same structure as non-tuple case + assert "r2" in scalars + assert "rmse" in scalars + assert "mae" in scalars + assert "dist" in scalars + assert "projected" in arrays + assert "coeffs" in arrays + assert "intercept" in arrays + + # Verify it matches non-tuple behavior + scalars_non_tuple, arrays_non_tuple = layer_linear_regression(x, weights, factor_0) + + assert scalars.keys() == scalars_non_tuple.keys() + assert arrays.keys() == arrays_non_tuple.keys() + for key in scalars.keys(): + assert scalars[key] == pytest.approx(scalars_non_tuple[key]) + for key in arrays.keys(): + assert arrays[key] == pytest.approx(arrays_non_tuple[key]) def test_orthogonality_with_orthogonal_subspaces() -> None: @@ -565,7 +576,7 @@ def test_orthogonality_warning_for_single_belief_state(caplog: pytest.LogCapture ) # Should have logged a warning - assert "Subspace orthogonality cannot be computed for a single belief state" in caplog.text + assert "Subspace orthogonality requires multiple factors." in caplog.text # Should still run regression successfully assert "r2" in scalars From d6d71419dd0ffa20334cbc4eeb5b92f62d641db2 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 12:55:24 -0800 Subject: [PATCH 17/48] Ammended unseen merge conflict in linear_regression tests --- tests/analysis/test_linear_regression.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index eebea8fd..b27d8c24 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -308,11 +308,7 @@ def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: -<<<<<<< HEAD - """Handles a single factor provided as a tuple of belief states.""" -======= """Single-element tuple should behave the same as passing a single array.""" ->>>>>>> 8aad089 (Change single factor regre) x = jnp.arange(9.0).reshape(3, 3) weights = jnp.ones(3) / 3.0 From 9a71da4ab1fd208ccb1eae9e5c254a8bf1aac4e4 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 13:03:28 -0800 Subject: [PATCH 18/48] Rename to_factors parameter to concat_belief_states in activation analyses --- simplexity/activations/activation_analyses.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/simplexity/activations/activation_analyses.py b/simplexity/activations/activation_analyses.py index 51fdb7d9..ce724574 100644 --- a/simplexity/activations/activation_analyses.py +++ b/simplexity/activations/activation_analyses.py @@ -84,7 +84,7 @@ def __init__( use_probs_as_weights: bool = True, skip_first_token: bool = False, fit_intercept: bool = True, - to_factors: bool = False, + concat_belief_states: bool = False, ) -> None: super().__init__( analysis_type="linear_regression", @@ -92,7 +92,7 @@ def __init__( concat_layers=concat_layers, use_probs_as_weights=use_probs_as_weights, skip_first_token=skip_first_token, - analysis_kwargs={"fit_intercept": fit_intercept, "to_factors": to_factors}, + analysis_kwargs={"fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states}, ) @@ -108,9 +108,9 @@ def __init__( skip_first_token: bool = False, rcond_values: Sequence[float] | None = None, fit_intercept: bool = True, - to_factors: bool = False, + concat_belief_states: bool = False, ) -> None: - analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "to_factors": to_factors} + analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states} if rcond_values is not None: analysis_kwargs["rcond_values"] = tuple(rcond_values) super().__init__( From ecfa55c96625c4ef312c88ab7cdc776e488389e9 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 13:18:10 -0800 Subject: [PATCH 19/48] Update activation analysis tests for concat_belief_states semantics --- tests/activations/test_activation_analysis.py | 381 ++++++++++++++++++ 1 file changed, 381 insertions(+) diff --git a/tests/activations/test_activation_analysis.py b/tests/activations/test_activation_analysis.py index b31c8a8a..f802cc2c 100644 --- a/tests/activations/test_activation_analysis.py +++ b/tests/activations/test_activation_analysis.py @@ -740,6 +740,387 @@ def test_controls_accumulate_steps_conflict(self): ActivationVisualizationControlsConfig(slider="step", accumulate_steps=True) +class TestTupleBeliefStates: + """Test activation tracker with tuple belief states for factored processes.""" + + @pytest.fixture + def factored_belief_data(self): + """Create synthetic data with factored belief states.""" + batch_size = 4 + seq_len = 5 + d_layer0 = 8 + d_layer1 = 12 + + inputs = jnp.array( + [ + [1, 2, 3, 4, 5], + [1, 2, 3, 6, 7], + [1, 2, 8, 9, 10], + [1, 2, 3, 4, 11], + ] + ) + + # Factored beliefs: 2 factors with dimensions 3 and 2 + factor_0 = jnp.ones((batch_size, seq_len, 3)) * 0.3 + factor_1 = jnp.ones((batch_size, seq_len, 2)) * 0.7 + factored_beliefs = (factor_0, factor_1) + + probs = jnp.ones((batch_size, seq_len)) * 0.1 + + activations = { + "layer_0": jnp.ones((batch_size, seq_len, d_layer0)) * 0.3, + "layer_1": jnp.ones((batch_size, seq_len, d_layer1)) * 0.7, + } + + return { + "inputs": inputs, + "factored_beliefs": factored_beliefs, + "probs": probs, + "activations": activations, + "batch_size": batch_size, + "seq_len": seq_len, + "factor_0_dim": 3, + "factor_1_dim": 2, + "d_layer0": d_layer0, + "d_layer1": d_layer1, + } + + def test_prepare_activations_accepts_tuple_beliefs(self, factored_belief_data): + """prepare_activations should accept and preserve tuple belief states.""" + result = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + + batch_size = factored_belief_data["batch_size"] + assert result.belief_states[0].shape == (batch_size, factored_belief_data["factor_0_dim"]) + assert result.belief_states[1].shape == (batch_size, factored_belief_data["factor_1_dim"]) + + def test_prepare_activations_tuple_beliefs_all_tokens(self, factored_belief_data): + """Tuple beliefs should work with all tokens mode.""" + result = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=False, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + + # With deduplication, we expect fewer samples than batch_size * seq_len + n_prefixes = result.belief_states[0].shape[0] + assert result.belief_states[0].shape == (n_prefixes, factored_belief_data["factor_0_dim"]) + assert result.belief_states[1].shape == (n_prefixes, factored_belief_data["factor_1_dim"]) + assert result.activations["layer_0"].shape[0] == n_prefixes + + def test_prepare_activations_torch_tuple_beliefs(self, factored_belief_data): + """prepare_activations should accept tuple of PyTorch tensors.""" + torch = pytest.importorskip("torch") + + torch_factor_0 = torch.tensor(np.asarray(factored_belief_data["factored_beliefs"][0])) + torch_factor_1 = torch.tensor(np.asarray(factored_belief_data["factored_beliefs"][1])) + torch_beliefs = (torch_factor_0, torch_factor_1) + + result = prepare_activations( + factored_belief_data["inputs"], + torch_beliefs, + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + # Should be converted to JAX arrays + assert isinstance(result.belief_states[0], jnp.ndarray) + assert isinstance(result.belief_states[1], jnp.ndarray) + + def test_prepare_activations_numpy_tuple_beliefs(self, factored_belief_data): + """prepare_activations should accept tuple of numpy arrays.""" + np_factor_0 = np.asarray(factored_belief_data["factored_beliefs"][0]) + np_factor_1 = np.asarray(factored_belief_data["factored_beliefs"][1]) + np_beliefs = (np_factor_0, np_factor_1) + + result = prepare_activations( + factored_belief_data["inputs"], + np_beliefs, + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + # Should be converted to JAX arrays + assert isinstance(result.belief_states[0], jnp.ndarray) + assert isinstance(result.belief_states[1], jnp.ndarray) + + def test_linear_regression_with_multiple_factors(self, factored_belief_data): + """LinearRegressionAnalysis with multi-factor tuple should regress to each factor separately.""" + analysis = LinearRegressionAnalysis() + + prepared = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have separate metrics for each factor + # Format is: layer_name_factor_idx/metric_name + assert "layer_0_factor_0/r2" in scalars + assert "layer_0_factor_1/r2" in scalars + assert "layer_0_factor_0/rmse" in scalars + assert "layer_0_factor_1/rmse" in scalars + assert "layer_0_factor_0/mae" in scalars + assert "layer_0_factor_1/mae" in scalars + assert "layer_0_factor_0/dist" in scalars + assert "layer_0_factor_1/dist" in scalars + + assert "layer_1_factor_0/r2" in scalars + assert "layer_1_factor_1/r2" in scalars + + # Should have separate projections for each factor + assert "layer_0_factor_0/projected" in projections + assert "layer_0_factor_1/projected" in projections + assert "layer_1_factor_0/projected" in projections + assert "layer_1_factor_1/projected" in projections + + # Check projection shapes + batch_size = factored_belief_data["batch_size"] + assert projections["layer_0_factor_0/projected"].shape == (batch_size, factored_belief_data["factor_0_dim"]) + assert projections["layer_0_factor_1/projected"].shape == (batch_size, factored_belief_data["factor_1_dim"]) + + def test_linear_regression_svd_with_multiple_factors(self, factored_belief_data): + """LinearRegressionSVDAnalysis with multi-factor tuple should regress to each factor separately.""" + analysis = LinearRegressionSVDAnalysis(rcond_values=[1e-10]) + + prepared = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have separate metrics for each factor including best_rcond + assert "layer_0_factor_0/r2" in scalars + assert "layer_0_factor_1/r2" in scalars + assert "layer_0_factor_0/best_rcond" in scalars + assert "layer_0_factor_1/best_rcond" in scalars + + # Should have separate projections for each factor + assert "layer_0_factor_0/projected" in projections + assert "layer_0_factor_1/projected" in projections + + def test_tracker_with_factored_beliefs(self, factored_belief_data): + """ActivationTracker should work with tuple belief states.""" + tracker = ActivationTracker( + { + "regression": LinearRegressionAnalysis( + last_token_only=True, + concat_layers=False, + ), + "pca": PcaAnalysis( + n_components=2, + last_token_only=True, + concat_layers=False, + ), + } + ) + + scalars, projections = tracker.analyze( + inputs=factored_belief_data["inputs"], + beliefs=factored_belief_data["factored_beliefs"], + probs=factored_belief_data["probs"], + activations=factored_belief_data["activations"], + ) + + # Regression should have per-factor metrics + assert "regression/layer_0_factor_0/r2" in scalars + assert "regression/layer_0_factor_1/r2" in scalars + + # PCA should still work (doesn't use belief states) + assert "pca/layer_0_variance_explained" in scalars + + # Projections should be present + assert "regression/layer_0_factor_0/projected" in projections + assert "regression/layer_0_factor_1/projected" in projections + assert "pca/layer_0_pca" in projections + + def test_single_factor_tuple(self, synthetic_data): + """Test with a single-factor tuple (edge case).""" + # Create single-factor tuple + single_factor = (synthetic_data["beliefs"],) + + result = prepare_activations( + synthetic_data["inputs"], + single_factor, + synthetic_data["probs"], + synthetic_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 1 + assert result.belief_states[0].shape == (synthetic_data["batch_size"], synthetic_data["belief_dim"]) + + def test_linear_regression_single_factor_tuple_behaves_like_non_tuple(self, synthetic_data): + """LinearRegressionAnalysis with single-factor tuple should behave like non-tuple (no factor keys).""" + single_factor = (synthetic_data["beliefs"],) + analysis = LinearRegressionAnalysis() + + prepared = prepare_activations( + synthetic_data["inputs"], + single_factor, + synthetic_data["probs"], + synthetic_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have simple keys without "factor_" prefix + assert "layer_0_r2" in scalars + assert "layer_0_rmse" in scalars + assert "layer_0_projected" in projections + + # Should NOT have factor keys + assert "layer_0_factor_0/r2" not in scalars + assert "layer_0_factor_0/projected" not in projections + + def test_linear_regression_concat_belief_states(self, factored_belief_data): + """LinearRegressionAnalysis with concat_belief_states=True should return both factor and concat results.""" + analysis = LinearRegressionAnalysis(concat_belief_states=True) + + prepared = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have per-factor results + assert "layer_0_factor_0/r2" in scalars + assert "layer_0_factor_1/r2" in scalars + assert "layer_0_factor_0/projected" in projections + assert "layer_0_factor_1/projected" in projections + + # Should ALSO have concatenated results + assert "layer_0_concat/r2" in scalars + assert "layer_0_concat/rmse" in scalars + assert "layer_0_concat/projected" in projections + + # Check concatenated projection shape (should be sum of factor dimensions) + batch_size = factored_belief_data["batch_size"] + total_dim = factored_belief_data["factor_0_dim"] + factored_belief_data["factor_1_dim"] + assert projections["layer_0_concat/projected"].shape == (batch_size, total_dim) + + def test_three_factor_tuple(self, factored_belief_data): + """Test with three factors to ensure generalization.""" + batch_size = factored_belief_data["batch_size"] + seq_len = factored_belief_data["seq_len"] + + # Add a third factor + factor_0 = jnp.ones((batch_size, seq_len, 3)) * 0.3 + factor_1 = jnp.ones((batch_size, seq_len, 2)) * 0.5 + factor_2 = jnp.ones((batch_size, seq_len, 4)) * 0.7 + three_factor_beliefs = (factor_0, factor_1, factor_2) + + result = prepare_activations( + factored_belief_data["inputs"], + three_factor_beliefs, + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 3 + assert result.belief_states[0].shape == (batch_size, 3) + assert result.belief_states[1].shape == (batch_size, 2) + assert result.belief_states[2].shape == (batch_size, 4) + + + class TestScalarSeriesMapping: """Tests for scalar_series dataframe construction.""" From 8a16ab71254a7ce0101a7292b0eacfae7980f297 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 13:37:46 -0800 Subject: [PATCH 20/48] Fix validator error message and fix linting issues --- simplexity/analysis/layerwise_analysis.py | 2 +- simplexity/analysis/linear_regression.py | 4 +--- tests/analysis/test_layerwise_analysis.py | 14 ++++++++++++-- tests/analysis/test_linear_regression.py | 10 +++++----- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index a18fae41..f6d3cbb9 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -45,7 +45,7 @@ def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd", "rcond_values"} unexpected = set(provided) - allowed if unexpected: - raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}") + raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}") fit_intercept = bool(provided.get("fit_intercept", True)) concat_belief_states = bool(provided.get("concat_belief_states", False)) compute_subspace_orthogonality = bool(provided.get("compute_subspace_orthogonality", False)) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index b0209f13..3fbff516 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -303,7 +303,6 @@ def _compute_subspace_orthogonality( # Compute the min singular value min_singular_value = jnp.min(singular_values) - # Check if subspace is degenerate if max_singular_value == 0: SIMPLEXITY_LOGGER.warning( @@ -319,7 +318,7 @@ def _compute_subspace_orthogonality( probs_denominator = jnp.sum(singular_values**2) # Compute the participation ratio - participation_ratio = jnp.sum(singular_values**2)**2 / pratio_denominator + participation_ratio = jnp.sum(singular_values**2) ** 2 / pratio_denominator # Compute the entropy probs = singular_values**2 / probs_denominator @@ -478,7 +477,6 @@ def layer_linear_regression( scalars: Dictionary of scalar metrics arrays: Dictionary of arrays (projected predictions, parameters, singular values if orthogonality computed) """ - # If no belief states are provided, raise an error if ( belief_states is None diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index 2df502d2..cd3a6b82 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -30,7 +30,14 @@ def test_layerwise_analysis_linear_regression_namespacing(analysis_inputs) -> No ) assert set(scalars) >= {"layer_a_r2", "layer_b_r2"} - assert set(projections) == {"layer_a_projected", "layer_b_projected", "layer_a_coeffs", "layer_b_coeffs", "layer_a_intercept", "layer_b_intercept"} + assert set(projections) == { + "layer_a_projected", + "layer_b_projected", + "layer_a_coeffs", + "layer_b_coeffs", + "layer_a_intercept", + "layer_b_intercept", + } def test_layerwise_analysis_requires_targets(analysis_inputs) -> None: @@ -100,7 +107,7 @@ def test_linear_regression_svd_kwargs_validation_errors() -> None: def test_linear_regression_svd_rejects_unexpected_kwargs() -> None: """Unexpected SVD kwargs should raise clear errors.""" - with pytest.raises(ValueError, match="Unexpected linear_regression_svd kwargs"): + with pytest.raises(ValueError, match="Unexpected linear_regression kwargs"): LayerwiseAnalysis( "linear_regression_svd", analysis_kwargs={"bad": True}, @@ -189,6 +196,7 @@ def test_linear_regression_concat_belief_states_defaults_false() -> None: assert params["concat_belief_states"] is False + def test_linear_regression_accepts_compute_subspace_orthogonality() -> None: """linear_regression validator should accept compute_subspace_orthogonality parameter.""" validator = ANALYSIS_REGISTRY["linear_regression"].validator @@ -197,6 +205,7 @@ def test_linear_regression_accepts_compute_subspace_orthogonality() -> None: assert params["fit_intercept"] is True assert params["compute_subspace_orthogonality"] is True + def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: """linear_regression_svd validator should accept compute_subspace_orthogonality parameter.""" validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator @@ -206,6 +215,7 @@ def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: assert params["compute_subspace_orthogonality"] is True assert params["rcond_values"] == (0.001,) + def test_linear_regression_compute_subspace_orthogonality_defaults_false() -> None: """compute_subspace_orthogonality should default to False when not provided.""" validator = ANALYSIS_REGISTRY["linear_regression"].validator diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index b27d8c24..b5e2f0bf 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -333,13 +333,13 @@ def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: # Verify it matches non-tuple behavior scalars_non_tuple, arrays_non_tuple = layer_linear_regression(x, weights, factor_0) - + assert scalars.keys() == scalars_non_tuple.keys() assert arrays.keys() == arrays_non_tuple.keys() - for key in scalars.keys(): - assert scalars[key] == pytest.approx(scalars_non_tuple[key]) - for key in arrays.keys(): - assert arrays[key] == pytest.approx(arrays_non_tuple[key]) + for key, value in scalars_non_tuple.items(): + assert scalars[key] == pytest.approx(value) + for key, value in arrays_non_tuple.items(): + assert arrays[key] == pytest.approx(value) def test_orthogonality_with_orthogonal_subspaces() -> None: From 5b6247d6db5a8634e75bd2874e2b18ca946bc076 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 13:47:23 -0800 Subject: [PATCH 21/48] Add check requiring 2+ factors in _handle_factored_regression and remove redundant orthogonality compuations warning --- simplexity/analysis/linear_regression.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 3fbff516..1cfc940d 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -386,7 +386,10 @@ def _handle_factored_regression( use_svd: bool, **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """Handle regression for factored belief states using either standard or SVD method.""" + """Handle regression for two or more factored belief states using either standard or SVD method.""" + if len(belief_states) < 2: + raise ValueError("At least two factors are required for factored regression") + scalars: dict[str, float] = {} arrays: dict[str, jax.Array] = {} @@ -413,14 +416,11 @@ def _handle_factored_regression( _merge_results_with_prefix(scalars, arrays, factor_result, f"factor_{factor_idx}") if compute_subspace_orthogonality: - if len(belief_states) > 1: - # Extract coefficients (excludes intercept) for orthogonality computation - coeffs_list = [factor_arrays["coeffs"] for _, factor_arrays in factor_results] - orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(coeffs_list) - scalars.update(orthogonality_scalars) - arrays.update(orthogonality_singular_values) - else: - SIMPLEXITY_LOGGER.warning("Subspace orthogonality cannot be computed for a single belief state") + # Extract coefficients (excludes intercept) for orthogonality computation + coeffs_list = [factor_arrays["coeffs"] for _, factor_arrays in factor_results] + orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(coeffs_list) + scalars.update(orthogonality_scalars) + arrays.update(orthogonality_singular_values) return scalars, arrays From 43123af296d8e33b90b4660a18c36d03d4abbcdd Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 13:50:03 -0800 Subject: [PATCH 22/48] Add proper spacing to warning messages --- simplexity/analysis/linear_regression.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 1cfc940d..5e36511e 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -307,9 +307,9 @@ def _compute_subspace_orthogonality( if max_singular_value == 0: SIMPLEXITY_LOGGER.warning( "Degenerate subspace detected during orthogonality computation." - "All singular values are zero." - "Setting probability values to zero." - "Setting participation ratio to zero." + " All singular values are zero." + " Setting probability values to zero." + " Setting participation ratio to zero." ) pratio_denominator = 1.0 probs_denominator = 1.0 @@ -326,8 +326,8 @@ def _compute_subspace_orthogonality( if num_zeros > 0: SIMPLEXITY_LOGGER.warning( f"Encountered {num_zeros} probability values of zero during entropy computation." - "This is likely due to numerical instability." - "Setting corresponding entropy contribution to zero." + " This is likely due to numerical instability." + " Setting corresponding entropy contribution to zero." ) nonzero_probs = probs[probs > 0] entropy = -jnp.sum(nonzero_probs * jnp.log(nonzero_probs)) @@ -491,7 +491,7 @@ def layer_linear_regression( if compute_subspace_orthogonality: SIMPLEXITY_LOGGER.warning( "Subspace orthogonality requires multiple factors." - "Received single factor of type %s; skipping orthogonality metrics.", + " Received single factor of type %s; skipping orthogonality metrics.", type(belief_states).__name__, ) belief_states = belief_states[0] if isinstance(belief_states, tuple) else belief_states From 729222dddb236f74ea14697a68f3e08cde50b0da Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 13:56:06 -0800 Subject: [PATCH 23/48] Fix dictionary equivalence check in test_linear_regression and add blank line after docstring in test_layerwise_analysis --- tests/analysis/test_layerwise_analysis.py | 22 ++++++++++++++++++++++ tests/analysis/test_linear_regression.py | 10 +++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index cd3a6b82..f3358885 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -9,6 +9,7 @@ @pytest.fixture def analysis_inputs() -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jnp.ndarray]: """Provides sample activations, weights, and belief states for analysis tests.""" + activations = { "layer_a": jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), "layer_b": jnp.array([[2.0, 1.0], [1.0, 2.0], [0.0, 1.0]]), @@ -20,6 +21,7 @@ def analysis_inputs() -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jnp.ndarray] def test_layerwise_analysis_linear_regression_namespacing(analysis_inputs) -> None: """Metrics and projections should be namespace-qualified per layer.""" + activations, weights, belief_states = analysis_inputs analysis = LayerwiseAnalysis("linear_regression", last_token_only=True) @@ -42,6 +44,7 @@ def test_layerwise_analysis_linear_regression_namespacing(analysis_inputs) -> No def test_layerwise_analysis_requires_targets(analysis_inputs) -> None: """Analyses that need belief states should validate input.""" + activations, weights, _ = analysis_inputs analysis = LayerwiseAnalysis("linear_regression") @@ -51,12 +54,14 @@ def test_layerwise_analysis_requires_targets(analysis_inputs) -> None: def test_invalid_analysis_type_raises() -> None: """Unknown analysis types should raise clear errors.""" + with pytest.raises(ValueError, match="Unknown analysis_type"): LayerwiseAnalysis("unknown") def test_invalid_kwargs_validation() -> None: """Validator rejects unsupported kwargs for a registered analysis.""" + with pytest.raises(ValueError, match="Unexpected linear_regression kwargs"): LayerwiseAnalysis( "linear_regression", @@ -66,6 +71,7 @@ def test_invalid_kwargs_validation() -> None: def test_pca_analysis_does_not_require_beliefs(analysis_inputs) -> None: """PCA analysis should run without belief states and namespace results.""" + activations, weights, _ = analysis_inputs analysis = LayerwiseAnalysis( "pca", @@ -83,6 +89,7 @@ def test_pca_analysis_does_not_require_beliefs(analysis_inputs) -> None: def test_invalid_pca_kwargs() -> None: """Invalid PCA kwargs should raise helpful errors.""" + with pytest.raises(ValueError, match="n_components must be positive"): LayerwiseAnalysis( "pca", @@ -92,6 +99,7 @@ def test_invalid_pca_kwargs() -> None: def test_linear_regression_svd_kwargs_validation_errors() -> None: """SVD-specific validators should reject unsupported inputs.""" + with pytest.raises(TypeError, match="rcond_values must be a sequence"): LayerwiseAnalysis( "linear_regression_svd", @@ -107,6 +115,7 @@ def test_linear_regression_svd_kwargs_validation_errors() -> None: def test_linear_regression_svd_rejects_unexpected_kwargs() -> None: """Unexpected SVD kwargs should raise clear errors.""" + with pytest.raises(ValueError, match="Unexpected linear_regression kwargs"): LayerwiseAnalysis( "linear_regression_svd", @@ -116,6 +125,7 @@ def test_linear_regression_svd_rejects_unexpected_kwargs() -> None: def test_linear_regression_svd_kwargs_are_normalized() -> None: """Validator should coerce mixed numeric types to floats.""" + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator params = validator({"rcond_values": [1, 1e-3]}) @@ -124,6 +134,7 @@ def test_linear_regression_svd_kwargs_are_normalized() -> None: def test_pca_kwargs_require_int_components() -> None: """PCA validator should enforce integral n_components.""" + with pytest.raises(TypeError, match="n_components must be an int or None"): LayerwiseAnalysis( "pca", @@ -133,6 +144,7 @@ def test_pca_kwargs_require_int_components() -> None: def test_pca_kwargs_require_sequence_thresholds() -> None: """Variance thresholds must be sequences with valid ranges.""" + with pytest.raises(TypeError, match="variance_thresholds must be a sequence"): LayerwiseAnalysis( "pca", @@ -148,6 +160,7 @@ def test_pca_kwargs_require_sequence_thresholds() -> None: def test_pca_rejects_unexpected_kwargs() -> None: """Unexpected PCA kwargs should surface informative errors.""" + with pytest.raises(ValueError, match="Unexpected pca kwargs"): LayerwiseAnalysis( "pca", @@ -157,6 +170,7 @@ def test_pca_rejects_unexpected_kwargs() -> None: def test_layerwise_analysis_property_accessors() -> None: """Constructor flags should surface via property accessors.""" + analysis = LayerwiseAnalysis( "pca", last_token_only=True, @@ -172,6 +186,7 @@ def test_layerwise_analysis_property_accessors() -> None: def test_linear_regression_accepts_concat_belief_states() -> None: """linear_regression validator should accept concat_belief_states parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator params = validator({"fit_intercept": False, "concat_belief_states": True}) @@ -181,6 +196,7 @@ def test_linear_regression_accepts_concat_belief_states() -> None: def test_linear_regression_svd_accepts_concat_belief_states() -> None: """linear_regression_svd validator should accept concat_belief_states parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator params = validator({"fit_intercept": True, "concat_belief_states": True, "rcond_values": [1e-3]}) @@ -191,6 +207,7 @@ def test_linear_regression_svd_accepts_concat_belief_states() -> None: def test_linear_regression_concat_belief_states_defaults_false() -> None: """concat_belief_states should default to False when not provided.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator params = validator({"fit_intercept": True}) @@ -199,6 +216,7 @@ def test_linear_regression_concat_belief_states_defaults_false() -> None: def test_linear_regression_accepts_compute_subspace_orthogonality() -> None: """linear_regression validator should accept compute_subspace_orthogonality parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator params = validator({"fit_intercept": True, "compute_subspace_orthogonality": True}) @@ -208,6 +226,7 @@ def test_linear_regression_accepts_compute_subspace_orthogonality() -> None: def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: """linear_regression_svd validator should accept compute_subspace_orthogonality parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator params = validator({"fit_intercept": True, "compute_subspace_orthogonality": True, "rcond_values": [1e-3]}) @@ -218,6 +237,7 @@ def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: def test_linear_regression_compute_subspace_orthogonality_defaults_false() -> None: """compute_subspace_orthogonality should default to False when not provided.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator params = validator({"fit_intercept": True}) @@ -226,6 +246,7 @@ def test_linear_regression_compute_subspace_orthogonality_defaults_false() -> No def test_linear_regression_accepts_use_svd() -> None: """linear_regression validator should accept use_svd parameter.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator params = validator({"use_svd": True}) @@ -234,6 +255,7 @@ def test_linear_regression_accepts_use_svd() -> None: def test_linear_regression_use_svd_defaults_false() -> None: """use_svd should default to False when not provided.""" + validator = ANALYSIS_REGISTRY["linear_regression"].validator params = validator({}) diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index b5e2f0bf..11958dbd 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -333,13 +333,9 @@ def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: # Verify it matches non-tuple behavior scalars_non_tuple, arrays_non_tuple = layer_linear_regression(x, weights, factor_0) - - assert scalars.keys() == scalars_non_tuple.keys() - assert arrays.keys() == arrays_non_tuple.keys() - for key, value in scalars_non_tuple.items(): - assert scalars[key] == pytest.approx(value) - for key, value in arrays_non_tuple.items(): - assert arrays[key] == pytest.approx(value) + + chex.assert_trees_all_close(scalars, scalars_non_tuple) + chex.assert_trees_all_close(arrays, arrays_non_tuple) def test_orthogonality_with_orthogonal_subspaces() -> None: From 2e8829f8800790a5326e567f0892c3762dda331d Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 14:31:49 -0800 Subject: [PATCH 24/48] Refactor subspace orthogonality computation for JIT compatibility --- simplexity/analysis/linear_regression.py | 56 ++++++++++++------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 5e36511e..e6daed67 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -7,6 +7,7 @@ from typing import Any import jax +from jax.debug import callback import jax.numpy as jnp import numpy as np @@ -289,58 +290,55 @@ def _compute_subspace_orthogonality( # Compute the singular values of the interaction matrix interaction_matrix = q1.T @ q2 singular_values = jnp.linalg.svd(interaction_matrix, compute_uv=False) - - # Clip the singular values to the range [0, 1] singular_values = jnp.clip(singular_values, 0, 1) # Compute the subspace overlap score min_dim = min(q1.shape[1], q2.shape[1]) - subspace_overlap_score = jnp.sum(singular_values**2) / min_dim - - # Compute the max singular value - max_singular_value = jnp.max(singular_values) + sum_sq_sv = jnp.sum(singular_values**2) + sum_quad_sv = jnp.sum(singular_values**4) - # Compute the min singular value - min_singular_value = jnp.min(singular_values) + is_degenerate = sum_quad_sv == 0 - # Check if subspace is degenerate - if max_singular_value == 0: + def log_all_zeros(_): SIMPLEXITY_LOGGER.warning( "Degenerate subspace detected during orthogonality computation." " All singular values are zero." - " Setting probability values to zero." - " Setting participation ratio to zero." + " Setting probability values and participation ratio to zero." ) - pratio_denominator = 1.0 - probs_denominator = 1.0 - else: - pratio_denominator = jnp.sum(singular_values**4) - probs_denominator = jnp.sum(singular_values**2) + + callback(log_all_zeros, sum_sq_sv, ordered=True, when=is_degenerate) - # Compute the participation ratio - participation_ratio = jnp.sum(singular_values**2) ** 2 / pratio_denominator + pratio_denominator_safe = jnp.where(is_degenerate, 1.0, sum_quad_sv) + probs_denominator_safe = jnp.where(is_degenerate, 1.0, sum_sq_sv) + participation_ratio = sum_sq_sv**2 / pratio_denominator_safe - # Compute the entropy - probs = singular_values**2 / probs_denominator - num_zeros = jnp.sum(probs == 0) - if num_zeros > 0: + subspace_overlap_score = sum_sq_sv / min_dim + + # Compute the entropy probabilities + probs = singular_values**2 / probs_denominator_safe + + def log_some_zeros(num_zeros_array: jax.Array) -> None: + num_zeros = num_zeros_array.item() SIMPLEXITY_LOGGER.warning( f"Encountered {num_zeros} probability values of zero during entropy computation." " This is likely due to numerical instability." " Setting corresponding entropy contribution to zero." ) - nonzero_probs = probs[probs > 0] - entropy = -jnp.sum(nonzero_probs * jnp.log(nonzero_probs)) - else: - entropy = -jnp.sum(probs * jnp.log(probs)) + + num_zeros = jnp.sum(probs == 0) + has_some_zeros = num_zeros > 0 + callback(log_some_zeros, num_zeros, ordered=True, when=has_some_zeros) + + p_log_p = probs * jnp.log(probs) + entropy = -jnp.sum(jnp.where(probs > 0, p_log_p, 0.0)) # Compute the effective rank effective_rank = jnp.exp(entropy) scalars = { "subspace_overlap": float(subspace_overlap_score), - "max_singular_value": float(max_singular_value), - "min_singular_value": float(min_singular_value), + "max_singular_value": float(jnp.max(singular_values)), + "min_singular_value": float(jnp.min(singular_values)), "participation_ratio": float(participation_ratio), "entropy": float(entropy), "effective_rank": float(effective_rank), From 4136030d426552134ce8a9fa1fea071127eaf63f Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 14:52:03 -0800 Subject: [PATCH 25/48] Fix conditional callback execution using jax.lax.cond --- simplexity/analysis/linear_regression.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index e6daed67..b26712fc 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -298,6 +298,15 @@ def _compute_subspace_orthogonality( sum_quad_sv = jnp.sum(singular_values**4) is_degenerate = sum_quad_sv == 0 + + # Define the False branch function (does nothing) + def do_nothing_branch(x): + return None + + # Define the True branch function (runs the callback) + def execute_all_zeros_warning_branch(x): + callback(log_all_zeros, x) + return None def log_all_zeros(_): SIMPLEXITY_LOGGER.warning( @@ -306,7 +315,7 @@ def log_all_zeros(_): " Setting probability values and participation ratio to zero." ) - callback(log_all_zeros, sum_sq_sv, ordered=True, when=is_degenerate) + jax.lax.cond(is_degenerate, execute_all_zeros_warning_branch, do_nothing_branch, sum_sq_sv) pratio_denominator_safe = jnp.where(is_degenerate, 1.0, sum_quad_sv) probs_denominator_safe = jnp.where(is_degenerate, 1.0, sum_sq_sv) @@ -317,6 +326,11 @@ def log_all_zeros(_): # Compute the entropy probabilities probs = singular_values**2 / probs_denominator_safe + def execute_some_zeros_warning_branch(x): + # This correctly calls the log_some_zeros function + callback(log_some_zeros, x) + return None + def log_some_zeros(num_zeros_array: jax.Array) -> None: num_zeros = num_zeros_array.item() SIMPLEXITY_LOGGER.warning( @@ -327,7 +341,7 @@ def log_some_zeros(num_zeros_array: jax.Array) -> None: num_zeros = jnp.sum(probs == 0) has_some_zeros = num_zeros > 0 - callback(log_some_zeros, num_zeros, ordered=True, when=has_some_zeros) + jax.lax.cond(has_some_zeros, execute_some_zeros_warning_branch, do_nothing_branch, num_zeros) p_log_p = probs * jnp.log(probs) entropy = -jnp.sum(jnp.where(probs > 0, p_log_p, 0.0)) From 2be20325452ad3e91e88055aad4e0be54381df14 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 14:57:58 -0800 Subject: [PATCH 26/48] Fix linting and formatting issues --- simplexity/analysis/linear_regression.py | 2 +- tests/analysis/test_layerwise_analysis.py | 2 +- tests/analysis/test_linear_regression.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index b26712fc..2ce06f13 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -7,9 +7,9 @@ from typing import Any import jax -from jax.debug import callback import jax.numpy as jnp import numpy as np +from jax.debug import callback from simplexity.analysis.normalization import normalize_weights, standardize_features, standardize_targets from simplexity.logger import SIMPLEXITY_LOGGER diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index f3358885..5be4db4f 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -9,7 +9,7 @@ @pytest.fixture def analysis_inputs() -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jnp.ndarray]: """Provides sample activations, weights, and belief states for analysis tests.""" - + activations = { "layer_a": jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), "layer_b": jnp.array([[2.0, 1.0], [1.0, 2.0], [0.0, 1.0]]), diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index 11958dbd..38dad455 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -333,7 +333,7 @@ def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: # Verify it matches non-tuple behavior scalars_non_tuple, arrays_non_tuple = layer_linear_regression(x, weights, factor_0) - + chex.assert_trees_all_close(scalars, scalars_non_tuple) chex.assert_trees_all_close(arrays, arrays_non_tuple) From f77f2f5eaa3f0713d3640795af4edf7527037de0 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 15:01:54 -0800 Subject: [PATCH 27/48] Fix formatting issues --- simplexity/analysis/linear_regression.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 2ce06f13..16b938ae 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -298,7 +298,7 @@ def _compute_subspace_orthogonality( sum_quad_sv = jnp.sum(singular_values**4) is_degenerate = sum_quad_sv == 0 - + # Define the False branch function (does nothing) def do_nothing_branch(x): return None @@ -306,7 +306,7 @@ def do_nothing_branch(x): # Define the True branch function (runs the callback) def execute_all_zeros_warning_branch(x): callback(log_all_zeros, x) - return None + return None def log_all_zeros(_): SIMPLEXITY_LOGGER.warning( @@ -314,7 +314,7 @@ def log_all_zeros(_): " All singular values are zero." " Setting probability values and participation ratio to zero." ) - + jax.lax.cond(is_degenerate, execute_all_zeros_warning_branch, do_nothing_branch, sum_sq_sv) pratio_denominator_safe = jnp.where(is_degenerate, 1.0, sum_quad_sv) @@ -338,7 +338,7 @@ def log_some_zeros(num_zeros_array: jax.Array) -> None: " This is likely due to numerical instability." " Setting corresponding entropy contribution to zero." ) - + num_zeros = jnp.sum(probs == 0) has_some_zeros = num_zeros > 0 jax.lax.cond(has_some_zeros, execute_some_zeros_warning_branch, do_nothing_branch, num_zeros) @@ -401,7 +401,7 @@ def _handle_factored_regression( """Handle regression for two or more factored belief states using either standard or SVD method.""" if len(belief_states) < 2: raise ValueError("At least two factors are required for factored regression") - + scalars: dict[str, float] = {} arrays: dict[str, jax.Array] = {} From 7af2bc426842fa298c0b11f39ed2e9ddd1fce159 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 15:10:24 -0800 Subject: [PATCH 28/48] Disable too-many-locals linting issue in test_linear_regression.py --- tests/analysis/test_linear_regression.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index 38dad455..ea6c937c 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -1,6 +1,7 @@ """Tests for reusable linear regression helpers.""" # pylint: disable=too-many-lines +# pylint: disable=too-many-locals import chex import jax From 6ee64fa8bbd9a0d8a285d24fc52e3dd605bdf9e0 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 15:25:27 -0800 Subject: [PATCH 29/48] Change name of return dict from singular_values -> arrays for clarity --- simplexity/analysis/linear_regression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 16b938ae..6446b747 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -358,11 +358,11 @@ def log_some_zeros(num_zeros_array: jax.Array) -> None: "effective_rank": float(effective_rank), } - singular_values = { + arrays = { "singular_values": singular_values, } - return scalars, singular_values + return scalars, arrays def _compute_all_pairwise_orthogonality( From 84006da8ef10dd6f5b56946527e12b452a5f330d Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 15:31:11 -0800 Subject: [PATCH 30/48] Add docstring describing return values for _compute_all_pairwise_orthogonality function --- simplexity/analysis/linear_regression.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 6446b747..abcbabd5 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -372,21 +372,26 @@ def _compute_all_pairwise_orthogonality( Args: coeffs_list: List of coefficient matrices (one per factor, excludes intercepts) + + Returns: + Tuple[dict[str, float], dict[str, jax.Array]]: + - scalars: Dictionary mapping keys of the form "orthogonality_{i}_{j}/" to scalar float metrics for + each pair of factors (i, j). + - arrays: Dictionary mapping keys of the form "orthogonality_{i}_{j}/" to array-valued + metrics for each pair of factors (i, j). """ scalars = {} - singular_values = {} + arrays = {} factor_pairs = list(itertools.combinations(range(len(coeffs_list)), 2)) for i, j in factor_pairs: coeffs_pair = [ coeffs_list[i], coeffs_list[j], ] - orthogonality_scalars, orthogonality_singular_values = _compute_subspace_orthogonality(coeffs_pair) + orthogonality_scalars, orthogonality_arrays = _compute_subspace_orthogonality(coeffs_pair) scalars.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items()}) - singular_values.update( - {f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_singular_values.items()} - ) - return scalars, singular_values + arrays.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_arrays.items()}) + return scalars, arrays def _handle_factored_regression( From 556fede4e9ca040272667774fae8d0224ab162fa Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 15:34:00 -0800 Subject: [PATCH 31/48] Add docstring describing relevance of the do_nothing_branch function --- simplexity/analysis/linear_regression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index abcbabd5..df7f5e17 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -301,6 +301,8 @@ def _compute_subspace_orthogonality( # Define the False branch function (does nothing) def do_nothing_branch(x): + """JAX 'False' branch function. Serves only to return a value that matches the 'True' branch's type (None) for + jax.lax.cond.""" return None # Define the True branch function (runs the callback) @@ -327,7 +329,6 @@ def log_all_zeros(_): probs = singular_values**2 / probs_denominator_safe def execute_some_zeros_warning_branch(x): - # This correctly calls the log_some_zeros function callback(log_some_zeros, x) return None From 5b9801dc07b8b02ce04796b5a4ed5cb00e473a17 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 15:43:30 -0800 Subject: [PATCH 32/48] Refactor key removal method in kwarg validator and fix docstring format --- simplexity/analysis/layerwise_analysis.py | 10 ++++++++++ simplexity/analysis/linear_regression.py | 6 ++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index f6d3cbb9..b2a1f719 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -68,6 +68,16 @@ def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict } +def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + kwargs = _base_validate_linear_regression_kwargs(kwargs) + return {k: v for k, v in kwargs.items() if k != "rcond_values"} + + +def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + kwargs = _base_validate_linear_regression_kwargs(kwargs) + return {k: v for k, v in kwargs.items() if k != "use_svd"} + + def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) allowed = {"n_components", "variance_thresholds"} diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index df7f5e17..6265484c 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -301,8 +301,10 @@ def _compute_subspace_orthogonality( # Define the False branch function (does nothing) def do_nothing_branch(x): - """JAX 'False' branch function. Serves only to return a value that matches the 'True' branch's type (None) for - jax.lax.cond.""" + """JAX 'False' branch function. + + Serves only to return a value that matches the 'True' branch's type (None) for jax.lax.cond. + """ return None # Define the True branch function (runs the callback) From 06c7692270c13005db53edadc2aa99d7ff7c427a Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:01:08 -0800 Subject: [PATCH 33/48] Temporarily disable pylint checks during AST traversal in linear_regression.py to prevent crashes. Remove deprecated layer_linear_regression_svd function for cleaner code and encourage use of layer_linear_regression with use_svd=True. --- simplexity/analysis/linear_regression.py | 31 ++++++------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 6265484c..cba2b1b0 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -1,5 +1,12 @@ """Reusable linear regression utilities for activation analysis.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + from __future__ import annotations import itertools @@ -527,27 +534,3 @@ def layer_linear_regression( use_svd, **kwargs, ) - - -def layer_linear_regression_svd( - layer_activations: jax.Array, - weights: jax.Array, - belief_states: jax.Array | tuple[jax.Array, ...] | None, - concat_belief_states: bool = False, - compute_subspace_orthogonality: bool = False, - **kwargs: Any, -) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """Layer-wise SVD regression helper (wrapper around layer_linear_regression with use_svd=True). - - This function is provided for backward compatibility and convenience. - Consider using layer_linear_regression with use_svd=True for new code. - """ - return layer_linear_regression( - layer_activations=layer_activations, - weights=weights, - belief_states=belief_states, - concat_belief_states=concat_belief_states, - compute_subspace_orthogonality=compute_subspace_orthogonality, - use_svd=True, - **kwargs, - ) From 5bcbe03d1793680ce345019ce8b3af45d52c3d76 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:01:18 -0800 Subject: [PATCH 34/48] Refactor linear regression analysis registration to use partial application of layer_linear_regression with use_svd=True, removing the deprecated layer_linear_regression_svd function for improved clarity and consistency. --- simplexity/analysis/layerwise_analysis.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index b2a1f719..3b6c83be 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -11,14 +11,12 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass +from functools import partial from typing import Any import jax -from simplexity.analysis.linear_regression import ( - layer_linear_regression, - layer_linear_regression_svd, -) +from simplexity.analysis.linear_regression import layer_linear_regression from simplexity.analysis.pca import ( DEFAULT_VARIANCE_THRESHOLDS, layer_pca_analysis, @@ -110,7 +108,7 @@ def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: validator=_validate_linear_regression_kwargs, ), "linear_regression_svd": AnalysisRegistration( - fn=layer_linear_regression_svd, + fn=partial(layer_linear_regression, use_svd=True), requires_belief_states=True, validator=_validate_linear_regression_kwargs, ), From ed6981492e8503238abc272414f1e34680cc1518 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:03:30 -0800 Subject: [PATCH 35/48] Fix tests --- tests/analysis/test_layerwise_analysis.py | 8 ++++++++ tests/analysis/test_linear_regression.py | 25 +++++++++++++++-------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index 5be4db4f..b4c67121 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -1,5 +1,13 @@ """Tests for the LayerwiseAnalysis orchestrator.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import jax import jax.numpy as jnp import pytest diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index ea6c937c..e90f100b 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -1,5 +1,12 @@ """Tests for reusable linear regression helpers.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + # pylint: disable=too-many-lines # pylint: disable=too-many-locals @@ -10,7 +17,6 @@ from simplexity.analysis.linear_regression import ( layer_linear_regression, - layer_linear_regression_svd, linear_regression, linear_regression_svd, ) @@ -81,9 +87,6 @@ def test_layer_regression_requires_targets() -> None: with pytest.raises(ValueError, match="requires belief_states"): layer_linear_regression(x, weights, None) - with pytest.raises(ValueError, match="requires belief_states"): - layer_linear_regression_svd(x, weights, None) - def test_linear_regression_rejects_mismatched_weights() -> None: """Weights must align with the sample dimension.""" @@ -198,16 +201,17 @@ def test_linear_regression_svd_falls_back_to_default_rcond() -> None: assert scalars["best_rcond"] == pytest.approx(1e-15) -def test_layer_linear_regression_svd_runs_end_to_end() -> None: +def test_layer_linear_regression_runs_end_to_end() -> None: """Layer helper should proxy through to the base implementation.""" x = jnp.arange(6.0).reshape(3, 2) weights = jnp.ones(3) / 3.0 beliefs = 2.0 * x.sum(axis=1, keepdims=True) - scalars, arrays = layer_linear_regression_svd( + scalars, arrays = layer_linear_regression( x, weights, beliefs, + use_svd=True, rcond_values=[1e-3], ) @@ -272,10 +276,11 @@ def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: factor_1 = jnp.array([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3], [0.4, 0.4, 0.2], [0.3, 0.3, 0.4]]) # [4, 3] factored_beliefs = (factor_0, factor_1) - scalars, arrays = layer_linear_regression_svd( + scalars, arrays = layer_linear_regression( x, weights, factored_beliefs, + use_svd=True, rcond_values=[1e-6], ) @@ -602,10 +607,11 @@ def test_use_svd_flag_equivalence() -> None: ) # Method 2: layer_linear_regression_svd - scalars_wrapper, arrays_wrapper = layer_linear_regression_svd( + scalars_wrapper, arrays_wrapper = layer_linear_regression( x, weights, belief_state, + use_svd=True, rcond_values=rcond_values, ) @@ -635,10 +641,11 @@ def test_use_svd_flag_equivalence() -> None: ) # Method 2: layer_linear_regression_svd with factored beliefs - scalars_wrapper_fact, arrays_wrapper_fact = layer_linear_regression_svd( + scalars_wrapper_fact, arrays_wrapper_fact = layer_linear_regression( x, weights, factored_beliefs, + use_svd=True, rcond_values=rcond_values, ) From 46ce191d1e987949be81ede9cfcde14298aa76e8 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:05:34 -0800 Subject: [PATCH 36/48] Add detailed docstring to _compute_subspace_orthogonality function, specifying return values and their meanings for improved clarity and documentation. --- simplexity/analysis/linear_regression.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index cba2b1b0..02929777 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -290,6 +290,18 @@ def _compute_subspace_orthogonality( Args: coeffs_pair: List of two coefficient matrices (excludes intercept) + + Returns: + Tuple[dict[str, float], dict[str, jax.Array]]: A tuple containing: + - scalars: A dictionary with the following keys and float values: + - 'subspace_overlap': Average squared singular value (overlap score). + - 'max_singular_value': Largest singular value. + - 'min_singular_value': Smallest singular value. + - 'participation_ratio': Participation ratio of the singular values. + - 'entropy': Entropy of the squared singular values. + - 'effective_rank': Effective rank (exp(entropy)) of the singular value distribution. + - singular_values: A dictionary with a single key: + - 'singular_values': jax.Array of the singular values between the two subspaces. """ # Compute the orthonormal bases for the two subspaces using QR decomposition q1, _ = jnp.linalg.qr(coeffs_pair[0]) From 049b6d68720bf86a2fd60ce298e482c5f06a74b6 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:09:54 -0800 Subject: [PATCH 37/48] Add todo --- simplexity/analysis/linear_regression.py | 1 + 1 file changed, 1 insertion(+) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 02929777..73bd7f8d 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -303,6 +303,7 @@ def _compute_subspace_orthogonality( - singular_values: A dictionary with a single key: - 'singular_values': jax.Array of the singular values between the two subspaces. """ + # TODO: assumes coeff matrices are full ranks, should verify # Compute the orthonormal bases for the two subspaces using QR decomposition q1, _ = jnp.linalg.qr(coeffs_pair[0]) q2, _ = jnp.linalg.qr(coeffs_pair[1]) From c890e362363d1ea7658e44fd7e1ac42ed0e69921 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:26:44 -0800 Subject: [PATCH 38/48] Fix kwarg validation --- simplexity/analysis/layerwise_analysis.py | 36 +++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index 3b6c83be..04d8595f 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -44,26 +44,26 @@ def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict unexpected = set(provided) - allowed if unexpected: raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}") - fit_intercept = bool(provided.get("fit_intercept", True)) - concat_belief_states = bool(provided.get("concat_belief_states", False)) - compute_subspace_orthogonality = bool(provided.get("compute_subspace_orthogonality", False)) + resolved_kwargs = {} + resolved_kwargs["fit_intercept"] = bool(provided.get("fit_intercept", True)) + resolved_kwargs["concat_belief_states"] = bool(provided.get("concat_belief_states", False)) + resolved_kwargs["compute_subspace_orthogonality"] = bool(provided.get("compute_subspace_orthogonality", False)) use_svd = bool(provided.get("use_svd", False)) + resolved_kwargs["use_svd"] = use_svd rcond_values = provided.get("rcond_values") - if rcond_values is not None: - if not isinstance(rcond_values, (list, tuple)): - raise TypeError("rcond_values must be a sequence of floats") - if len(rcond_values) == 0: - raise ValueError("rcond_values must not be empty") - if not use_svd: - SIMPLEXITY_LOGGER.warning("rcond_values are only used when use_svd is True") - rcond_values = tuple(float(v) for v in rcond_values) - return { - "fit_intercept": fit_intercept, - "concat_belief_states": concat_belief_states, - "compute_subspace_orthogonality": compute_subspace_orthogonality, - "use_svd": use_svd, - "rcond_values": rcond_values, - } + if use_svd: + if rcond_values is not None: + if not isinstance(rcond_values, (list, tuple)): + raise TypeError("rcond_values must be a sequence of floats") + if len(rcond_values) == 0: + raise ValueError("rcond_values must not be empty") + if not use_svd: + SIMPLEXITY_LOGGER.warning("rcond_values are only used when use_svd is True") + rcond_values = tuple(float(v) for v in rcond_values) + resolved_kwargs["rcond_values"] = rcond_values + elif rcond_values is not None: + raise ValueError("rcond_values are only used when use_svd is True") + return resolved_kwargs def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: From 3a5a8e28dfbe00f48df90be247fc9dc8d801a3f7 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:31:18 -0800 Subject: [PATCH 39/48] Fix tests --- simplexity/analysis/layerwise_analysis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index 04d8595f..94ad5b9b 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -48,9 +48,10 @@ def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict resolved_kwargs["fit_intercept"] = bool(provided.get("fit_intercept", True)) resolved_kwargs["concat_belief_states"] = bool(provided.get("concat_belief_states", False)) resolved_kwargs["compute_subspace_orthogonality"] = bool(provided.get("compute_subspace_orthogonality", False)) - use_svd = bool(provided.get("use_svd", False)) - resolved_kwargs["use_svd"] = use_svd rcond_values = provided.get("rcond_values") + should_use_svd = rcond_values is not None + use_svd = bool(provided.get("use_svd", should_use_svd)) + resolved_kwargs["use_svd"] = use_svd if use_svd: if rcond_values is not None: if not isinstance(rcond_values, (list, tuple)): From 09876978c6f811d8fca1e9dfcbeebcb4ebd93ab1 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:54:59 -0800 Subject: [PATCH 40/48] Add validator decorator for linear_regression_svd to enforce use_svd=True and exclude it from output. Enhance tests to validate behavior. --- simplexity/analysis/layerwise_analysis.py | 21 ++++++++++++++------- tests/analysis/test_layerwise_analysis.py | 23 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index 94ad5b9b..60aa5cb7 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -67,14 +67,21 @@ def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict return resolved_kwargs -def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: - kwargs = _base_validate_linear_regression_kwargs(kwargs) - return {k: v for k, v in kwargs.items() if k != "rcond_values"} +def set_use_svd( + fn: ValidatorFn, +) -> ValidatorFn: + """Decorator to set use_svd to True in the kwargs and remove it from output to avoid duplicate with partial.""" + def wrapper(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + if kwargs and "use_svd" in kwargs and not kwargs["use_svd"]: + raise ValueError("use_svd cannot be set to False for linear_regression_svd") + modified_kwargs = dict(kwargs) if kwargs else {} # Make a copy to avoid mutating the input + modified_kwargs["use_svd"] = True + resolved = fn(modified_kwargs) + resolved.pop("use_svd", None) # Remove use_svd to avoid duplicate argument with partial + return resolved -def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: - kwargs = _base_validate_linear_regression_kwargs(kwargs) - return {k: v for k, v in kwargs.items() if k != "use_svd"} + return wrapper def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: @@ -111,7 +118,7 @@ def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: "linear_regression_svd": AnalysisRegistration( fn=partial(layer_linear_regression, use_svd=True), requires_belief_states=True, - validator=_validate_linear_regression_kwargs, + validator=set_use_svd(_validate_linear_regression_kwargs), ), "pca": AnalysisRegistration( fn=layer_pca_analysis, diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index b4c67121..51226f21 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -243,6 +243,29 @@ def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: assert params["rcond_values"] == (0.001,) +def test_linear_regression_svd_rejects_use_svd() -> None: + """linear_regression_svd validator should reject explicit use_svd parameter since it's bound in partial.""" + + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + + with pytest.raises(ValueError, match="use_svd cannot be specified for linear_regression_svd"): + validator({"use_svd": True}) + + with pytest.raises(ValueError, match="use_svd cannot be specified for linear_regression_svd"): + validator({"use_svd": False}) + + +def test_linear_regression_svd_excludes_use_svd_from_output() -> None: + """linear_regression_svd validator should not include use_svd in resolved kwargs.""" + + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + params = validator({"rcond_values": [1e-3]}) + + # use_svd should not be in the output since it's already bound in the partial + assert "use_svd" not in params + assert params["rcond_values"] == (0.001,) + + def test_linear_regression_compute_subspace_orthogonality_defaults_false() -> None: """compute_subspace_orthogonality should default to False when not provided.""" From 0f37809be10c9ce43c0fecb5e01cd1c3b1cdc977 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 12 Dec 2025 16:56:08 -0800 Subject: [PATCH 41/48] Fix test --- tests/analysis/test_layerwise_analysis.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index 51226f21..c80b2a5c 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -243,15 +243,14 @@ def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: assert params["rcond_values"] == (0.001,) -def test_linear_regression_svd_rejects_use_svd() -> None: +def test_linear_regression_svd_rejects_false_use_svd() -> None: """linear_regression_svd validator should reject explicit use_svd parameter since it's bound in partial.""" validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator - with pytest.raises(ValueError, match="use_svd cannot be specified for linear_regression_svd"): - validator({"use_svd": True}) + validator({"use_svd": True}) - with pytest.raises(ValueError, match="use_svd cannot be specified for linear_regression_svd"): + with pytest.raises(ValueError, match="use_svd cannot be set to False for linear_regression_svd"): validator({"use_svd": False}) From 028e047c709b0afcfe72ef5b810077f85198b36a Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 16:55:54 -0800 Subject: [PATCH 42/48] Add get_robust_basis for robust orthonormal basis extraction --- simplexity/analysis/linear_regression.py | 17 +++ tests/analysis/test_linear_regression.py | 126 +++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 73bd7f8d..ddcfc97d 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -283,6 +283,23 @@ def _split_concat_results( return results +def get_robust_basis(matrix: jax.Array) -> jax.Array: + """Extracts an orthonormal basis for the column space of the matrx. + + Handles rank deficiency gracefully by discarding directions associated with singular values below a + certain tolerance. + """ + u, s, _ = jnp.linalg.svd(matrix, full_matrices=False) + + max_dim = max(matrix.shape) + eps = jnp.finfo(matrix.dtype).eps + tol = s[0] * max_dim * eps + + valid_dims = s > tol + basis = u[:, valid_dims] + return basis + + def _compute_subspace_orthogonality( coeffs_pair: list[jax.Array], ) -> tuple[dict[str, float], dict[str, jax.Array]]: diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index e90f100b..c32766c1 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -16,6 +16,7 @@ import pytest from simplexity.analysis.linear_regression import ( + get_robust_basis, layer_linear_regression, linear_regression, linear_regression_svd, @@ -1085,3 +1086,128 @@ def test_layer_linear_regression_svd_concat_vs_separate_equivalence_best_rcond() atol=1e-6, rtol=0.0, ).item() + + +def test_get_robust_basis_full_rank(): + """Full rank matrix should return all basis vectors.""" + # Create a full rank 5x3 matrix + key = jax.random.PRNGKey(42) + matrix = jax.random.normal(key, (5, 3)) + + basis = get_robust_basis(matrix) + + # Should return 3 basis vectors (all columns are linearly independent) + assert basis.shape == (5, 3) + + # Basis should be orthonormal + # Error in Gram matrix scales with: n_basis * eps + eps = jnp.finfo(basis.dtype).eps + tol = basis.shape[1] * eps + gram = basis.T @ basis + assert jnp.allclose(gram, jnp.eye(3), atol=tol) + + +def test_get_robust_basis_rank_deficient(): + """Rank deficient matrix should filter out zero singular value directions.""" + # Create a rank-2 matrix with 3 columns (third is linear combination) + col1 = jnp.array([[1.0], [0.0], [0.0], [0.0]]) + col2 = jnp.array([[0.0], [1.0], [0.0], [0.0]]) + col3 = 2.0 * col1 + 3.0 * col2 # Linear combination, rank deficient + matrix = jnp.hstack([col1, col2, col3]) + + basis = get_robust_basis(matrix) + + # Should return only 2 basis vectors (true rank is 2) + assert basis.shape[1] == 2 + + # Basis should be orthonormal + # Error in Gram matrix scales with: n_basis * eps + eps = jnp.finfo(basis.dtype).eps + tol = basis.shape[1] * eps + gram = basis.T @ basis + assert jnp.allclose(gram, jnp.eye(2), atol=tol) + + +def test_get_robust_basis_zero_matrix(): + """Zero matrix should return empty basis.""" + matrix = jnp.zeros((5, 3)) + basis = get_robust_basis(matrix) + + # Should return empty basis (no valid directions) + assert basis.shape == (5, 0) + + +def test_get_robust_basis_near_rank_deficient(): + """Matrix with very small singular value should filter it out.""" + # Create matrix with controlled singular values using SVD construction + key = jax.random.PRNGKey(123) + u = jax.random.normal(key, (6, 3)) + u, _ = jnp.linalg.qr(u) # Orthonormalize + + # Singular values: [10.0, 1.0, 1e-10] - last one is tiny + s = jnp.array([10.0, 1.0, 1e-10]) + v = jnp.eye(3) + + matrix = u @ jnp.diag(s) @ v + basis = get_robust_basis(matrix) + + # Should filter out the tiny singular value, keeping only 2 vectors + assert basis.shape[1] == 2 + + # Basis should be orthonormal + # Error in Gram matrix scales with: n_basis * eps + eps = jnp.finfo(basis.dtype).eps + tol = basis.shape[1] * eps + gram = basis.T @ basis + assert jnp.allclose(gram, jnp.eye(2), atol=tol) + + +def test_get_robust_basis_preserves_column_space(): + """Basis should span the same space as the original matrix's columns.""" + # Create a known rank-2 matrix + col1 = jnp.array([[1.0], [0.0], [0.0], [0.0]]) + col2 = jnp.array([[0.0], [1.0], [0.0], [0.0]]) + col3 = 2 * col1 + 3 * col2 # Linear combination + matrix = jnp.hstack([col1, col2, col3]) + + basis = get_robust_basis(matrix) + + # Basis should be rank 2 + assert basis.shape[1] == 2 + + # Compute principled tolerance based on matrix properties + # Error in projection scales with: max_dim * eps * max_singular_value + max_dim = max(matrix.shape) + eps = jnp.finfo(matrix.dtype).eps + max_sv = jnp.linalg.svd(matrix, compute_uv=False)[0] + tol = max_dim * eps * max_sv + + # Each original column should be expressible as linear combination of basis + for i in range(3): + col = matrix[:, i : i + 1] + # Project onto basis + projection = basis @ (basis.T @ col) + # Should be very close to original (within numerical tolerance) + assert jnp.allclose(projection, col, atol=tol) + + +def test_get_robust_basis_single_vector(): + """Single non-zero column should return normalized version.""" + vector = jnp.array([[3.0], [4.0], [0.0]]) + basis = get_robust_basis(vector) + + # Should return one basis vector + assert basis.shape == (3, 1) + + # Should be unit norm + # Error in norm computation scales with: dimension * eps + dim = vector.shape[0] + eps = jnp.finfo(vector.dtype).eps + norm_tol = dim * eps + assert jnp.allclose(jnp.linalg.norm(basis), 1.0, atol=norm_tol) + + # Should be parallel to input + # Error in dot product scales with: dimension * eps * magnitude + expected_norm = jnp.linalg.norm(vector) + parallel_tol = dim * eps * expected_norm + assert jnp.allclose(jnp.abs(basis.T @ vector), expected_norm, atol=parallel_tol) From 0532cd2ce8ce6196c8cac672d75ad7a5da233bca Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 17:17:11 -0800 Subject: [PATCH 43/48] Pass pair of bases instead of coefficient matrices to _compute_subspace_orthogonality --- simplexity/analysis/linear_regression.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index ddcfc97d..0c6792ea 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -301,7 +301,7 @@ def get_robust_basis(matrix: jax.Array) -> jax.Array: def _compute_subspace_orthogonality( - coeffs_pair: list[jax.Array], + basis_pair: list[jax.Array], ) -> tuple[dict[str, float], dict[str, jax.Array]]: """Compute orthogonality metrics between two coefficient subspaces. @@ -320,10 +320,10 @@ def _compute_subspace_orthogonality( - singular_values: A dictionary with a single key: - 'singular_values': jax.Array of the singular values between the two subspaces. """ - # TODO: assumes coeff matrices are full ranks, should verify - # Compute the orthonormal bases for the two subspaces using QR decomposition - q1, _ = jnp.linalg.qr(coeffs_pair[0]) - q2, _ = jnp.linalg.qr(coeffs_pair[1]) + + q1 = basis_pair[0] + q2 = basis_pair[1] + # Compute the singular values of the interaction matrix interaction_matrix = q1.T @ q2 singular_values = jnp.linalg.svd(interaction_matrix, compute_uv=False) From 95060d1288e5c9dbedfac96f93d065f949bfa46b Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 17:22:41 -0800 Subject: [PATCH 44/48] Compute full rank and orthonormal basis of coeff matrices before passing bases to subspace analysis --- simplexity/analysis/linear_regression.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 0c6792ea..d0bbb9e5 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -320,7 +320,7 @@ def _compute_subspace_orthogonality( - singular_values: A dictionary with a single key: - 'singular_values': jax.Array of the singular values between the two subspaces. """ - + q1 = basis_pair[0] q2 = basis_pair[1] @@ -423,12 +423,10 @@ def _compute_all_pairwise_orthogonality( scalars = {} arrays = {} factor_pairs = list(itertools.combinations(range(len(coeffs_list)), 2)) + basis_list = [get_robust_basis(coeffs) for coeffs in coeffs_list] # ensures full rank and orthonormal basis for i, j in factor_pairs: - coeffs_pair = [ - coeffs_list[i], - coeffs_list[j], - ] - orthogonality_scalars, orthogonality_arrays = _compute_subspace_orthogonality(coeffs_pair) + basis_pair = [basis_list[i], basis_list[j]] + orthogonality_scalars, orthogonality_arrays = _compute_subspace_orthogonality(basis_pair) scalars.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items()}) arrays.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_arrays.items()}) return scalars, arrays From b0ecb64abff927e3ac31af5d0f429d7308850420 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 17:30:56 -0800 Subject: [PATCH 45/48] Fix formatting and docstring --- simplexity/analysis/linear_regression.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index d0bbb9e5..b06ea732 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -306,7 +306,7 @@ def _compute_subspace_orthogonality( """Compute orthogonality metrics between two coefficient subspaces. Args: - coeffs_pair: List of two coefficient matrices (excludes intercept) + basis_pair: List of two orthonormal basis matrices Returns: Tuple[dict[str, float], dict[str, jax.Array]]: A tuple containing: @@ -320,7 +320,6 @@ def _compute_subspace_orthogonality( - singular_values: A dictionary with a single key: - 'singular_values': jax.Array of the singular values between the two subspaces. """ - q1 = basis_pair[0] q2 = basis_pair[1] @@ -423,7 +422,7 @@ def _compute_all_pairwise_orthogonality( scalars = {} arrays = {} factor_pairs = list(itertools.combinations(range(len(coeffs_list)), 2)) - basis_list = [get_robust_basis(coeffs) for coeffs in coeffs_list] # ensures full rank and orthonormal basis + basis_list = [get_robust_basis(coeffs) for coeffs in coeffs_list] # ensures full rank and orthonormal basis for i, j in factor_pairs: basis_pair = [basis_list[i], basis_list[j]] orthogonality_scalars, orthogonality_arrays = _compute_subspace_orthogonality(basis_pair) From 7a026024db5c48ad87a82298fc34507ba5a202c7 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Fri, 12 Dec 2025 17:33:04 -0800 Subject: [PATCH 46/48] Update comment --- simplexity/analysis/linear_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index b06ea732..a0ef6eee 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -422,7 +422,7 @@ def _compute_all_pairwise_orthogonality( scalars = {} arrays = {} factor_pairs = list(itertools.combinations(range(len(coeffs_list)), 2)) - basis_list = [get_robust_basis(coeffs) for coeffs in coeffs_list] # ensures full rank and orthonormal basis + basis_list = [get_robust_basis(coeffs) for coeffs in coeffs_list] # computes orthonormal basis of coeff matrix for i, j in factor_pairs: basis_pair = [basis_list[i], basis_list[j]] orthogonality_scalars, orthogonality_arrays = _compute_subspace_orthogonality(basis_pair) From 69ff3e4394213d13a40b2391edafc920d094fa89 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Mon, 15 Dec 2025 18:47:58 -0800 Subject: [PATCH 47/48] Fix issues due to API changes in activation and dataframe tests --- tests/activations/test_activation_analysis.py | 2 +- tests/activations/test_dataframe_integration.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/activations/test_activation_analysis.py b/tests/activations/test_activation_analysis.py index f802cc2c..f51dd203 100644 --- a/tests/activations/test_activation_analysis.py +++ b/tests/activations/test_activation_analysis.py @@ -978,7 +978,7 @@ def test_tracker_with_factored_beliefs(self, factored_belief_data): } ) - scalars, projections = tracker.analyze( + scalars, projections, _ = tracker.analyze( inputs=factored_belief_data["inputs"], beliefs=factored_belief_data["factored_beliefs"], probs=factored_belief_data["probs"], diff --git a/tests/activations/test_dataframe_integration.py b/tests/activations/test_dataframe_integration.py index 4ca363fc..1fa8727d 100644 --- a/tests/activations/test_dataframe_integration.py +++ b/tests/activations/test_dataframe_integration.py @@ -15,7 +15,7 @@ ActivationVisualizationFieldRef, CombinedMappingSection, ) -from simplexity.analysis.linear_regression import layer_linear_regression_svd +from simplexity.analysis.linear_regression import layer_linear_regression from simplexity.exceptions import ConfigValidationError @@ -339,8 +339,8 @@ def test_linear_regression_projections_match_beliefs(self): beliefs_softmax = beliefs_softmax / beliefs_softmax.sum(axis=2, keepdims=True) belief_states = tuple(jnp.array(beliefs_softmax[:, f, :]) for f in range(n_factors)) - scalars, projections = layer_linear_regression_svd( - jnp.array(ds), jnp.ones(n_samples) / n_samples, belief_states, to_factors=True + scalars, projections = layer_linear_regression( + jnp.array(ds), jnp.ones(n_samples) / n_samples, belief_states, use_svd=True ) for f in range(n_factors): From 8e1efa4745d0a7cb263c0e09e66e91f947bc3ff4 Mon Sep 17 00:00:00 2001 From: loren-ac Date: Mon, 15 Dec 2025 19:02:26 -0800 Subject: [PATCH 48/48] Fix formatting issues --- tests/activations/test_activation_analysis.py | 1 - tests/analysis/test_layerwise_analysis.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/activations/test_activation_analysis.py b/tests/activations/test_activation_analysis.py index f51dd203..b2288716 100644 --- a/tests/activations/test_activation_analysis.py +++ b/tests/activations/test_activation_analysis.py @@ -1120,7 +1120,6 @@ def test_three_factor_tuple(self, factored_belief_data): assert result.belief_states[2].shape == (batch_size, 4) - class TestScalarSeriesMapping: """Tests for scalar_series dataframe construction.""" diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index c80b2a5c..5b2c9dbd 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -7,7 +7,6 @@ # (code quality, style, undefined names, etc.) to run normally while bypassing # the problematic imports checker that would crash during AST traversal. -import jax import jax.numpy as jnp import pytest