Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
60d1380
Refactor regression code to incorporate optional computation of pairw…
loren-ac Dec 11, 2025
09c1d89
Refine regression API and add comprehensive orthogonality tests
loren-ac Dec 12, 2025
d43935a
Organize imports
ealt Dec 12, 2025
9e600a4
Fix lint issues
ealt Dec 12, 2025
edba4fe
Fix slices
ealt Dec 12, 2025
70eb56e
Simplify lr kwarg validation
ealt Dec 12, 2025
9cc9810
Add return type
ealt Dec 12, 2025
d403bc7
Add pylint ignore
ealt Dec 12, 2025
1c55be0
Fix potential division by zero
ealt Dec 12, 2025
c3d070c
Fix potential log(0) issue
ealt Dec 12, 2025
0c9a37f
Enhance subspace orthogonality computation by adding a check for mult…
ealt Dec 12, 2025
74c6760
Fix docstring inconsistency
ealt Dec 12, 2025
d3b0235
Update docstring
ealt Dec 12, 2025
2d4a97f
Fix lint issues
ealt Dec 12, 2025
335d210
Refactor linear regression kwargs validation and improve logging. Tem…
ealt Dec 12, 2025
358985c
Fix merge conflict
loren-ac Dec 12, 2025
d6d7141
Ammended unseen merge conflict in linear_regression tests
loren-ac Dec 12, 2025
9a71da4
Rename to_factors parameter to concat_belief_states in activation ana…
loren-ac Dec 12, 2025
ecfa55c
Update activation analysis tests for concat_belief_states semantics
loren-ac Dec 12, 2025
8a16ab7
Fix validator error message and fix linting issues
loren-ac Dec 12, 2025
5b6247d
Add check requiring 2+ factors in _handle_factored_regression and rem…
loren-ac Dec 12, 2025
43123af
Add proper spacing to warning messages
loren-ac Dec 12, 2025
729222d
Fix dictionary equivalence check in test_linear_regression and add bl…
loren-ac Dec 12, 2025
2e8829f
Refactor subspace orthogonality computation for JIT compatibility
loren-ac Dec 12, 2025
4136030
Fix conditional callback execution using jax.lax.cond
loren-ac Dec 12, 2025
2be2032
Fix linting and formatting issues
loren-ac Dec 12, 2025
f77f2f5
Fix formatting issues
loren-ac Dec 12, 2025
7af2bc4
Disable too-many-locals linting issue in test_linear_regression.py
loren-ac Dec 12, 2025
6ee64fa
Change name of return dict from singular_values -> arrays for clarity
loren-ac Dec 12, 2025
84006da
Add docstring describing return values for _compute_all_pairwise_orth…
loren-ac Dec 12, 2025
556fede
Add docstring describing relevance of the do_nothing_branch function
loren-ac Dec 12, 2025
5b9801d
Refactor key removal method in kwarg validator and fix docstring format
loren-ac Dec 12, 2025
06c7692
Temporarily disable pylint checks during AST traversal in linear_regr…
ealt Dec 13, 2025
5bcbe03
Refactor linear regression analysis registration to use partial appli…
ealt Dec 13, 2025
ed69814
Fix tests
ealt Dec 13, 2025
46ce191
Add detailed docstring to _compute_subspace_orthogonality function, s…
ealt Dec 13, 2025
049b6d6
Add todo
ealt Dec 13, 2025
c890e36
Fix kwarg validation
ealt Dec 13, 2025
3a5a8e2
Fix tests
ealt Dec 13, 2025
0987697
Add validator decorator for linear_regression_svd to enforce use_svd=…
ealt Dec 13, 2025
0f37809
Fix test
ealt Dec 13, 2025
028e047
Add get_robust_basis for robust orthonormal basis extraction
loren-ac Dec 13, 2025
0532cd2
Pass pair of bases instead of coefficient matrices to _compute_subspa…
loren-ac Dec 13, 2025
95060d1
Compute full rank and orthonormal basis of coeff matrices before pass…
loren-ac Dec 13, 2025
b0ecb64
Fix formatting and docstring
loren-ac Dec 13, 2025
7a02602
Update comment
loren-ac Dec 13, 2025
69ff3e4
Fix issues due to API changes in activation and dataframe tests
loren-ac Dec 16, 2025
8e1efa4
Fix formatting issues
loren-ac Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions simplexity/activations/activation_analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def __init__(
use_probs_as_weights: bool = True,
skip_first_token: bool = False,
fit_intercept: bool = True,
to_factors: bool = False,
concat_belief_states: bool = False,
) -> None:
super().__init__(
analysis_type="linear_regression",
last_token_only=last_token_only,
concat_layers=concat_layers,
use_probs_as_weights=use_probs_as_weights,
skip_first_token=skip_first_token,
analysis_kwargs={"fit_intercept": fit_intercept, "to_factors": to_factors},
analysis_kwargs={"fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states},
)


Expand All @@ -108,9 +108,9 @@ def __init__(
skip_first_token: bool = False,
rcond_values: Sequence[float] | None = None,
fit_intercept: bool = True,
to_factors: bool = False,
concat_belief_states: bool = False,
) -> None:
analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "to_factors": to_factors}
analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "concat_belief_states": concat_belief_states}
if rcond_values is not None:
analysis_kwargs["rcond_values"] = tuple(rcond_values)
super().__init__(
Expand Down
81 changes: 50 additions & 31 deletions simplexity/analysis/layerwise_analysis.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
"""Composable layer-wise analysis orchestration."""

# pylint: disable=all # Temporarily disable all pylint checkers during AST traversal to prevent crash.
# The imports checker crashes when resolving simplexity package imports due to a bug
# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185
# pylint: enable=all # Re-enable all pylint checkers for the checking phase. This allows other checks
# (code quality, style, undefined names, etc.) to run normally while bypassing
# the problematic imports checker that would crash during AST traversal.

from __future__ import annotations

from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from typing import Any

import jax

from simplexity.analysis.linear_regression import (
layer_linear_regression,
layer_linear_regression_svd,
)
from simplexity.analysis.linear_regression import layer_linear_regression
from simplexity.analysis.pca import (
DEFAULT_VARIANCE_THRESHOLDS,
layer_pca_analysis,
)
from simplexity.logger import SIMPLEXITY_LOGGER

AnalysisFn = Callable[..., tuple[Mapping[str, float], Mapping[str, jax.Array]]]

Expand All @@ -34,35 +40,48 @@ class AnalysisRegistration:

def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
provided = dict(kwargs or {})
allowed = {"fit_intercept", "to_factors"}
allowed = {"fit_intercept", "concat_belief_states", "compute_subspace_orthogonality", "use_svd", "rcond_values"}
unexpected = set(provided) - allowed
if unexpected:
raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}")
fit_intercept = bool(provided.get("fit_intercept", True))
to_factors = bool(provided.get("to_factors", False))
return {"fit_intercept": fit_intercept, "to_factors": to_factors}


def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
provided = dict(kwargs or {})
allowed = {"fit_intercept", "rcond_values", "to_factors"}
unexpected = set(provided) - allowed
if unexpected:
raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}")
fit_intercept = bool(provided.get("fit_intercept", True))
to_factors = bool(provided.get("to_factors", False))
resolved_kwargs = {}
resolved_kwargs["fit_intercept"] = bool(provided.get("fit_intercept", True))
resolved_kwargs["concat_belief_states"] = bool(provided.get("concat_belief_states", False))
resolved_kwargs["compute_subspace_orthogonality"] = bool(provided.get("compute_subspace_orthogonality", False))
rcond_values = provided.get("rcond_values")
if rcond_values is not None:
if not isinstance(rcond_values, (list, tuple)):
raise TypeError("rcond_values must be a sequence of floats")
if len(rcond_values) == 0:
raise ValueError("rcond_values must not be empty")
rcond_values = tuple(float(v) for v in rcond_values)
return {
"fit_intercept": fit_intercept,
"to_factors": to_factors,
"rcond_values": rcond_values,
}
should_use_svd = rcond_values is not None
use_svd = bool(provided.get("use_svd", should_use_svd))
resolved_kwargs["use_svd"] = use_svd
if use_svd:
if rcond_values is not None:
if not isinstance(rcond_values, (list, tuple)):
raise TypeError("rcond_values must be a sequence of floats")
if len(rcond_values) == 0:
raise ValueError("rcond_values must not be empty")
if not use_svd:
SIMPLEXITY_LOGGER.warning("rcond_values are only used when use_svd is True")
rcond_values = tuple(float(v) for v in rcond_values)
resolved_kwargs["rcond_values"] = rcond_values
elif rcond_values is not None:
raise ValueError("rcond_values are only used when use_svd is True")
return resolved_kwargs


def set_use_svd(
fn: ValidatorFn,
) -> ValidatorFn:
"""Decorator to set use_svd to True in the kwargs and remove it from output to avoid duplicate with partial."""

def wrapper(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
if kwargs and "use_svd" in kwargs and not kwargs["use_svd"]:
raise ValueError("use_svd cannot be set to False for linear_regression_svd")
modified_kwargs = dict(kwargs) if kwargs else {} # Make a copy to avoid mutating the input
modified_kwargs["use_svd"] = True
resolved = fn(modified_kwargs)
resolved.pop("use_svd", None) # Remove use_svd to avoid duplicate argument with partial
return resolved

return wrapper


def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
Expand Down Expand Up @@ -97,9 +116,9 @@ def _validate_pca_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
validator=_validate_linear_regression_kwargs,
),
"linear_regression_svd": AnalysisRegistration(
fn=layer_linear_regression_svd,
fn=partial(layer_linear_regression, use_svd=True),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Duplicate keyword argument causes runtime error for SVD analysis

The linear_regression_svd registry entry uses partial(layer_linear_regression, use_svd=True) which binds use_svd=True as a keyword argument. However, the validator _validate_linear_regression_kwargs always sets resolved_kwargs["use_svd"] at line 54. When the analysis is called, both the partial and the resolved kwargs provide use_svd, causing a TypeError: got multiple values for keyword argument 'use_svd' at runtime.

Additional Locations (1)

Fix in Cursor Fix in Web

requires_belief_states=True,
validator=_validate_linear_regression_svd_kwargs,
validator=set_use_svd(_validate_linear_regression_kwargs),
),
"pca": AnalysisRegistration(
fn=layer_pca_analysis,
Expand Down
Loading
Loading