diff --git a/simplexity/activations/activation_analyses.py b/simplexity/activations/activation_analyses.py index 51fdb7d9..ce724574 100644 --- a/simplexity/activations/activation_analyses.py +++ b/simplexity/activations/activation_analyses.py @@ -84,7 +84,7 @@ def __init__( use_probs_as_weights: bool = True, skip_first_token: bool = False, fit_intercept: bool = True, - to_factors: bool = False, + concat_belief_states: bool = False, ) -> None: super().__init__( analysis_type="linear_regression", @@ -92,7 +92,7 @@ def __init__( concat_layers=concat_layers, use_probs_as_weights=use_probs_as_weights, skip_first_token=skip_first_token, - analysis_kwargs={"fit_intercept": fit_intercept, "to_factors": to_factors}, + analysis_kwargs={"fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states}, ) @@ -108,9 +108,9 @@ def __init__( skip_first_token: bool = False, rcond_values: Sequence[float] | None = None, fit_intercept: bool = True, - to_factors: bool = False, + concat_belief_states: bool = False, ) -> None: - analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "to_factors": to_factors} + analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states} if rcond_values is not None: analysis_kwargs["rcond_values"] = tuple(rcond_values) super().__init__( diff --git a/simplexity/analysis/layerwise_analysis.py b/simplexity/analysis/layerwise_analysis.py index e76e1c6d..60aa5cb7 100644 --- a/simplexity/analysis/layerwise_analysis.py +++ b/simplexity/analysis/layerwise_analysis.py @@ -1,21 +1,27 @@ """Composable layer-wise analysis orchestration.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + from __future__ import annotations from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass +from functools import partial from typing import Any import jax -from simplexity.analysis.linear_regression import ( - layer_linear_regression, - layer_linear_regression_svd, -) +from simplexity.analysis.linear_regression import layer_linear_regression from simplexity.analysis.pca import ( DEFAULT_VARIANCE_THRESHOLDS, layer_pca_analysis, ) +from simplexity.logger import SIMPLEXITY_LOGGER AnalysisFn = Callable[..., tuple[Mapping[str, float], Mapping[str, jax.Array]]] @@ -34,35 +40,48 @@ class AnalysisRegistration: def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: provided = dict(kwargs or {}) - allowed = {"fit_intercept", "to_factors"} + allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd", "rcond_values"} unexpected = set(provided) - allowed if unexpected: raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}") - fit_intercept = bool(provided.get("fit_intercept", True)) - to_factors = bool(provided.get("to_factors", False)) - return {"fit_intercept": fit_intercept, "to_factors": to_factors} - - -def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: - provided = dict(kwargs or {}) - allowed = {"fit_intercept", "rcond_values", "to_factors"} - unexpected = set(provided) - allowed - if unexpected: - raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}") - fit_intercept = bool(provided.get("fit_intercept", True)) - to_factors = bool(provided.get("to_factors", False)) + resolved_kwargs = {} + resolved_kwargs["fit_intercept"] = bool(provided.get("fit_intercept", True)) + resolved_kwargs["concat_belief_states"] = bool(provided.get("concat_belief_states", False)) + resolved_kwargs["compute_subspace_orthogonality"] = bool(provided.get("compute_subspace_orthogonality", False)) rcond_values = provided.get("rcond_values") - if rcond_values is not None: - if not isinstance(rcond_values, (list, tuple)): - raise TypeError("rcond_values must be a sequence of floats") - if len(rcond_values) == 0: - raise ValueError("rcond_values must not be empty") - rcond_values = tuple(float(v) for v in rcond_values) - return { - "fit_intercept": fit_intercept, - "to_factors": to_factors, - "rcond_values": rcond_values, - } + should_use_svd = rcond_values is not None + use_svd = bool(provided.get("use_svd", should_use_svd)) + resolved_kwargs["use_svd"] = use_svd + if use_svd: + if rcond_values is not None: + if not isinstance(rcond_values, (list, tuple)): + raise TypeError("rcond_values must be a sequence of floats") + if len(rcond_values) == 0: + raise ValueError("rcond_values must not be empty") + if not use_svd: + SIMPLEXITY_LOGGER.warning("rcond_values are only used when use_svd is True") + rcond_values = tuple(float(v) for v in rcond_values) + resolved_kwargs["rcond_values"] = rcond_values + elif rcond_values is not None: + raise ValueError("rcond_values are only used when use_svd is True") + return resolved_kwargs + + +def set_use_svd( + fn: ValidatorFn, +) -> ValidatorFn: + """Decorator to set use_svd to True in the kwargs and remove it from output to avoid duplicate with partial.""" + + def wrapper(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + if kwargs and "use_svd" in kwargs and not kwargs["use_svd"]: + raise ValueError("use_svd cannot be set to False for linear_regression_svd") + modified_kwargs = dict(kwargs) if kwargs else {} # Make a copy to avoid mutating the input + modified_kwargs["use_svd"] = True + resolved = fn(modified_kwargs) + resolved.pop("use_svd", None) # Remove use_svd to avoid duplicate argument with partial + return resolved + + return wrapper def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: @@ -97,9 +116,9 @@ def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: validator=_validate_linear_regression_kwargs, ), "linear_regression_svd": AnalysisRegistration( - fn=layer_linear_regression_svd, + fn=partial(layer_linear_regression, use_svd=True), requires_belief_states=True, - validator=_validate_linear_regression_svd_kwargs, + validator=set_use_svd(_validate_linear_regression_kwargs), ), "pca": AnalysisRegistration( fn=layer_pca_analysis, diff --git a/simplexity/analysis/linear_regression.py b/simplexity/analysis/linear_regression.py index 1ce5c086..a0ef6eee 100644 --- a/simplexity/analysis/linear_regression.py +++ b/simplexity/analysis/linear_regression.py @@ -1,15 +1,25 @@ """Reusable linear regression utilities for activation analysis.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + from __future__ import annotations +import itertools from collections.abc import Callable, Mapping, Sequence from typing import Any import jax import jax.numpy as jnp import numpy as np +from jax.debug import callback from simplexity.analysis.normalization import normalize_weights, standardize_features, standardize_targets +from simplexity.logger import SIMPLEXITY_LOGGER def _design_matrix(x: jax.Array, fit_intercept: bool) -> jax.Array: @@ -69,8 +79,46 @@ def linear_regression( beta, _, _, _ = jnp.linalg.lstsq(weighted_design, weighted_targets, rcond=None) predictions = design @ beta scalars = _regression_metrics(predictions, y_arr, w_arr) - projections = {"projected": predictions} - return scalars, projections + + # Separate intercept and coefficients + if fit_intercept: + arrays = { + "projected": predictions, + "coeffs": beta[1:], # Linear coefficients (excluding intercept) + "intercept": beta[:1], # Intercept term (keep 2D: [1, n_targets]) + } + else: + arrays = { + "projected": predictions, + "coeffs": beta, # All parameters are coefficients when no intercept + } + + return scalars, arrays + + +def _compute_regression_metrics( # pylint: disable=too-many-arguments + x: jax.Array, + y: jax.Array, + weights: jax.Array | np.ndarray | None, + beta: jax.Array, + predictions: jax.Array | None = None, + *, + fit_intercept: bool = True, +) -> Mapping[str, float]: + x_arr = standardize_features(x) + y_arr = standardize_targets(y) + if x_arr.shape[0] != y_arr.shape[0]: + raise ValueError("Features and targets must share the same first dimension") + if x_arr.shape[0] == 0: + raise ValueError("At least one sample is required") + w_arr = normalize_weights(weights, x_arr.shape[0]) + if w_arr is None: + w_arr = jnp.ones(x_arr.shape[0], dtype=x_arr.dtype) / x_arr.shape[0] + if predictions is None: + design = _design_matrix(x_arr, fit_intercept) + predictions = design @ beta + scalars = _regression_metrics(predictions, y_arr, w_arr) + return scalars def _compute_beta_from_svd( @@ -115,6 +163,7 @@ def linear_regression_svd( best_scalars: Mapping[str, float] | None = None best_rcond = rconds[0] best_error = float("inf") + best_beta: jax.Array | None = None for rcond in rconds: threshold = rcond * max_singular beta = _compute_beta_from_svd(u, s, vh, weighted_targets, threshold) @@ -128,12 +177,306 @@ def linear_regression_svd( best_pred = predictions best_scalars = scalars best_rcond = rcond - if best_pred is None or best_scalars is None: + best_beta = beta + if best_pred is None or best_scalars is None or best_beta is None: raise RuntimeError("Unable to compute linear regression solution") scalars = dict(best_scalars) scalars["best_rcond"] = float(best_rcond) - projections = {"projected": best_pred} - return scalars, projections + + # Separate intercept and coefficients + if fit_intercept: + arrays = { + "projected": best_pred, + "coeffs": best_beta[1:], # Linear coefficients (excluding intercept) + "intercept": best_beta[:1], # Intercept term (keep 2D: [1, n_targets]) + } + else: + arrays = { + "projected": best_pred, + "coeffs": best_beta, # All parameters are coefficients when no intercept + } + + return scalars, arrays + + +def _process_individual_factors( + layer_activations: jax.Array, + belief_states: tuple[jax.Array, ...], + weights: jax.Array, + use_svd: bool, + **kwargs: Any, +) -> list[tuple[Mapping[str, float], Mapping[str, jax.Array]]]: + """Process each factor individually using either standard or SVD regression.""" + results = [] + regression_fn = linear_regression_svd if use_svd else linear_regression + for factor in belief_states: + if not isinstance(factor, jax.Array): + raise ValueError("Each factor in belief_states must be a jax.Array") + factor_scalars, factor_arrays = regression_fn(layer_activations, factor, weights, **kwargs) + results.append((factor_scalars, factor_arrays)) + return results + + +def _merge_results_with_prefix( + scalars: dict[str, float], + arrays: dict[str, jax.Array], + results: tuple[Mapping[str, float], Mapping[str, jax.Array]], + prefix: str, +) -> None: + results_scalars, results_arrays = results + scalars.update({f"{prefix}/{key}": value for key, value in results_scalars.items()}) + arrays.update({f"{prefix}/{key}": value for key, value in results_arrays.items()}) + + +def _split_concat_results( + layer_activations: jax.Array, + weights: jax.Array, + belief_states: tuple[jax.Array, ...], + concat_results: tuple[Mapping[str, float], Mapping[str, jax.Array]], + **kwargs: Any, +) -> list[tuple[Mapping[str, float], Mapping[str, jax.Array]]]: + """Split concatenated regression results into individual factors.""" + _, concat_arrays = concat_results + + # Split the concatenated coefficients and projections into the individual factors + factor_dims = [factor.shape[-1] for factor in belief_states] + split_indices = jnp.cumsum(jnp.array(factor_dims))[:-1] + + coeffs_list = jnp.split(concat_arrays["coeffs"], split_indices, axis=-1) + projections_list = jnp.split(concat_arrays["projected"], split_indices, axis=-1) + + # Handle intercept - split if present + if "intercept" in concat_arrays: + intercepts_list = jnp.split(concat_arrays["intercept"], split_indices, axis=-1) + else: + intercepts_list = [None] * len(belief_states) + + # Only recompute scalar metrics, reuse projections and coefficients + # Filter out rcond_values from kwargs (only relevant for SVD during fitting, not metrics) + metrics_kwargs = {k: v for k, v in kwargs.items() if k != "rcond_values"} + + results = [] + for factor, coeffs, intercept, projections in zip( + belief_states, coeffs_list, intercepts_list, projections_list, strict=True + ): + # Reconstruct full beta for metrics computation + if intercept is not None: + beta = jnp.concatenate([intercept, coeffs], axis=0) + else: + beta = coeffs + + factor_scalars = _compute_regression_metrics( + layer_activations, + factor, + weights, + beta, + predictions=projections, + **metrics_kwargs, + ) + + # Build factor arrays - include intercept only if present + factor_arrays = {"projected": projections, "coeffs": coeffs} + if intercept is not None: + factor_arrays["intercept"] = intercept + + results.append((factor_scalars, factor_arrays)) + return results + + +def get_robust_basis(matrix: jax.Array) -> jax.Array: + """Extracts an orthonormal basis for the column space of the matrx. + + Handles rank deficiency gracefully by discarding directions associated with singular values below a + certain tolerance. + """ + u, s, _ = jnp.linalg.svd(matrix, full_matrices=False) + + max_dim = max(matrix.shape) + eps = jnp.finfo(matrix.dtype).eps + tol = s[0] * max_dim * eps + + valid_dims = s > tol + basis = u[:, valid_dims] + return basis + + +def _compute_subspace_orthogonality( + basis_pair: list[jax.Array], +) -> tuple[dict[str, float], dict[str, jax.Array]]: + """Compute orthogonality metrics between two coefficient subspaces. + + Args: + basis_pair: List of two orthonormal basis matrices + + Returns: + Tuple[dict[str, float], dict[str, jax.Array]]: A tuple containing: + - scalars: A dictionary with the following keys and float values: + - 'subspace_overlap': Average squared singular value (overlap score). + - 'max_singular_value': Largest singular value. + - 'min_singular_value': Smallest singular value. + - 'participation_ratio': Participation ratio of the singular values. + - 'entropy': Entropy of the squared singular values. + - 'effective_rank': Effective rank (exp(entropy)) of the singular value distribution. + - singular_values: A dictionary with a single key: + - 'singular_values': jax.Array of the singular values between the two subspaces. + """ + q1 = basis_pair[0] + q2 = basis_pair[1] + + # Compute the singular values of the interaction matrix + interaction_matrix = q1.T @ q2 + singular_values = jnp.linalg.svd(interaction_matrix, compute_uv=False) + singular_values = jnp.clip(singular_values, 0, 1) + + # Compute the subspace overlap score + min_dim = min(q1.shape[1], q2.shape[1]) + sum_sq_sv = jnp.sum(singular_values**2) + sum_quad_sv = jnp.sum(singular_values**4) + + is_degenerate = sum_quad_sv == 0 + + # Define the False branch function (does nothing) + def do_nothing_branch(x): + """JAX 'False' branch function. + + Serves only to return a value that matches the 'True' branch's type (None) for jax.lax.cond. + """ + return None + + # Define the True branch function (runs the callback) + def execute_all_zeros_warning_branch(x): + callback(log_all_zeros, x) + return None + + def log_all_zeros(_): + SIMPLEXITY_LOGGER.warning( + "Degenerate subspace detected during orthogonality computation." + " All singular values are zero." + " Setting probability values and participation ratio to zero." + ) + + jax.lax.cond(is_degenerate, execute_all_zeros_warning_branch, do_nothing_branch, sum_sq_sv) + + pratio_denominator_safe = jnp.where(is_degenerate, 1.0, sum_quad_sv) + probs_denominator_safe = jnp.where(is_degenerate, 1.0, sum_sq_sv) + participation_ratio = sum_sq_sv**2 / pratio_denominator_safe + + subspace_overlap_score = sum_sq_sv / min_dim + + # Compute the entropy probabilities + probs = singular_values**2 / probs_denominator_safe + + def execute_some_zeros_warning_branch(x): + callback(log_some_zeros, x) + return None + + def log_some_zeros(num_zeros_array: jax.Array) -> None: + num_zeros = num_zeros_array.item() + SIMPLEXITY_LOGGER.warning( + f"Encountered {num_zeros} probability values of zero during entropy computation." + " This is likely due to numerical instability." + " Setting corresponding entropy contribution to zero." + ) + + num_zeros = jnp.sum(probs == 0) + has_some_zeros = num_zeros > 0 + jax.lax.cond(has_some_zeros, execute_some_zeros_warning_branch, do_nothing_branch, num_zeros) + + p_log_p = probs * jnp.log(probs) + entropy = -jnp.sum(jnp.where(probs > 0, p_log_p, 0.0)) + + # Compute the effective rank + effective_rank = jnp.exp(entropy) + + scalars = { + "subspace_overlap": float(subspace_overlap_score), + "max_singular_value": float(jnp.max(singular_values)), + "min_singular_value": float(jnp.min(singular_values)), + "participation_ratio": float(participation_ratio), + "entropy": float(entropy), + "effective_rank": float(effective_rank), + } + + arrays = { + "singular_values": singular_values, + } + + return scalars, arrays + + +def _compute_all_pairwise_orthogonality( + coeffs_list: list[jax.Array], +) -> tuple[dict[str, float], dict[str, jax.Array]]: + """Compute pairwise orthogonality metrics for all factor pairs. + + Args: + coeffs_list: List of coefficient matrices (one per factor, excludes intercepts) + + Returns: + Tuple[dict[str, float], dict[str, jax.Array]]: + - scalars: Dictionary mapping keys of the form "orthogonality_{i}_{j}/" to scalar float metrics for + each pair of factors (i, j). + - arrays: Dictionary mapping keys of the form "orthogonality_{i}_{j}/" to array-valued + metrics for each pair of factors (i, j). + """ + scalars = {} + arrays = {} + factor_pairs = list(itertools.combinations(range(len(coeffs_list)), 2)) + basis_list = [get_robust_basis(coeffs) for coeffs in coeffs_list] # computes orthonormal basis of coeff matrix + for i, j in factor_pairs: + basis_pair = [basis_list[i], basis_list[j]] + orthogonality_scalars, orthogonality_arrays = _compute_subspace_orthogonality(basis_pair) + scalars.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_scalars.items()}) + arrays.update({f"orthogonality_{i}_{j}/{key}": value for key, value in orthogonality_arrays.items()}) + return scalars, arrays + + +def _handle_factored_regression( + layer_activations: jax.Array, + weights: jax.Array, + belief_states: tuple[jax.Array, ...], + concat_belief_states: bool, + compute_subspace_orthogonality: bool, + use_svd: bool, + **kwargs: Any, +) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: + """Handle regression for two or more factored belief states using either standard or SVD method.""" + if len(belief_states) < 2: + raise ValueError("At least two factors are required for factored regression") + + scalars: dict[str, float] = {} + arrays: dict[str, jax.Array] = {} + + regression_fn = linear_regression_svd if use_svd else linear_regression + + # Process concatenated belief states if requested + if concat_belief_states: + belief_states_concat = jnp.concatenate(belief_states, axis=-1) + concat_results = regression_fn(layer_activations, belief_states_concat, weights, **kwargs) + _merge_results_with_prefix(scalars, arrays, concat_results, "concat") + + # Split the concatenated parameters and projections into the individual factors + factor_results = _split_concat_results( + layer_activations, + weights, + belief_states, + concat_results, + **kwargs, + ) + else: + factor_results = _process_individual_factors(layer_activations, belief_states, weights, use_svd, **kwargs) + + for factor_idx, factor_result in enumerate(factor_results): + _merge_results_with_prefix(scalars, arrays, factor_result, f"factor_{factor_idx}") + + if compute_subspace_orthogonality: + # Extract coefficients (excludes intercept) for orthogonality computation + coeffs_list = [factor_arrays["coeffs"] for _, factor_arrays in factor_results] + orthogonality_scalars, orthogonality_singular_values = _compute_all_pairwise_orthogonality(coeffs_list) + scalars.update(orthogonality_scalars) + arrays.update(orthogonality_singular_values) + + return scalars, arrays def _apply_layer_regression( @@ -168,25 +511,53 @@ def layer_linear_regression( layer_activations: jax.Array, weights: jax.Array, belief_states: jax.Array | tuple[jax.Array, ...] | None, - to_factors: bool = False, + concat_belief_states: bool = False, + compute_subspace_orthogonality: bool = False, + use_svd: bool = False, **kwargs: Any, ) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """Layer-wise regression helper that wraps :func:`linear_regression`.""" - if belief_states is None: + """Layer-wise regression helper that wraps :func:`linear_regression` or :func:`linear_regression_svd`. + + Args: + layer_activations: Neural network activations for a single layer + weights: Sample weights for weighted regression + belief_states: Target belief states (single array or tuple for factored processes) + concat_belief_states: If True and belief_states is a tuple, concatenate and regress jointly + compute_subspace_orthogonality: If True, compute orthogonality between factor subspaces + use_svd: If True, use SVD-based regression instead of standard least squares + **kwargs: Additional arguments passed to regression function (fit_intercept, rcond_values, etc.) + + Returns: + scalars: Dictionary of scalar metrics + arrays: Dictionary of arrays (projected predictions, parameters, singular values if orthogonality computed) + """ + # If no belief states are provided, raise an error + if ( + belief_states is None + or (isinstance(belief_states, tuple) and len(belief_states) == 0) + or (isinstance(belief_states, jax.Array) and belief_states.size == 0) + ): raise ValueError("linear_regression requires belief_states") - return _apply_layer_regression(linear_regression, layer_activations, weights, belief_states, to_factors, **kwargs) + regression_fn = linear_regression_svd if use_svd else linear_regression -def layer_linear_regression_svd( - layer_activations: jax.Array, - weights: jax.Array, - belief_states: jax.Array | tuple[jax.Array, ...] | None, - to_factors: bool = False, - **kwargs: Any, -) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]: - """Layer-wise regression helper that wraps :func:`linear_regression_svd`.""" - if belief_states is None: - raise ValueError("linear_regression_svd requires belief_states") - return _apply_layer_regression( - linear_regression_svd, layer_activations, weights, belief_states, to_factors, **kwargs + if not isinstance(belief_states, tuple) or len(belief_states) == 1: + if compute_subspace_orthogonality: + SIMPLEXITY_LOGGER.warning( + "Subspace orthogonality requires multiple factors." + " Received single factor of type %s; skipping orthogonality metrics.", + type(belief_states).__name__, + ) + belief_states = belief_states[0] if isinstance(belief_states, tuple) else belief_states + scalars, arrays = regression_fn(layer_activations, belief_states, weights, **kwargs) + return scalars, arrays + + return _handle_factored_regression( + layer_activations, + weights, + belief_states, + concat_belief_states, + compute_subspace_orthogonality, + use_svd, + **kwargs, ) diff --git a/tests/activations/test_activation_analysis.py b/tests/activations/test_activation_analysis.py index b31c8a8a..b2288716 100644 --- a/tests/activations/test_activation_analysis.py +++ b/tests/activations/test_activation_analysis.py @@ -740,6 +740,386 @@ def test_controls_accumulate_steps_conflict(self): ActivationVisualizationControlsConfig(slider="step", accumulate_steps=True) +class TestTupleBeliefStates: + """Test activation tracker with tuple belief states for factored processes.""" + + @pytest.fixture + def factored_belief_data(self): + """Create synthetic data with factored belief states.""" + batch_size = 4 + seq_len = 5 + d_layer0 = 8 + d_layer1 = 12 + + inputs = jnp.array( + [ + [1, 2, 3, 4, 5], + [1, 2, 3, 6, 7], + [1, 2, 8, 9, 10], + [1, 2, 3, 4, 11], + ] + ) + + # Factored beliefs: 2 factors with dimensions 3 and 2 + factor_0 = jnp.ones((batch_size, seq_len, 3)) * 0.3 + factor_1 = jnp.ones((batch_size, seq_len, 2)) * 0.7 + factored_beliefs = (factor_0, factor_1) + + probs = jnp.ones((batch_size, seq_len)) * 0.1 + + activations = { + "layer_0": jnp.ones((batch_size, seq_len, d_layer0)) * 0.3, + "layer_1": jnp.ones((batch_size, seq_len, d_layer1)) * 0.7, + } + + return { + "inputs": inputs, + "factored_beliefs": factored_beliefs, + "probs": probs, + "activations": activations, + "batch_size": batch_size, + "seq_len": seq_len, + "factor_0_dim": 3, + "factor_1_dim": 2, + "d_layer0": d_layer0, + "d_layer1": d_layer1, + } + + def test_prepare_activations_accepts_tuple_beliefs(self, factored_belief_data): + """prepare_activations should accept and preserve tuple belief states.""" + result = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + + batch_size = factored_belief_data["batch_size"] + assert result.belief_states[0].shape == (batch_size, factored_belief_data["factor_0_dim"]) + assert result.belief_states[1].shape == (batch_size, factored_belief_data["factor_1_dim"]) + + def test_prepare_activations_tuple_beliefs_all_tokens(self, factored_belief_data): + """Tuple beliefs should work with all tokens mode.""" + result = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=False, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + + # With deduplication, we expect fewer samples than batch_size * seq_len + n_prefixes = result.belief_states[0].shape[0] + assert result.belief_states[0].shape == (n_prefixes, factored_belief_data["factor_0_dim"]) + assert result.belief_states[1].shape == (n_prefixes, factored_belief_data["factor_1_dim"]) + assert result.activations["layer_0"].shape[0] == n_prefixes + + def test_prepare_activations_torch_tuple_beliefs(self, factored_belief_data): + """prepare_activations should accept tuple of PyTorch tensors.""" + torch = pytest.importorskip("torch") + + torch_factor_0 = torch.tensor(np.asarray(factored_belief_data["factored_beliefs"][0])) + torch_factor_1 = torch.tensor(np.asarray(factored_belief_data["factored_beliefs"][1])) + torch_beliefs = (torch_factor_0, torch_factor_1) + + result = prepare_activations( + factored_belief_data["inputs"], + torch_beliefs, + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + # Should be converted to JAX arrays + assert isinstance(result.belief_states[0], jnp.ndarray) + assert isinstance(result.belief_states[1], jnp.ndarray) + + def test_prepare_activations_numpy_tuple_beliefs(self, factored_belief_data): + """prepare_activations should accept tuple of numpy arrays.""" + np_factor_0 = np.asarray(factored_belief_data["factored_beliefs"][0]) + np_factor_1 = np.asarray(factored_belief_data["factored_beliefs"][1]) + np_beliefs = (np_factor_0, np_factor_1) + + result = prepare_activations( + factored_belief_data["inputs"], + np_beliefs, + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 2 + # Should be converted to JAX arrays + assert isinstance(result.belief_states[0], jnp.ndarray) + assert isinstance(result.belief_states[1], jnp.ndarray) + + def test_linear_regression_with_multiple_factors(self, factored_belief_data): + """LinearRegressionAnalysis with multi-factor tuple should regress to each factor separately.""" + analysis = LinearRegressionAnalysis() + + prepared = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have separate metrics for each factor + # Format is: layer_name_factor_idx/metric_name + assert "layer_0_factor_0/r2" in scalars + assert "layer_0_factor_1/r2" in scalars + assert "layer_0_factor_0/rmse" in scalars + assert "layer_0_factor_1/rmse" in scalars + assert "layer_0_factor_0/mae" in scalars + assert "layer_0_factor_1/mae" in scalars + assert "layer_0_factor_0/dist" in scalars + assert "layer_0_factor_1/dist" in scalars + + assert "layer_1_factor_0/r2" in scalars + assert "layer_1_factor_1/r2" in scalars + + # Should have separate projections for each factor + assert "layer_0_factor_0/projected" in projections + assert "layer_0_factor_1/projected" in projections + assert "layer_1_factor_0/projected" in projections + assert "layer_1_factor_1/projected" in projections + + # Check projection shapes + batch_size = factored_belief_data["batch_size"] + assert projections["layer_0_factor_0/projected"].shape == (batch_size, factored_belief_data["factor_0_dim"]) + assert projections["layer_0_factor_1/projected"].shape == (batch_size, factored_belief_data["factor_1_dim"]) + + def test_linear_regression_svd_with_multiple_factors(self, factored_belief_data): + """LinearRegressionSVDAnalysis with multi-factor tuple should regress to each factor separately.""" + analysis = LinearRegressionSVDAnalysis(rcond_values=[1e-10]) + + prepared = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have separate metrics for each factor including best_rcond + assert "layer_0_factor_0/r2" in scalars + assert "layer_0_factor_1/r2" in scalars + assert "layer_0_factor_0/best_rcond" in scalars + assert "layer_0_factor_1/best_rcond" in scalars + + # Should have separate projections for each factor + assert "layer_0_factor_0/projected" in projections + assert "layer_0_factor_1/projected" in projections + + def test_tracker_with_factored_beliefs(self, factored_belief_data): + """ActivationTracker should work with tuple belief states.""" + tracker = ActivationTracker( + { + "regression": LinearRegressionAnalysis( + last_token_only=True, + concat_layers=False, + ), + "pca": PcaAnalysis( + n_components=2, + last_token_only=True, + concat_layers=False, + ), + } + ) + + scalars, projections, _ = tracker.analyze( + inputs=factored_belief_data["inputs"], + beliefs=factored_belief_data["factored_beliefs"], + probs=factored_belief_data["probs"], + activations=factored_belief_data["activations"], + ) + + # Regression should have per-factor metrics + assert "regression/layer_0_factor_0/r2" in scalars + assert "regression/layer_0_factor_1/r2" in scalars + + # PCA should still work (doesn't use belief states) + assert "pca/layer_0_variance_explained" in scalars + + # Projections should be present + assert "regression/layer_0_factor_0/projected" in projections + assert "regression/layer_0_factor_1/projected" in projections + assert "pca/layer_0_pca" in projections + + def test_single_factor_tuple(self, synthetic_data): + """Test with a single-factor tuple (edge case).""" + # Create single-factor tuple + single_factor = (synthetic_data["beliefs"],) + + result = prepare_activations( + synthetic_data["inputs"], + single_factor, + synthetic_data["probs"], + synthetic_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 1 + assert result.belief_states[0].shape == (synthetic_data["batch_size"], synthetic_data["belief_dim"]) + + def test_linear_regression_single_factor_tuple_behaves_like_non_tuple(self, synthetic_data): + """LinearRegressionAnalysis with single-factor tuple should behave like non-tuple (no factor keys).""" + single_factor = (synthetic_data["beliefs"],) + analysis = LinearRegressionAnalysis() + + prepared = prepare_activations( + synthetic_data["inputs"], + single_factor, + synthetic_data["probs"], + synthetic_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have simple keys without "factor_" prefix + assert "layer_0_r2" in scalars + assert "layer_0_rmse" in scalars + assert "layer_0_projected" in projections + + # Should NOT have factor keys + assert "layer_0_factor_0/r2" not in scalars + assert "layer_0_factor_0/projected" not in projections + + def test_linear_regression_concat_belief_states(self, factored_belief_data): + """LinearRegressionAnalysis with concat_belief_states=True should return both factor and concat results.""" + analysis = LinearRegressionAnalysis(concat_belief_states=True) + + prepared = prepare_activations( + factored_belief_data["inputs"], + factored_belief_data["factored_beliefs"], + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + scalars, projections = analysis.analyze( + activations=prepared.activations, + belief_states=prepared.belief_states, + weights=prepared.weights, + ) + + # Should have per-factor results + assert "layer_0_factor_0/r2" in scalars + assert "layer_0_factor_1/r2" in scalars + assert "layer_0_factor_0/projected" in projections + assert "layer_0_factor_1/projected" in projections + + # Should ALSO have concatenated results + assert "layer_0_concat/r2" in scalars + assert "layer_0_concat/rmse" in scalars + assert "layer_0_concat/projected" in projections + + # Check concatenated projection shape (should be sum of factor dimensions) + batch_size = factored_belief_data["batch_size"] + total_dim = factored_belief_data["factor_0_dim"] + factored_belief_data["factor_1_dim"] + assert projections["layer_0_concat/projected"].shape == (batch_size, total_dim) + + def test_three_factor_tuple(self, factored_belief_data): + """Test with three factors to ensure generalization.""" + batch_size = factored_belief_data["batch_size"] + seq_len = factored_belief_data["seq_len"] + + # Add a third factor + factor_0 = jnp.ones((batch_size, seq_len, 3)) * 0.3 + factor_1 = jnp.ones((batch_size, seq_len, 2)) * 0.5 + factor_2 = jnp.ones((batch_size, seq_len, 4)) * 0.7 + three_factor_beliefs = (factor_0, factor_1, factor_2) + + result = prepare_activations( + factored_belief_data["inputs"], + three_factor_beliefs, + factored_belief_data["probs"], + factored_belief_data["activations"], + prepare_options=PrepareOptions( + last_token_only=True, + concat_layers=False, + use_probs_as_weights=False, + ), + ) + + assert result.belief_states is not None + assert isinstance(result.belief_states, tuple) + assert len(result.belief_states) == 3 + assert result.belief_states[0].shape == (batch_size, 3) + assert result.belief_states[1].shape == (batch_size, 2) + assert result.belief_states[2].shape == (batch_size, 4) + + class TestScalarSeriesMapping: """Tests for scalar_series dataframe construction.""" diff --git a/tests/activations/test_dataframe_integration.py b/tests/activations/test_dataframe_integration.py index 4ca363fc..1fa8727d 100644 --- a/tests/activations/test_dataframe_integration.py +++ b/tests/activations/test_dataframe_integration.py @@ -15,7 +15,7 @@ ActivationVisualizationFieldRef, CombinedMappingSection, ) -from simplexity.analysis.linear_regression import layer_linear_regression_svd +from simplexity.analysis.linear_regression import layer_linear_regression from simplexity.exceptions import ConfigValidationError @@ -339,8 +339,8 @@ def test_linear_regression_projections_match_beliefs(self): beliefs_softmax = beliefs_softmax / beliefs_softmax.sum(axis=2, keepdims=True) belief_states = tuple(jnp.array(beliefs_softmax[:, f, :]) for f in range(n_factors)) - scalars, projections = layer_linear_regression_svd( - jnp.array(ds), jnp.ones(n_samples) / n_samples, belief_states, to_factors=True + scalars, projections = layer_linear_regression( + jnp.array(ds), jnp.ones(n_samples) / n_samples, belief_states, use_svd=True ) for f in range(n_factors): diff --git a/tests/analysis/test_layerwise_analysis.py b/tests/analysis/test_layerwise_analysis.py index c0d1a839..5b2c9dbd 100644 --- a/tests/analysis/test_layerwise_analysis.py +++ b/tests/analysis/test_layerwise_analysis.py @@ -1,5 +1,12 @@ """Tests for the LayerwiseAnalysis orchestrator.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + import jax.numpy as jnp import pytest @@ -9,6 +16,7 @@ @pytest.fixture def analysis_inputs() -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jnp.ndarray]: """Provides sample activations, weights, and belief states for analysis tests.""" + activations = { "layer_a": jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), "layer_b": jnp.array([[2.0, 1.0], [1.0, 2.0], [0.0, 1.0]]), @@ -20,6 +28,7 @@ def analysis_inputs() -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jnp.ndarray] def test_layerwise_analysis_linear_regression_namespacing(analysis_inputs) -> None: """Metrics and projections should be namespace-qualified per layer.""" + activations, weights, belief_states = analysis_inputs analysis = LayerwiseAnalysis("linear_regression", last_token_only=True) @@ -30,11 +39,19 @@ def test_layerwise_analysis_linear_regression_namespacing(analysis_inputs) -> No ) assert set(scalars) >= {"layer_a_r2", "layer_b_r2"} - assert set(projections) == {"layer_a_projected", "layer_b_projected"} + assert set(projections) == { + "layer_a_projected", + "layer_b_projected", + "layer_a_coeffs", + "layer_b_coeffs", + "layer_a_intercept", + "layer_b_intercept", + } def test_layerwise_analysis_requires_targets(analysis_inputs) -> None: """Analyses that need belief states should validate input.""" + activations, weights, _ = analysis_inputs analysis = LayerwiseAnalysis("linear_regression") @@ -44,12 +61,14 @@ def test_layerwise_analysis_requires_targets(analysis_inputs) -> None: def test_invalid_analysis_type_raises() -> None: """Unknown analysis types should raise clear errors.""" + with pytest.raises(ValueError, match="Unknown analysis_type"): LayerwiseAnalysis("unknown") def test_invalid_kwargs_validation() -> None: """Validator rejects unsupported kwargs for a registered analysis.""" + with pytest.raises(ValueError, match="Unexpected linear_regression kwargs"): LayerwiseAnalysis( "linear_regression", @@ -59,6 +78,7 @@ def test_invalid_kwargs_validation() -> None: def test_pca_analysis_does_not_require_beliefs(analysis_inputs) -> None: """PCA analysis should run without belief states and namespace results.""" + activations, weights, _ = analysis_inputs analysis = LayerwiseAnalysis( "pca", @@ -76,6 +96,7 @@ def test_pca_analysis_does_not_require_beliefs(analysis_inputs) -> None: def test_invalid_pca_kwargs() -> None: """Invalid PCA kwargs should raise helpful errors.""" + with pytest.raises(ValueError, match="n_components must be positive"): LayerwiseAnalysis( "pca", @@ -85,6 +106,7 @@ def test_invalid_pca_kwargs() -> None: def test_linear_regression_svd_kwargs_validation_errors() -> None: """SVD-specific validators should reject unsupported inputs.""" + with pytest.raises(TypeError, match="rcond_values must be a sequence"): LayerwiseAnalysis( "linear_regression_svd", @@ -100,7 +122,8 @@ def test_linear_regression_svd_kwargs_validation_errors() -> None: def test_linear_regression_svd_rejects_unexpected_kwargs() -> None: """Unexpected SVD kwargs should raise clear errors.""" - with pytest.raises(ValueError, match="Unexpected linear_regression_svd kwargs"): + + with pytest.raises(ValueError, match="Unexpected linear_regression kwargs"): LayerwiseAnalysis( "linear_regression_svd", analysis_kwargs={"bad": True}, @@ -109,6 +132,7 @@ def test_linear_regression_svd_rejects_unexpected_kwargs() -> None: def test_linear_regression_svd_kwargs_are_normalized() -> None: """Validator should coerce mixed numeric types to floats.""" + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator params = validator({"rcond_values": [1, 1e-3]}) @@ -117,6 +141,7 @@ def test_linear_regression_svd_kwargs_are_normalized() -> None: def test_pca_kwargs_require_int_components() -> None: """PCA validator should enforce integral n_components.""" + with pytest.raises(TypeError, match="n_components must be an int or None"): LayerwiseAnalysis( "pca", @@ -126,6 +151,7 @@ def test_pca_kwargs_require_int_components() -> None: def test_pca_kwargs_require_sequence_thresholds() -> None: """Variance thresholds must be sequences with valid ranges.""" + with pytest.raises(TypeError, match="variance_thresholds must be a sequence"): LayerwiseAnalysis( "pca", @@ -141,6 +167,7 @@ def test_pca_kwargs_require_sequence_thresholds() -> None: def test_pca_rejects_unexpected_kwargs() -> None: """Unexpected PCA kwargs should surface informative errors.""" + with pytest.raises(ValueError, match="Unexpected pca kwargs"): LayerwiseAnalysis( "pca", @@ -150,6 +177,7 @@ def test_pca_rejects_unexpected_kwargs() -> None: def test_layerwise_analysis_property_accessors() -> None: """Constructor flags should surface via property accessors.""" + analysis = LayerwiseAnalysis( "pca", last_token_only=True, @@ -161,3 +189,103 @@ def test_layerwise_analysis_property_accessors() -> None: assert analysis.concat_layers assert not analysis.use_probs_as_weights assert not analysis.requires_belief_states + + +def test_linear_regression_accepts_concat_belief_states() -> None: + """linear_regression validator should accept concat_belief_states parameter.""" + + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": False, "concat_belief_states": True}) + + assert params["fit_intercept"] is False + assert params["concat_belief_states"] is True + + +def test_linear_regression_svd_accepts_concat_belief_states() -> None: + """linear_regression_svd validator should accept concat_belief_states parameter.""" + + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + params = validator({"fit_intercept": True, "concat_belief_states": True, "rcond_values": [1e-3]}) + + assert params["fit_intercept"] is True + assert params["concat_belief_states"] is True + assert params["rcond_values"] == (0.001,) + + +def test_linear_regression_concat_belief_states_defaults_false() -> None: + """concat_belief_states should default to False when not provided.""" + + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": True}) + + assert params["concat_belief_states"] is False + + +def test_linear_regression_accepts_compute_subspace_orthogonality() -> None: + """linear_regression validator should accept compute_subspace_orthogonality parameter.""" + + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": True, "compute_subspace_orthogonality": True}) + + assert params["fit_intercept"] is True + assert params["compute_subspace_orthogonality"] is True + + +def test_linear_regression_svd_accepts_compute_subspace_orthogonality() -> None: + """linear_regression_svd validator should accept compute_subspace_orthogonality parameter.""" + + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + params = validator({"fit_intercept": True, "compute_subspace_orthogonality": True, "rcond_values": [1e-3]}) + + assert params["fit_intercept"] is True + assert params["compute_subspace_orthogonality"] is True + assert params["rcond_values"] == (0.001,) + + +def test_linear_regression_svd_rejects_false_use_svd() -> None: + """linear_regression_svd validator should reject explicit use_svd parameter since it's bound in partial.""" + + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + + validator({"use_svd": True}) + + with pytest.raises(ValueError, match="use_svd cannot be set to False for linear_regression_svd"): + validator({"use_svd": False}) + + +def test_linear_regression_svd_excludes_use_svd_from_output() -> None: + """linear_regression_svd validator should not include use_svd in resolved kwargs.""" + + validator = ANALYSIS_REGISTRY["linear_regression_svd"].validator + params = validator({"rcond_values": [1e-3]}) + + # use_svd should not be in the output since it's already bound in the partial + assert "use_svd" not in params + assert params["rcond_values"] == (0.001,) + + +def test_linear_regression_compute_subspace_orthogonality_defaults_false() -> None: + """compute_subspace_orthogonality should default to False when not provided.""" + + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"fit_intercept": True}) + + assert params["compute_subspace_orthogonality"] is False + + +def test_linear_regression_accepts_use_svd() -> None: + """linear_regression validator should accept use_svd parameter.""" + + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({"use_svd": True}) + + assert params["use_svd"] is True + + +def test_linear_regression_use_svd_defaults_false() -> None: + """use_svd should default to False when not provided.""" + + validator = ANALYSIS_REGISTRY["linear_regression"].validator + params = validator({}) + + assert params["use_svd"] is False diff --git a/tests/analysis/test_linear_regression.py b/tests/analysis/test_linear_regression.py index f6bfe084..c32766c1 100644 --- a/tests/analysis/test_linear_regression.py +++ b/tests/analysis/test_linear_regression.py @@ -1,29 +1,66 @@ """Tests for reusable linear regression helpers.""" +# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +# pylint: disable=too-many-lines +# pylint: disable=too-many-locals + import chex +import jax import jax.numpy as jnp import pytest from simplexity.analysis.linear_regression import ( + get_robust_basis, layer_linear_regression, - layer_linear_regression_svd, linear_regression, linear_regression_svd, ) +def _compute_orthogonality_threshold( + x: jax.Array, + *factors: jax.Array, + safety_factor: int = 10, +) -> float: + """Compute principled threshold for near-zero orthogonality checks. + + Threshold is based on machine precision scaled by problem dimensions. + For orthogonality via QR + SVD, typical numerical error is O(ε·n) where + ε is machine epsilon and n is the maximum relevant dimension. + + Args: + x: Input features array (used for dtype and dimension) + *factors: Factor arrays being compared (used for output dimensions) + safety_factor: Multiplicative safety factor (default 10) + + Returns: + Threshold value for considering singular values as effectively zero + """ + eps = jnp.finfo(x.dtype).eps + n_features = x.shape[1] + factor_dims = [f.shape[1] for f in factors] + max_dim = max(n_features, *factor_dims) + return float(max_dim * eps * safety_factor) + + def test_linear_regression_perfect_fit() -> None: """Verify weighted least squares recovers a perfect linear relation.""" x = jnp.arange(6.0).reshape(-1, 1) y = 3.0 * x + 2.0 weights = jnp.ones(x.shape[0]) - scalars, projections = linear_regression(x, y, weights) + scalars, arrays = linear_regression(x, y, weights) assert pytest.approx(1.0) == scalars["r2"] assert pytest.approx(0.0, abs=1e-5) == scalars["rmse"] assert pytest.approx(0.0, abs=1e-5) == scalars["mae"] - chex.assert_trees_all_close(projections["projected"], y) + chex.assert_trees_all_close(arrays["projected"], y) def test_linear_regression_svd_selects_best_rcond() -> None: @@ -32,7 +69,7 @@ def test_linear_regression_svd_selects_best_rcond() -> None: y = jnp.sum(x, axis=1, keepdims=True) weights = jnp.array([0.1, 0.2, 0.3, 0.4]) - scalars, projections = linear_regression_svd( + scalars, arrays = linear_regression_svd( x, y, weights, @@ -40,7 +77,7 @@ def test_linear_regression_svd_selects_best_rcond() -> None: ) assert scalars["best_rcond"] in {1e-6, 1e-4, 1e-2} - chex.assert_trees_all_close(projections["projected"], y) + chex.assert_trees_all_close(arrays["projected"], y) def test_layer_regression_requires_targets() -> None: @@ -51,9 +88,6 @@ def test_layer_regression_requires_targets() -> None: with pytest.raises(ValueError, match="requires belief_states"): layer_linear_regression(x, weights, None) - with pytest.raises(ValueError, match="requires belief_states"): - layer_linear_regression_svd(x, weights, None) - def test_linear_regression_rejects_mismatched_weights() -> None: """Weights must align with the sample dimension.""" @@ -90,10 +124,10 @@ def test_linear_regression_without_intercept_uses_uniform_weights() -> None: x = jnp.arange(1.0, 4.0)[:, None] y = 2.0 * x - scalars, projections = linear_regression(x, y, None, fit_intercept=False) + scalars, arrays = linear_regression(x, y, None, fit_intercept=False) assert pytest.approx(1.0) == scalars["r2"] - chex.assert_trees_all_close(projections["projected"], y) + chex.assert_trees_all_close(arrays["projected"], y) def test_linear_regression_svd_handles_empty_features() -> None: @@ -102,10 +136,10 @@ def test_linear_regression_svd_handles_empty_features() -> None: y = jnp.arange(3.0)[:, None] weights = jnp.ones(3) - scalars, projections = linear_regression_svd(x, y, weights, fit_intercept=False) + scalars, arrays = linear_regression_svd(x, y, weights, fit_intercept=False) assert scalars["best_rcond"] == pytest.approx(1e-15) - chex.assert_trees_all_close(projections["projected"], jnp.zeros_like(y)) + chex.assert_trees_all_close(arrays["projected"], jnp.zeros_like(y)) def test_linear_regression_accepts_one_dimensional_inputs() -> None: @@ -114,10 +148,10 @@ def test_linear_regression_accepts_one_dimensional_inputs() -> None: y = 5.0 * x + 1.0 weights = jnp.ones_like(x) - scalars, projections = linear_regression(x, y, weights) + scalars, arrays = linear_regression(x, y, weights) assert pytest.approx(1.0) == scalars["r2"] - chex.assert_trees_all_close(projections["projected"], y[:, None]) + chex.assert_trees_all_close(arrays["projected"], y[:, None]) def test_linear_regression_rejects_high_rank_inputs() -> None: @@ -168,25 +202,26 @@ def test_linear_regression_svd_falls_back_to_default_rcond() -> None: assert scalars["best_rcond"] == pytest.approx(1e-15) -def test_layer_linear_regression_svd_runs_end_to_end() -> None: +def test_layer_linear_regression_runs_end_to_end() -> None: """Layer helper should proxy through to the base implementation.""" x = jnp.arange(6.0).reshape(3, 2) weights = jnp.ones(3) / 3.0 beliefs = 2.0 * x.sum(axis=1, keepdims=True) - scalars, projections = layer_linear_regression_svd( + scalars, arrays = layer_linear_regression( x, weights, beliefs, + use_svd=True, rcond_values=[1e-3], ) assert pytest.approx(1.0, abs=1e-6) == scalars["r2"] - chex.assert_trees_all_close(projections["projected"], beliefs) + chex.assert_trees_all_close(arrays["projected"], beliefs) -def test_layer_linear_regression_to_factors_basic() -> None: - """Layer regression with to_factors should regress to each factor separately.""" +def test_layer_linear_regression_belief_states_tuple_default() -> None: + """By default, layer regression should regress to each factor separately if given a tuple of belief states.""" x = jnp.arange(12.0).reshape(4, 3) # 4 samples, 3 features weights = jnp.ones(4) / 4.0 @@ -195,11 +230,10 @@ def test_layer_linear_regression_to_factors_basic() -> None: factor_1 = jnp.array([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3], [0.4, 0.4, 0.2], [0.3, 0.3, 0.4]]) # [4, 3] factored_beliefs = (factor_0, factor_1) - scalars, projections = layer_linear_regression( + scalars, arrays = layer_linear_regression( x, weights, factored_beliefs, - to_factors=True, ) # Should have separate metrics for each factor @@ -213,16 +247,28 @@ def test_layer_linear_regression_to_factors_basic() -> None: assert "factor_1/dist" in scalars # Should have separate projections for each factor - assert "factor_0/projected" in projections - assert "factor_1/projected" in projections + assert "factor_0/projected" in arrays + assert "factor_1/projected" in arrays + + # Should have separate parameters for each factor + assert "factor_0/coeffs" in arrays + assert "factor_1/coeffs" in arrays + + # Should have separate intercepts for each factor by default + assert "factor_0/intercept" in arrays + assert "factor_1/intercept" in arrays # Check shapes - assert projections["factor_0/projected"].shape == factor_0.shape - assert projections["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/projected"].shape == factor_0.shape + assert arrays["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/coeffs"].shape == (x.shape[1], factor_0.shape[1]) + assert arrays["factor_1/coeffs"].shape == (x.shape[1], factor_1.shape[1]) + assert arrays["factor_0/intercept"].shape == (1, factor_0.shape[1]) + assert arrays["factor_1/intercept"].shape == (1, factor_1.shape[1]) -def test_layer_linear_regression_svd_to_factors_basic() -> None: - """Layer regression SVD with to_factors should regress to each factor separately.""" +def test_layer_linear_regression_svd_belief_states_tuple_default() -> None: + """By default, layer regression SVD should regress to each factor separately if given a tuple of belief states.""" x = jnp.arange(12.0).reshape(4, 3) # 4 samples, 3 features weights = jnp.ones(4) / 4.0 @@ -231,31 +277,45 @@ def test_layer_linear_regression_svd_to_factors_basic() -> None: factor_1 = jnp.array([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3], [0.4, 0.4, 0.2], [0.3, 0.3, 0.4]]) # [4, 3] factored_beliefs = (factor_0, factor_1) - scalars, projections = layer_linear_regression_svd( + scalars, arrays = layer_linear_regression( x, weights, factored_beliefs, - to_factors=True, + use_svd=True, rcond_values=[1e-6], ) - # Should have separate metrics for each factor including best_rcond - assert "factor_0/r2" in scalars - assert "factor_1/r2" in scalars - assert "factor_0/best_rcond" in scalars - assert "factor_1/best_rcond" in scalars + # Should have ALL regression metrics for each factor including best_rcond + for factor in [0, 1]: + assert f"factor_{factor}/r2" in scalars + assert f"factor_{factor}/rmse" in scalars + assert f"factor_{factor}/mae" in scalars + assert f"factor_{factor}/dist" in scalars + assert f"factor_{factor}/best_rcond" in scalars # Should have separate projections for each factor - assert "factor_0/projected" in projections - assert "factor_1/projected" in projections + assert "factor_0/projected" in arrays + assert "factor_1/projected" in arrays + + # Should have separate coefficients for each factor + assert "factor_0/coeffs" in arrays + assert "factor_1/coeffs" in arrays + + # Should have separate intercepts for each factor by default + assert "factor_0/intercept" in arrays + assert "factor_1/intercept" in arrays # Check shapes - assert projections["factor_0/projected"].shape == factor_0.shape - assert projections["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/projected"].shape == factor_0.shape + assert arrays["factor_1/projected"].shape == factor_1.shape + assert arrays["factor_0/coeffs"].shape == (x.shape[1], factor_0.shape[1]) + assert arrays["factor_1/coeffs"].shape == (x.shape[1], factor_1.shape[1]) + assert arrays["factor_0/intercept"].shape == (1, factor_0.shape[1]) + assert arrays["factor_1/intercept"].shape == (1, factor_1.shape[1]) -def test_layer_linear_regression_to_factors_single_factor() -> None: - """to_factors=True should work with a single factor tuple.""" +def test_layer_linear_regression_belief_states_tuple_single_factor() -> None: + """Single-element tuple should behave the same as passing a single array.""" x = jnp.arange(9.0).reshape(3, 3) weights = jnp.ones(3) / 3.0 @@ -263,126 +323,891 @@ def test_layer_linear_regression_to_factors_single_factor() -> None: factor_0 = jnp.array([[0.3, 0.7], [0.5, 0.5], [0.8, 0.2]]) factored_beliefs = (factor_0,) - scalars, projections = layer_linear_regression( + scalars, arrays = layer_linear_regression( x, weights, factored_beliefs, - to_factors=True, ) - # Should have metrics for single factor - assert "factor_0/r2" in scalars - assert "factor_0/projected" in projections - assert projections["factor_0/projected"].shape == factor_0.shape + # Should have same structure as non-tuple case + assert "r2" in scalars + assert "rmse" in scalars + assert "mae" in scalars + assert "dist" in scalars + assert "projected" in arrays + assert "coeffs" in arrays + assert "intercept" in arrays + # Verify it matches non-tuple behavior + scalars_non_tuple, arrays_non_tuple = layer_linear_regression(x, weights, factor_0) -def test_layer_linear_regression_to_factors_requires_tuple() -> None: - """to_factors=True requires belief_states to be a tuple.""" - x = jnp.ones((3, 2)) - weights = jnp.ones(3) / 3.0 - beliefs_array = jnp.ones((3, 2)) + chex.assert_trees_all_close(scalars, scalars_non_tuple) + chex.assert_trees_all_close(arrays, arrays_non_tuple) - with pytest.raises(ValueError, match="belief_states must be a tuple when to_factors is True"): - layer_linear_regression(x, weights, beliefs_array, to_factors=True) - with pytest.raises(ValueError, match="belief_states must be a tuple when to_factors is True"): - layer_linear_regression_svd(x, weights, beliefs_array, to_factors=True) +def test_orthogonality_with_orthogonal_subspaces() -> None: + """Orthogonal factors constructed explicitly should have near-zero overlap.""" + # Create truly orthogonal coefficient matrices by construction + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) -def test_layer_linear_regression_to_factors_validates_tuple_contents() -> None: - """to_factors=True requires all elements in tuple to be jax.Arrays.""" - x = jnp.ones((3, 2)) - weights = jnp.ones(3) / 3.0 + # Define orthogonal coefficient matrices + # w_0 uses first 3 features, w_1 uses last 3 features + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # (6, 2) - # Invalid: tuple contains non-array - invalid_beliefs = (jnp.ones((3, 2)), "not an array") # type: ignore + # Generate factors using orthogonal subspaces (no intercept for simplicity) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 2) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples - with pytest.raises(ValueError, match="Each factor in belief_states must be a jax.Array"): - layer_linear_regression(x, weights, invalid_beliefs, to_factors=True) # type: ignore + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, # No intercept for cleaner test + ) - with pytest.raises(ValueError, match="Each factor in belief_states must be a jax.Array"): - layer_linear_regression_svd(x, weights, invalid_beliefs, to_factors=True) # type: ignore + # Should have standard factor metrics with perfect fit + assert scalars["factor_0/r2"] > 0.99 # Should fit nearly perfectly + assert scalars["factor_1/r2"] > 0.99 + + # Should have ALL orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/min_singular_value" in scalars + assert "orthogonality_0_1/participation_ratio" in scalars + assert "orthogonality_0_1/entropy" in scalars + assert "orthogonality_0_1/effective_rank" in scalars + + # Compute principled threshold based on machine precision and problem size + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) + + # Should indicate near-zero overlap (orthogonal by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] < threshold + assert scalars["orthogonality_0_1/max_singular_value"] < threshold + + # Should have singular values in arrays + assert "orthogonality_0_1/singular_values" in arrays + # Both factors have 2 dimensions, so min(2, 2) = 2 singular values + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + # All singular values should be near zero (orthogonal) + assert jnp.all(arrays["orthogonality_0_1/singular_values"] < threshold) + + +def test_orthogonality_with_aligned_subspaces() -> None: + """Aligned factors with identical column spaces should have high overlap.""" + + # Create truly aligned coefficient matrices by construction + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Define aligned coefficient matrices - w_1 = w_0 @ A for invertible A + # This ensures span(w_1) = span(w_0) + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + w_1 = jnp.array([[0.5, 1.0], [1.0, 0.5], [1.5, 1.5], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # (6, 2) + + # Generate factors using aligned subspaces (no intercept for simplicity) + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 2) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, # No intercept for cleaner test + ) -def test_layer_linear_regression_to_factors_false_works() -> None: - """to_factors=False requires belief_states to be a single array, not a tuple.""" - x = jnp.ones((3, 2)) - weights = jnp.ones(3) / 3.0 + # Should have standard factor metrics with perfect fit + assert scalars["factor_0/r2"] > 0.99 # Should fit nearly perfectly + assert scalars["factor_1/r2"] > 0.99 + + # Should have ALL orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/min_singular_value" in scalars + assert "orthogonality_0_1/participation_ratio" in scalars + assert "orthogonality_0_1/entropy" in scalars + assert "orthogonality_0_1/effective_rank" in scalars + + # Should indicate high overlap (aligned by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] > 0.99 + assert scalars["orthogonality_0_1/max_singular_value"] > 0.99 + + # Should have singular values in arrays + assert "orthogonality_0_1/singular_values" in arrays + # Both factors have 2 dimensions, so min(2, 2) = 2 singular values + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + # All singular values should be near 1.0 (perfectly aligned) + assert jnp.all(arrays["orthogonality_0_1/singular_values"] > 0.99) + + +def test_orthogonality_with_three_factors() -> None: + """Three factors should produce all pairwise orthogonality combinations.""" + + # Create three mutually orthogonal coefficient matrices + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Define three orthogonal coefficient matrices using disjoint features + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 0-1 + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) # Uses features 2-3 + w_2 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) # Uses features 4-5 + + # Generate factors using orthogonal subspaces + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 2) + factor_2 = x @ w_2 # (100, 2) + factored_beliefs = (factor_0, factor_1, factor_2) + weights = jnp.ones(n_samples) / n_samples + + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) - # Invalid: tuple when to_factors=False - factored_beliefs = (jnp.ones((3, 2)), jnp.ones((3, 3))) + # Should have standard factor metrics for all three factors + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + assert scalars["factor_2/r2"] > 0.99 - scalars, projections = layer_linear_regression(x, weights, factored_beliefs, to_factors=False) - assert "r2" in scalars - assert "projected" in projections - assert projections["projected"].shape == (3, 5) + # Compute principled threshold based on machine precision and problem size + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1, factor_2) + + # Should have ALL three pairwise orthogonality combinations + pairwise_keys = ["orthogonality_0_1", "orthogonality_0_2", "orthogonality_1_2"] + for pair_key in pairwise_keys: + assert f"{pair_key}/subspace_overlap" in scalars + assert f"{pair_key}/max_singular_value" in scalars + assert f"{pair_key}/min_singular_value" in scalars + assert f"{pair_key}/participation_ratio" in scalars + assert f"{pair_key}/entropy" in scalars + assert f"{pair_key}/effective_rank" in scalars + assert f"{pair_key}/singular_values" in arrays + + # All pairs should be orthogonal (near-zero overlap) + overlap = scalars[f"{pair_key}/subspace_overlap"] + assert overlap < threshold, f"{pair_key} subspace_overlap={overlap} >= threshold={threshold}" + + max_sv = scalars[f"{pair_key}/max_singular_value"] + assert max_sv < threshold, f"{pair_key} max_singular_value={max_sv} >= threshold={threshold}" + + # Each pair has 2D subspaces, so 2 singular values + assert arrays[f"{pair_key}/singular_values"].shape[0] == 2 + svs = arrays[f"{pair_key}/singular_values"] + assert jnp.all(svs < threshold), f"{pair_key} singular_values={svs} not all < threshold={threshold}" + + +def test_orthogonality_not_computed_by_default() -> None: + """Orthogonality metrics should not be computed when compute_subspace_orthogonality=False.""" - scalars, projections = layer_linear_regression_svd(x, weights, factored_beliefs, to_factors=False) + # Setup two-factor regression + n_samples, n_features = 50, 4 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + + factor_0 = x @ w_0 + factor_1 = x @ w_1 + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + # Run WITHOUT compute_subspace_orthogonality (default is False) + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + fit_intercept=False, + ) + + # Should have standard factor metrics + assert "factor_0/r2" in scalars + assert "factor_1/r2" in scalars + + # Should NOT have any orthogonality metrics + orthogonality_keys = [ + "orthogonality_0_1/subspace_overlap", + "orthogonality_0_1/max_singular_value", + "orthogonality_0_1/min_singular_value", + "orthogonality_0_1/participation_ratio", + "orthogonality_0_1/entropy", + "orthogonality_0_1/effective_rank", + ] + for key in orthogonality_keys: + assert key not in scalars + + # Should NOT have orthogonality singular values in arrays + assert "orthogonality_0_1/singular_values" not in arrays + + +def test_orthogonality_warning_for_single_belief_state(caplog: pytest.LogCaptureFixture) -> None: + """Should warn when requesting orthogonality with a single belief state.""" + + # Setup single-factor regression + n_samples, n_features = 30, 4 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + belief_state = jax.random.normal(key, (n_samples, 2)) + weights = jnp.ones(n_samples) / n_samples + + # Request orthogonality with single belief state (not a tuple) + with caplog.at_level("WARNING"): + scalars, arrays = layer_linear_regression( + x, + weights, + belief_state, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + + # Should have logged a warning + assert "Subspace orthogonality requires multiple factors." in caplog.text + + # Should still run regression successfully assert "r2" in scalars - assert "projected" in projections - assert projections["projected"].shape == (3, 5) + assert "projected" in arrays + # Should NOT have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" not in scalars + assert "orthogonality_0_1/singular_values" not in arrays -def test_factored_regression_perfect_linear_fit() -> None: - """Test factored regression with perfectly linear targets achieves perfect fit. - Uses targets that are exact linear combinations of features to verify - the regression machinery works correctly for the factored case. - """ - # 5 samples, 4 features - x = jnp.array( +def test_use_svd_flag_equivalence() -> None: + """layer_linear_regression with use_svd=True should match layer_linear_regression_svd.""" + + n_samples, n_features = 40, 4 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Test with single belief state + belief_state = jax.random.normal(key, (n_samples, 3)) + weights = jnp.ones(n_samples) / n_samples + rcond_values = [1e-6, 1e-4] + + # Method 1: use_svd=True + scalars_flag, arrays_flag = layer_linear_regression( + x, + weights, + belief_state, + use_svd=True, + rcond_values=rcond_values, + ) + + # Method 2: layer_linear_regression_svd + scalars_wrapper, arrays_wrapper = layer_linear_regression( + x, + weights, + belief_state, + use_svd=True, + rcond_values=rcond_values, + ) + + # Should produce identical results + assert scalars_flag.keys() == scalars_wrapper.keys() + for key, value in scalars_flag.items(): + assert value == pytest.approx(scalars_wrapper[key]) + + assert arrays_flag.keys() == arrays_wrapper.keys() + for key, value in arrays_flag.items(): + chex.assert_trees_all_close(value, arrays_wrapper[key]) + + # Test with factored belief states + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + factor_0 = x @ w_0 + factor_1 = x @ w_1 + factored_beliefs = (factor_0, factor_1) + + # Method 1: use_svd=True with factored beliefs + scalars_flag_fact, arrays_flag_fact = layer_linear_regression( + x, + weights, + factored_beliefs, + use_svd=True, + rcond_values=rcond_values, + ) + + # Method 2: layer_linear_regression_svd with factored beliefs + scalars_wrapper_fact, arrays_wrapper_fact = layer_linear_regression( + x, + weights, + factored_beliefs, + use_svd=True, + rcond_values=rcond_values, + ) + + # Should produce identical results + assert scalars_flag_fact.keys() == scalars_wrapper_fact.keys() + for key, value in scalars_flag_fact.items(): + assert value == pytest.approx(scalars_wrapper_fact[key]) + + assert arrays_flag_fact.keys() == arrays_wrapper_fact.keys() + for key, value in arrays_flag_fact.items(): + chex.assert_trees_all_close(value, arrays_wrapper_fact[key]) + + +def test_use_svd_with_orthogonality() -> None: + """SVD regression should work with orthogonality computation.""" + + n_samples, n_features = 80, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create orthogonal coefficient matrices + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + + factor_0 = x @ w_0 + factor_1 = x @ w_1 + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + # Run SVD regression with orthogonality computation + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + use_svd=True, + compute_subspace_orthogonality=True, + rcond_values=[1e-6], + fit_intercept=False, + ) + + # Should have standard factor metrics with SVD + assert "factor_0/r2" in scalars + assert "factor_1/r2" in scalars + assert "factor_0/best_rcond" in scalars + assert "factor_1/best_rcond" in scalars + + # Should have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/singular_values" in arrays + + # Compute principled threshold + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) + + # Should indicate near-zero overlap (orthogonal by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] < threshold + assert scalars["orthogonality_0_1/max_singular_value"] < threshold + + # Should have good regression fit + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + +def test_orthogonality_with_different_subspace_dimensions() -> None: + """Orthogonality should work when factors have different output dimensions.""" + + n_samples, n_features = 100, 8 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create orthogonal coefficient matrices with different output dimensions + # factor_0 has 2 output dimensions, factor_1 has 5 output dimensions + w_0 = jnp.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) # (8, 2) + w_1 = jnp.array( [ - [1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [4.0, 5.0, 6.0, 7.0], - [5.0, 6.0, 7.0, 8.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], ] + ) # (8, 5) + + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 5) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + + # Should have standard factor metrics + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + # Should have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/singular_values" in arrays + + # Compute principled threshold + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) + + # Should indicate near-zero overlap (orthogonal by construction) + assert scalars["orthogonality_0_1/subspace_overlap"] < threshold + assert scalars["orthogonality_0_1/max_singular_value"] < threshold + + # Singular values shape should be min(2, 5) = 2 + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + assert jnp.all(arrays["orthogonality_0_1/singular_values"] < threshold) + + +def test_orthogonality_with_contained_subspace() -> None: + """Smaller subspace fully contained in larger subspace should show high alignment.""" + + n_samples, n_features = 100, 8 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create coefficient matrices where factor_0's subspace is contained in factor_1's + # factor_0: 2D subspace using features [0, 1] + # factor_1: 3D subspace using features [0, 1, 2] (contains factor_0's space) + w_0 = jnp.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) # (8, 2) + w_1 = jnp.array( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) # (8, 3) + + factor_0 = x @ w_0 # (100, 2) + factor_1 = x @ w_1 # (100, 3) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=False, + ) + + # Should have standard factor metrics + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 + + # Should have orthogonality metrics + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars + assert "orthogonality_0_1/singular_values" in arrays + + # Singular values shape should be min(2, 3) = 2 + assert arrays["orthogonality_0_1/singular_values"].shape[0] == 2 + + # Since factor_0's subspace is contained in factor_1's, singular values should be near 1.0 + # (indicating perfect alignment in the 2D shared subspace) + assert scalars["orthogonality_0_1/subspace_overlap"] > 0.99 + assert scalars["orthogonality_0_1/max_singular_value"] > 0.99 + assert scalars["orthogonality_0_1/min_singular_value"] > 0.99 + assert jnp.all(arrays["orthogonality_0_1/singular_values"] > 0.99) + + +def test_orthogonality_excludes_intercept() -> None: + """Orthogonality should be computed using only coefficients, not intercept.""" + + n_samples, n_features = 100, 6 + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (n_samples, n_features)) + + # Create orthogonal coefficient matrices + w_0 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + w_1 = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + + # Add different intercepts to the factors + intercept_0 = jnp.array([[5.0, -3.0]]) + intercept_1 = jnp.array([[10.0, 7.0]]) + + factor_0 = x @ w_0 + intercept_0 # (100, 2) + factor_1 = x @ w_1 + intercept_1 # (100, 2) + factored_beliefs = (factor_0, factor_1) + weights = jnp.ones(n_samples) / n_samples + + # Run with fit_intercept=True + scalars, arrays = layer_linear_regression( + x, + weights, + factored_beliefs, + compute_subspace_orthogonality=True, + fit_intercept=True, ) - weights = jnp.ones(5) / 5.0 - # Factor 0: 3 states, exact linear combination (with intercept) - # y0 = [x0 + 1, x1 + 2, x2 + 3] - factor_0 = jnp.stack([x[:, 0] + 1, x[:, 1] + 2, x[:, 2] + 3], axis=1) + # Should have intercepts for both factors + assert "factor_0/intercept" in arrays + assert "factor_1/intercept" in arrays + + # Should have good regression fit + assert scalars["factor_0/r2"] > 0.99 + assert scalars["factor_1/r2"] > 0.99 - # Factor 1: 2 states, exact linear combination - # y1 = [x1, x3] - factor_1 = jnp.stack([x[:, 1], x[:, 3]], axis=1) + # Orthogonality should still be near-zero (computed from coefficients only, not intercepts) + threshold = _compute_orthogonality_threshold(x, factor_0, factor_1) - scalars, projections = layer_linear_regression(x, weights, (factor_0, factor_1), to_factors=True) + assert "orthogonality_0_1/subspace_overlap" in scalars + assert "orthogonality_0_1/max_singular_value" in scalars - # Should achieve perfect R² since targets are exact linear combinations - assert scalars["factor_0/r2"] > 0.99, f"factor_0 R² too low: {scalars['factor_0/r2']}" - assert scalars["factor_1/r2"] > 0.99, f"factor_1 R² too low: {scalars['factor_1/r2']}" + overlap = scalars["orthogonality_0_1/subspace_overlap"] + assert overlap < threshold, f"subspace_overlap={overlap} >= threshold={threshold}" - # Projections should match targets very closely - chex.assert_trees_all_close(projections["factor_0/projected"], factor_0, atol=1e-4) - chex.assert_trees_all_close(projections["factor_1/projected"], factor_1, atol=1e-4) + max_sv = scalars["orthogonality_0_1/max_singular_value"] + assert max_sv < threshold, f"max_singular_value={max_sv} >= threshold={threshold}" + # The different intercepts should not affect orthogonality + svs = arrays["orthogonality_0_1/singular_values"] + assert jnp.all(svs < threshold), f"singular_values={svs} not all < threshold={threshold}" -def test_factored_regression_different_state_counts() -> None: - """Test factored regression with factors having different numbers of states. - This reproduces a scenario where factors have different dimensionality, - which is common in factored generative processes. +def test_linear_regression_constant_targets_r2_and_dist() -> None: + """Constant targets should yield r2==0, and dist matches weighted residual norm. + + With intercept: perfect fit to constant -> zero residuals but r2 fallback to 0.0. + Without intercept: nonzero residuals; verify `dist` against manual computation. """ - x = jnp.arange(24.0).reshape(6, 4) # 6 samples, 4 features - weights = jnp.ones(6) / 6.0 + x = jnp.arange(4.0)[:, None] + y = jnp.ones_like(x) * 3.0 + weights = jnp.array([0.1, 0.2, 0.3, 0.4]) - # Factor 0: 3 states (like "mess3") - factor_0_raw = x[:, :3] - factor_0 = factor_0_raw / factor_0_raw.sum(axis=1, keepdims=True) + # With intercept -> perfect constant fit, but r2 should fallback to 0.0 when variance is zero + scalars, _ = linear_regression(x, y, weights) + assert scalars["r2"] == 0.0 + assert jnp.isclose(scalars["rmse"], 0.0, atol=1e-6, rtol=0.0).item() + assert jnp.isclose(scalars["mae"], 0.0, atol=1e-6, rtol=0.0).item() + assert jnp.isclose(scalars["dist"], 0.0, atol=1e-6, rtol=0.0).item() + + # Without intercept -> cannot fit a constant perfectly; r2 still 0.0, and dist should match manual computation + scalars_no_int, arrays_no_int = linear_regression(x, y, weights, fit_intercept=False) + assert scalars_no_int["r2"] == 0.0 + residuals = arrays_no_int["projected"] - y + per_sample = jnp.sqrt(jnp.sum(residuals**2, axis=1)) + expected_dist = float(jnp.sum(per_sample * weights)) + assert jnp.isclose(scalars_no_int["dist"], expected_dist, atol=1e-6, rtol=0.0).item() + + +def test_linear_regression_intercept_and_shapes_both_solvers() -> None: + """Validate intercept presence/absence and array shapes for both solvers.""" + n, d, t = 5, 3, 2 + x = jnp.arange(float(n * d)).reshape(n, d) + # Construct multi-target y with known linear relation and intercept + true_coeffs = jnp.array([[1.0, 2.0], [0.5, -1.0], [3.0, 0.0]]) # (d, t) + true_intercept = jnp.array([[0.7, -0.3]]) # (1, t) + y = x @ true_coeffs + true_intercept + weights = jnp.ones(n) / n + + # Standard solver, with intercept + _, arrays = linear_regression(x, y, weights, fit_intercept=True) + assert "projected" in arrays + assert "coeffs" in arrays + assert "intercept" in arrays + assert arrays["projected"].shape == (n, t) + assert arrays["coeffs"].shape == (d, t) + assert arrays["intercept"].shape == (1, t) + + # Standard solver, without intercept + _, arrays_no_int = linear_regression(x, y, weights, fit_intercept=False) + assert "projected" in arrays_no_int + assert "coeffs" in arrays_no_int + assert "intercept" not in arrays_no_int + assert arrays_no_int["projected"].shape == (n, t) + assert arrays_no_int["coeffs"].shape == (d, t) + + # SVD solver, with intercept + _, arrays_svd = linear_regression_svd(x, y, weights, fit_intercept=True) + assert "projected" in arrays_svd + assert "coeffs" in arrays_svd + assert "intercept" in arrays_svd + assert arrays_svd["projected"].shape == (n, t) + assert arrays_svd["coeffs"].shape == (d, t) + assert arrays_svd["intercept"].shape == (1, t) + + # SVD solver, without intercept + _, arrays_svd_no_int = linear_regression_svd(x, y, weights, fit_intercept=False) + assert "projected" in arrays_svd_no_int + assert "coeffs" in arrays_svd_no_int + assert "intercept" not in arrays_svd_no_int + assert arrays_svd_no_int["projected"].shape == (n, t) + assert arrays_svd_no_int["coeffs"].shape == (d, t) + + +def test_layer_linear_regression_concat_vs_separate_equivalence() -> None: + """Concat and separate factor regressions should yield identical per-factor arrays.""" + n, d = 6, 3 + x = jnp.arange(float(n * d)).reshape(n, d) + # Two factors with different output dims + w_0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) + b0 = jnp.array([[0.3, -0.2]]) # (1, 2) + factor_0 = x @ w_0 + b0 + + w_1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) + b1 = jnp.array([[0.1, 0.2, -0.1]]) # (1, 3) + factor_1 = x @ w_1 + b1 - # Factor 1: 2 states (like "tom quantum") - factor_1_raw = x[:, :2] - factor_1 = factor_1_raw / factor_1_raw.sum(axis=1, keepdims=True) + factored_beliefs = (factor_0, factor_1) + weights = jnp.array([0.05, 0.10, 0.15, 0.20, 0.25, 0.25]) - scalars, projections = layer_linear_regression(x, weights, (factor_0, factor_1), to_factors=True) + # Separate per-factor regression + _, arrays_sep = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=False, + ) - # Verify shapes are correct - assert projections["factor_0/projected"].shape == (6, 3) - assert projections["factor_1/projected"].shape == (6, 2) + # Concatenated regression with splitting + _, arrays_cat = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=True, + ) + + # Concat path should also provide combined arrays + assert "concat/projected" in arrays_cat + assert "concat/coeffs" in arrays_cat + assert "concat/intercept" in arrays_cat + + # Per-factor arrays should match between separate and concatenated flows + for k in ["projected", "coeffs", "intercept"]: + chex.assert_trees_all_close(arrays_sep[f"factor_0/{k}"], arrays_cat[f"factor_0/{k}"]) + chex.assert_trees_all_close(arrays_sep[f"factor_1/{k}"], arrays_cat[f"factor_1/{k}"]) + + +def test_layer_linear_regression_svd_concat_vs_separate_equivalence_best_rcond() -> None: + """SVD regression: concat-split vs separate produce identical per-factor arrays. + + If belief concatenation is enabled, we only report rcond for the concatenated fit as "concat/best_rcond". + If belief concatenation is disabled, we report rcond for each factor as "factor_k/best_rcond". + """ + n, d = 6, 3 + x = jnp.arange(float(n * d)).reshape(n, d) + # Two factors with different output dims + w_0 = jnp.array([[1.0, 0.5], [0.0, -1.0], [2.0, 1.0]]) # (d, 2) + b0 = jnp.array([[0.3, -0.2]]) # (1, 2) + factor_0 = x @ w_0 + b0 + + w_1 = jnp.array([[0.2, 0.0, -0.5], [1.0, 1.0, 0.0], [-1.0, 0.5, 0.3]]) # (d, 3) + b1 = jnp.array([[0.1, 0.2, -0.1]]) # (1, 3) + factor_1 = x @ w_1 + b1 + + factored_beliefs = (factor_0, factor_1) + weights = jnp.array([0.05, 0.10, 0.15, 0.20, 0.25, 0.25]) + + # Separate per-factor SVD regression + scalars_sep, arrays_sep = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=False, + use_svd=True, + rcond_values=[1e-3], + ) + + # Concatenated SVD regression with splitting + scalars_cat, arrays_cat = layer_linear_regression( + x, + weights, + factored_beliefs, + concat_belief_states=True, + use_svd=True, + rcond_values=[1e-3], + ) - # Both should achieve reasonable fit - assert scalars["factor_0/r2"] > 0.5, f"factor_0 R² too low: {scalars['factor_0/r2']}" - assert scalars["factor_1/r2"] > 0.5, f"factor_1 R² too low: {scalars['factor_1/r2']}" + # Concat path should provide combined arrays and best_rcond + assert "concat/projected" in arrays_cat + assert "concat/coeffs" in arrays_cat + assert "concat/intercept" in arrays_cat + assert "concat/best_rcond" in scalars_cat + assert scalars_cat["concat/best_rcond"] == pytest.approx(1e-3) + + # Separate path should include per-factor best_rcond; concat-split path should not + assert "factor_0/best_rcond" in scalars_sep + assert "factor_1/best_rcond" in scalars_sep + assert "factor_0/best_rcond" not in scalars_cat + assert "factor_1/best_rcond" not in scalars_cat + + # Per-factor arrays should match between separate and concat-split flows + for k in ["projected", "coeffs", "intercept"]: + chex.assert_trees_all_close(arrays_sep[f"factor_0/{k}"], arrays_cat[f"factor_0/{k}"]) + chex.assert_trees_all_close(arrays_sep[f"factor_1/{k}"], arrays_cat[f"factor_1/{k}"]) + + # Overlapping scalar metrics should agree closely across flows + for metric in ["r2", "rmse", "mae", "dist"]: + assert jnp.isclose( + jnp.asarray(scalars_sep[f"factor_0/{metric}"]), + jnp.asarray(scalars_cat[f"factor_0/{metric}"]), + atol=1e-6, + rtol=0.0, + ).item() + assert jnp.isclose( + jnp.asarray(scalars_sep[f"factor_1/{metric}"]), + jnp.asarray(scalars_cat[f"factor_1/{metric}"]), + atol=1e-6, + rtol=0.0, + ).item() + + +def test_get_robust_basis_full_rank(): + """Full rank matrix should return all basis vectors.""" + # Create a full rank 5x3 matrix + key = jax.random.PRNGKey(42) + matrix = jax.random.normal(key, (5, 3)) + + basis = get_robust_basis(matrix) + + # Should return 3 basis vectors (all columns are linearly independent) + assert basis.shape == (5, 3) + + # Basis should be orthonormal + # Error in Gram matrix scales with: n_basis * eps + eps = jnp.finfo(basis.dtype).eps + tol = basis.shape[1] * eps + gram = basis.T @ basis + assert jnp.allclose(gram, jnp.eye(3), atol=tol) + + +def test_get_robust_basis_rank_deficient(): + """Rank deficient matrix should filter out zero singular value directions.""" + # Create a rank-2 matrix with 3 columns (third is linear combination) + col1 = jnp.array([[1.0], [0.0], [0.0], [0.0]]) + col2 = jnp.array([[0.0], [1.0], [0.0], [0.0]]) + col3 = 2.0 * col1 + 3.0 * col2 # Linear combination, rank deficient + matrix = jnp.hstack([col1, col2, col3]) + + basis = get_robust_basis(matrix) + + # Should return only 2 basis vectors (true rank is 2) + assert basis.shape[1] == 2 + + # Basis should be orthonormal + # Error in Gram matrix scales with: n_basis * eps + eps = jnp.finfo(basis.dtype).eps + tol = basis.shape[1] * eps + gram = basis.T @ basis + assert jnp.allclose(gram, jnp.eye(2), atol=tol) + + +def test_get_robust_basis_zero_matrix(): + """Zero matrix should return empty basis.""" + matrix = jnp.zeros((5, 3)) + basis = get_robust_basis(matrix) + + # Should return empty basis (no valid directions) + assert basis.shape == (5, 0) + + +def test_get_robust_basis_near_rank_deficient(): + """Matrix with very small singular value should filter it out.""" + # Create matrix with controlled singular values using SVD construction + key = jax.random.PRNGKey(123) + u = jax.random.normal(key, (6, 3)) + u, _ = jnp.linalg.qr(u) # Orthonormalize + + # Singular values: [10.0, 1.0, 1e-10] - last one is tiny + s = jnp.array([10.0, 1.0, 1e-10]) + v = jnp.eye(3) + + matrix = u @ jnp.diag(s) @ v + basis = get_robust_basis(matrix) + + # Should filter out the tiny singular value, keeping only 2 vectors + assert basis.shape[1] == 2 + + # Basis should be orthonormal + # Error in Gram matrix scales with: n_basis * eps + eps = jnp.finfo(basis.dtype).eps + tol = basis.shape[1] * eps + gram = basis.T @ basis + assert jnp.allclose(gram, jnp.eye(2), atol=tol) + + +def test_get_robust_basis_preserves_column_space(): + """Basis should span the same space as the original matrix's columns.""" + # Create a known rank-2 matrix + col1 = jnp.array([[1.0], [0.0], [0.0], [0.0]]) + col2 = jnp.array([[0.0], [1.0], [0.0], [0.0]]) + col3 = 2 * col1 + 3 * col2 # Linear combination + matrix = jnp.hstack([col1, col2, col3]) + + basis = get_robust_basis(matrix) + + # Basis should be rank 2 + assert basis.shape[1] == 2 + + # Compute principled tolerance based on matrix properties + # Error in projection scales with: max_dim * eps * max_singular_value + max_dim = max(matrix.shape) + eps = jnp.finfo(matrix.dtype).eps + max_sv = jnp.linalg.svd(matrix, compute_uv=False)[0] + tol = max_dim * eps * max_sv + + # Each original column should be expressible as linear combination of basis + for i in range(3): + col = matrix[:, i : i + 1] + # Project onto basis + projection = basis @ (basis.T @ col) + # Should be very close to original (within numerical tolerance) + assert jnp.allclose(projection, col, atol=tol) + + +def test_get_robust_basis_single_vector(): + """Single non-zero column should return normalized version.""" + vector = jnp.array([[3.0], [4.0], [0.0]]) + basis = get_robust_basis(vector) + + # Should return one basis vector + assert basis.shape == (3, 1) + + # Should be unit norm + # Error in norm computation scales with: dimension * eps + dim = vector.shape[0] + eps = jnp.finfo(vector.dtype).eps + norm_tol = dim * eps + assert jnp.allclose(jnp.linalg.norm(basis), 1.0, atol=norm_tol) + + # Should be parallel to input + # Error in dot product scales with: dimension * eps * magnitude + expected_norm = jnp.linalg.norm(vector) + parallel_tol = dim * eps * expected_norm + assert jnp.allclose(jnp.abs(basis.T @ vector), expected_norm, atol=parallel_tol)