-
Notifications
You must be signed in to change notification settings - Fork 2
Add subspace orthogonality analysis for factored processes #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds functionality to compute orthogonality metrics between activation subspaces when models are trained on factored processes, where belief states are Cartesian products of subprocess belief states.
Key Changes:
- Implements subspace orthogonality computation using QR decomposition and SVD to measure overlap between learned coefficient subspaces
- Refactors API: replaces
to_factorswithconcat_belief_states, renamesprojectionstoarrays, separatescoeffsandinterceptin return values - Unifies SVD functionality through a
use_svdflag for consistent access to both regression methods
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
tests/analysis/test_linear_regression.py |
Adds 9 comprehensive orthogonality tests covering orthogonal/aligned/contained subspaces, multi-factor scenarios, and edge cases; updates existing tests for new API structure |
tests/analysis/test_layerwise_analysis.py |
Adds validation tests for new parameters (concat_belief_states, compute_subspace_orthogonality, use_svd) and verifies proper default handling |
simplexity/analysis/linear_regression.py |
Implements core orthogonality computation functions, refactors regression to support both concat and separate factor processing, separates coefficients from intercepts in return values |
simplexity/analysis/layerwise_analysis.py |
Updates parameter validators to accept new parameters and ensure proper forwarding to regression functions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is being reviewed by Cursor Bugbot
Details
You are on the Bugbot Free tier. On this plan, Bugbot will review limited PRs each billing cycle.
To receive Bugbot reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "Degenerate subspace detected during orthogonality computation." | ||
| "All singular values are zero." | ||
| "Setting probability values to zero." |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning message is missing proper spacing between sentences. Each sentence should be separated by a space for proper readability.
| "Degenerate subspace detected during orthogonality computation." | |
| "All singular values are zero." | |
| "Setting probability values to zero." | |
| "Degenerate subspace detected during orthogonality computation. " | |
| "All singular values are zero. " | |
| "Setting probability values to zero. " |
| if compute_subspace_orthogonality: | ||
| SIMPLEXITY_LOGGER.warning( | ||
| "Subspace orthogonality requires multiple factors." | ||
| "Received single factor of type %s; skipping orthogonality metrics.", |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning message is missing proper spacing between sentences. Each sentence should be separated by a space for proper readability.
| "Received single factor of type %s; skipping orthogonality metrics.", | |
| " Received single factor of type %s; skipping orthogonality metrics.", |
|
|
||
| assert scalars.keys() == scalars_non_tuple.keys() | ||
| assert arrays.keys() == arrays_non_tuple.keys() | ||
| for key in scalars.keys(): | ||
| assert scalars[key] == pytest.approx(scalars_non_tuple[key]) | ||
| for key in arrays.keys(): | ||
| assert arrays[key] == pytest.approx(arrays_non_tuple[key]) | ||
|
|
||
|
|
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing dictionary keys and floating point values using direct equality and pytest.approx can be fragile. When comparing dictionaries, consider using chex.assert_trees_all_close for both scalars and arrays, which would provide a more robust comparison. Additionally, the trailing whitespace on line 336 should be removed.
| assert scalars.keys() == scalars_non_tuple.keys() | |
| assert arrays.keys() == arrays_non_tuple.keys() | |
| for key in scalars.keys(): | |
| assert scalars[key] == pytest.approx(scalars_non_tuple[key]) | |
| for key in arrays.keys(): | |
| assert arrays[key] == pytest.approx(arrays_non_tuple[key]) | |
| chex.assert_trees_all_close(scalars, scalars_non_tuple) | |
| chex.assert_trees_all_close(arrays, arrays_non_tuple) |
| # Compute the entropy | ||
| probs = singular_values**2 / probs_denominator | ||
| num_zeros = jnp.sum(probs == 0) | ||
| if num_zeros > 0: |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comparison num_zeros > 0 may not behave correctly when num_zeros is a JAX array. JAX arrays do not support truthiness in conditional statements. Use .item() to convert to a Python scalar, or use jnp.any() for array operations. Consider changing to if int(num_zeros) > 0: or if num_zeros.item() > 0:.
| if num_zeros > 0: | |
| if num_zeros.item() > 0: |
| scalars.update(orthogonality_scalars) | ||
| arrays.update(orthogonality_singular_values) | ||
| else: | ||
| SIMPLEXITY_LOGGER.warning("Subspace orthogonality cannot be computed for a single belief state") |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning message is inconsistent with the one at lines 467-469. Line 424 says "cannot be computed for a single belief state" while lines 467-469 say "requires multiple factors". These should use consistent phrasing. Consider updating line 424 to match: "Subspace orthogonality requires multiple factors. Received single factor; skipping orthogonality metrics."
| SIMPLEXITY_LOGGER.warning("Subspace orthogonality cannot be computed for a single belief state") | |
| SIMPLEXITY_LOGGER.warning("Subspace orthogonality requires multiple factors. Received single factor; skipping orthogonality metrics.") |
| allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd", "rcond_values"} | ||
| unexpected = set(provided) - allowed | ||
| if unexpected: | ||
| raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}") |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message says "Unexpected linear_regression_svd kwargs" but this is now a base validator used by both linear_regression and linear_regression_svd. The error message should be more generic, such as "Unexpected linear regression kwargs" to accurately reflect that it's used by both validators.
| raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}") | |
| raise ValueError(f"Unexpected linear regression kwargs: {sorted(unexpected)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| singular_values = { | ||
| "singular_values": singular_values, | ||
| } | ||
|
|
||
| return scalars, singular_values |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name singular_values is reused for the dictionary that will be returned, shadowing the array singular_values computed earlier at line 292. This reduces code clarity as readers need to track which singular_values is being referenced. Consider using a more descriptive name for the dictionary, such as singular_values_dict or arrays.
| singular_values = { | |
| "singular_values": singular_values, | |
| } | |
| return scalars, singular_values | |
| singular_values_dict = { | |
| "singular_values": singular_values, | |
| } | |
| return scalars, singular_values_dict |
| probs = singular_values**2 / probs_denominator_safe | ||
|
|
||
| def execute_some_zeros_warning_branch(x): | ||
| # This correctly calls the log_some_zeros function |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment "This correctly calls the log_some_zeros function" is redundant and adds no value. The code is self-explanatory. Consider removing this comment to improve code clarity.
| # This correctly calls the log_some_zeros function |
| kwargs = _base_validate_linear_regression_kwargs(kwargs) | ||
| kwargs.pop("rcond_values") | ||
| return kwargs | ||
|
|
||
|
|
||
| def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: | ||
| kwargs = _base_validate_linear_regression_kwargs(kwargs) | ||
| kwargs.pop("use_svd") | ||
| return kwargs |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation functions modify the input by calling pop() on the dictionary returned from the base validator. This mutates the returned dictionary from _base_validate_linear_regression_kwargs. While this works, it creates an implicit dependency where the base validator must return a mutable dict. A cleaner approach would be to build a new dictionary without the unwanted key using dictionary comprehension or filtering, which would make the intent clearer and avoid mutation.
| kwargs = _base_validate_linear_regression_kwargs(kwargs) | |
| kwargs.pop("rcond_values") | |
| return kwargs | |
| def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: | |
| kwargs = _base_validate_linear_regression_kwargs(kwargs) | |
| kwargs.pop("use_svd") | |
| return kwargs | |
| base_kwargs = _base_validate_linear_regression_kwargs(kwargs) | |
| return {k: v for k, v in base_kwargs.items() if k != "rcond_values"} | |
| def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: | |
| base_kwargs = _base_validate_linear_regression_kwargs(kwargs) | |
| return {k: v for k, v in base_kwargs.items() if k != "use_svd"} |
| fn=partial(layer_linear_regression, use_svd=True), | ||
| requires_belief_states=True, | ||
| validator=_validate_linear_regression_svd_kwargs, | ||
| validator=_validate_linear_regression_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Duplicate use_svd keyword argument causes TypeError
The linear_regression_svd registry entry uses partial(layer_linear_regression, use_svd=True) which binds use_svd=True. However, _validate_linear_regression_kwargs always includes use_svd in its return dictionary (defaulting to False). When LayerwiseAnalysis.analyze() calls self._analysis_fn(..., **self._analysis_kwargs), the use_svd argument is passed twice—once from partial and once from the unpacked kwargs—causing a TypeError: got multiple values for keyword argument 'use_svd'.
Additional Locations (1)
| "concat_belief_states": concat_belief_states, | ||
| "compute_subspace_orthogonality": compute_subspace_orthogonality, | ||
| "use_svd": use_svd, | ||
| "rcond_values": rcond_values, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: SVD regression never used due to kwarg override
The validator _validate_linear_regression_kwargs always returns use_svd in its output (defaulting to False). When the linear_regression_svd registry entry uses partial(layer_linear_regression, use_svd=True), and analyze() calls it with **self._analysis_kwargs containing use_svd=False, the call-time kwarg overrides the partial's pre-bound value. This means linear_regression_svd will actually use non-SVD regression, breaking features like best_rcond output and SVD-based regularization.
Additional Locations (1)
| ), | ||
| "linear_regression_svd": AnalysisRegistration( | ||
| fn=layer_linear_regression_svd, | ||
| fn=partial(layer_linear_regression, use_svd=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Duplicate keyword argument causes runtime error for SVD analysis
The linear_regression_svd registry entry uses partial(layer_linear_regression, use_svd=True) which binds use_svd=True as a keyword argument. However, the validator _validate_linear_regression_kwargs always sets resolved_kwargs["use_svd"] at line 54. When the analysis is called, both the partial and the resolved kwargs provide use_svd, causing a TypeError: got multiple values for keyword argument 'use_svd' at runtime.
Additional Locations (1)
…ise subspace orthogonality metrics
- Separate coeffs/intercept in return structure (omit intercept key when fit_intercept=False) - Rename to_factors → concat_belief_states for clarity - Add 9 orthogonality tests with principled numerical thresholds (safety_factor=10) - Test orthogonal, aligned, contained subspaces; multi-factor scenarios; edge cases - Update validators and existing tests for new parameter structure - Add informative assertion messages for debugging numerical precision
…iple belief states. Log a warning if only one belief state is present, preventing unnecessary calculations.
…ove redundant orthogonality compuations warning
…ank line after docstring in test_layerwise_analysis
…ogonality function
…ession.py to prevent crashes. Remove deprecated layer_linear_regression_svd function for cleaner code and encourage use of layer_linear_regression with use_svd=True.
…cation of layer_linear_regression with use_svd=True, removing the deprecated layer_linear_regression_svd function for improved clarity and consistency.
…pecifying return values and their meanings for improved clarity and documentation.
…True and exclude it from output. Enhance tests to validate behavior.
…ing bases to subspace analysis
c37d006 to
69ff3e4
Compare
Summary
Adds functionality to compute orthogonality metrics between activation
subspaces when models are trained on factored processes (processes whose
belief state is a Cartesian product of subprocess belief states).
Key Features
metrics between learned coefficient subspaces using QR decomposition and
SVD
regression and orthogonality
coeffs(linear transformation) andintercept(translation) in return structure
Implementation
Core functionality:
linear_regression.py:compute_subspace_orthogonalityparameter tolayer_linear_regression()_compute_all_pairwise_orthogonality()and_compute_subspace_orthogonality()for pairwise subspace metrics for multi-factor scenariosexcluded)
use_svdflaglayerwise_analysis.py:concat_belief_states,compute_subspace_orthogonality,use_svdAPI improvements:
to_factorsflag and introduceconcat_belief_statesflagconcat_belief_statesto determine whether to do regression on concatenated beliefs vs factored beliefsprojectionsdictionary toarrays(more descriptive)arraysascoeffsandinterceptinterceptkey whenfit_intercept=FalseTesting:
Orthogonality computation tests (
test_linear_regression.py):Parameter validation tests (
test_layerwise_analysis.py):concat_belief_states,compute_subspace_orthogonality,use_svd)coeffs/interceptkeys)Note
Adds pairwise subspace-orthogonality metrics for factored beliefs, separates coeffs/intercept in outputs, and unifies linear regression/SVD via a use_svd flag with updated validators and tests.
coeffsandintercept(omitinterceptwhenfit_intercept=False).concat_belief_statesto fit jointly and split back; reuse params for metrics.use_svd(with optionalrcond_values); expose best rcond appropriately (concat vs per-factor).concat_belief_states,compute_subspace_orthogonality,use_svd,rcond_values; enforce constraints and defaults.linear_regression_svdviapartial(layer_linear_regression, use_svd=True)and guarduse_svdin validator.to_factorswithconcat_belief_states.projected,coeffs,intercept).Written by Cursor Bugbot for commit 51283f7. This will update automatically on new commits. Configure here.