Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
23f6efb
pylint (#98)
ealt Nov 12, 2025
26a3862
add compatibility for factored states
casperlchristensen Nov 12, 2025
a15ef3b
concrete examples and alternating process
casperlchristensen Nov 13, 2025
c90ebbc
tweaks to vocab sizes
casperlchristensen Nov 13, 2025
731eb5d
update naming
casperlchristensen Nov 18, 2025
c88d596
lock
casperlchristensen Dec 2, 2025
013ebfe
painful merge
casperlchristensen Dec 2, 2025
ba50000
full merge, renaming
casperlchristensen Dec 2, 2025
72c114d
test factored representation
casperlchristensen Dec 3, 2025
0bd219d
finalise gen-process PR
casperlchristensen Dec 3, 2025
46af230
update after merge
casperlchristensen Dec 3, 2025
d62fd5d
static analysis
casperlchristensen Dec 3, 2025
534b1ff
static analysis tweaks
casperlchristensen Dec 3, 2025
77076d6
arg name
casperlchristensen Dec 3, 2025
1074e5a
better test coverage
casperlchristensen Dec 3, 2025
7111f30
factor input args
casperlchristensen Dec 3, 2025
e7d2a92
ruff
casperlchristensen Dec 3, 2025
1babe79
better linting
casperlchristensen Dec 3, 2025
04b4110
bind i
casperlchristensen Dec 3, 2025
36baf6d
elipsis to protocol
casperlchristensen Dec 3, 2025
1dd7d44
simplify protocol
casperlchristensen Dec 3, 2025
dba0782
format
casperlchristensen Dec 3, 2025
4d957f6
Minor fixes
ealt Dec 5, 2025
da6392c
Minor fixes
ealt Dec 5, 2025
f290aef
jnp.ndarray -> jax.Array
ealt Dec 5, 2025
3c13032
Fix JIT compilation issue
ealt Dec 5, 2025
b666cc1
Refactor generative process config tests to use a helper method for c…
ealt Dec 5, 2025
1d23c0a
Add docstrings
ealt Dec 5, 2025
eac6194
Add match strings to value errors in tests
ealt Dec 5, 2025
d3be479
add better factor handling and allow regression to individual factors
casperlchristensen Dec 8, 2025
3fac6cc
pass device
casperlchristensen Dec 8, 2025
ea5d28f
static analysis
casperlchristensen Dec 8, 2025
87dd5be
merge from main
casperlchristensen Dec 8, 2025
3bb6984
better output format
casperlchristensen Dec 8, 2025
b56776f
to_factor in validation
casperlchristensen Dec 8, 2025
e484396
update returns and concatenations
casperlchristensen Dec 9, 2025
4d56734
tuple handling
casperlchristensen Dec 9, 2025
bcee48c
fix typehint
casperlchristensen Dec 9, 2025
ebbe21d
improve test coverage
casperlchristensen Dec 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions simplexity/activations/activation_analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def analyze(
self,
activations: Mapping[str, jax.Array],
weights: jax.Array,
belief_states: jax.Array | None = None,
belief_states: jax.Array | tuple[jax.Array, ...] | None = None,
) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]:
"""Analyze activations and return scalar metrics and projections."""
...
Expand Down Expand Up @@ -76,13 +76,14 @@ def __init__(
concat_layers: bool = False,
use_probs_as_weights: bool = True,
fit_intercept: bool = True,
to_factors: bool = False,
) -> None:
super().__init__(
analysis_type="linear_regression",
last_token_only=last_token_only,
concat_layers=concat_layers,
use_probs_as_weights=use_probs_as_weights,
analysis_kwargs={"fit_intercept": fit_intercept},
analysis_kwargs={"fit_intercept": fit_intercept, "to_factors": to_factors},
)


Expand All @@ -97,8 +98,9 @@ def __init__(
use_probs_as_weights: bool = True,
rcond_values: Sequence[float] | None = None,
fit_intercept: bool = True,
to_factors: bool = False,
) -> None:
analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept}
analysis_kwargs: dict[str, Any] = {"fit_intercept": fit_intercept, "to_factors": to_factors}
if rcond_values is not None:
analysis_kwargs["rcond_values"] = tuple(rcond_values)
super().__init__(
Expand Down
25 changes: 20 additions & 5 deletions simplexity/activations/activation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class PreparedActivations:
"""Prepared activations with belief states and sample weights."""

activations: Mapping[str, jax.Array]
belief_states: jax.Array | None
belief_states: jax.Array | tuple[jax.Array, ...] | None
weights: jax.Array


Expand Down Expand Up @@ -48,16 +48,26 @@ def _to_jax_array(value: Any) -> jax.Array:
return jnp.asarray(value)


def _convert_tuple_to_jax_array(value: tuple[Any, ...]) -> tuple[jax.Array, ...]:
"""Convert a tuple of supported tensor types to JAX arrays."""
return tuple(_to_jax_array(v) for v in value)


def prepare_activations(
inputs: jax.Array | torch.Tensor | np.ndarray,
beliefs: jax.Array | torch.Tensor | np.ndarray,
beliefs: jax.Array
| torch.Tensor
| np.ndarray
| tuple[jax.Array, ...]
| tuple[torch.Tensor, ...]
| tuple[np.ndarray, ...],
probs: jax.Array | torch.Tensor | np.ndarray,
activations: Mapping[str, jax.Array | torch.Tensor | np.ndarray],
prepare_options: PrepareOptions,
) -> PreparedActivations:
"""Preprocess activations by deduplicating sequences, selecting tokens/layers, and computing weights."""
inputs = _to_jax_array(inputs)
beliefs = _to_jax_array(beliefs)
beliefs = _convert_tuple_to_jax_array(beliefs) if isinstance(beliefs, tuple) else _to_jax_array(beliefs)
probs = _to_jax_array(probs)
activations = {name: _to_jax_array(layer) for name, layer in activations.items()}

Expand All @@ -74,7 +84,7 @@ def prepare_activations(
weights = (
dataset.probs
if prepare_options.use_probs_as_weights
else _get_uniform_weights(belief_states.shape[0], belief_states.dtype)
else _get_uniform_weights(dataset.probs.shape[0], dataset.probs.dtype)
)

if prepare_options.concat_layers:
Expand All @@ -98,7 +108,12 @@ def __init__(self, analyses: Mapping[str, ActivationAnalysis]):
def analyze(
self,
inputs: jax.Array | torch.Tensor | np.ndarray,
beliefs: jax.Array | torch.Tensor | np.ndarray,
beliefs: jax.Array
| torch.Tensor
| np.ndarray
| tuple[jax.Array, ...]
| tuple[torch.Tensor, ...]
| tuple[np.ndarray, ...],
probs: jax.Array | torch.Tensor | np.ndarray,
activations: Mapping[str, jax.Array | torch.Tensor | np.ndarray],
) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]:
Expand Down
11 changes: 7 additions & 4 deletions simplexity/analysis/layerwise_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,23 @@ class AnalysisRegistration:

def _validate_linear_regression_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
provided = dict(kwargs or {})
allowed = {"fit_intercept"}
allowed = {"fit_intercept", "to_factors"}
unexpected = set(provided) - allowed
if unexpected:
raise ValueError(f"Unexpected linear_regression kwargs: {sorted(unexpected)}")
fit_intercept = bool(provided.get("fit_intercept", True))
return {"fit_intercept": fit_intercept}
to_factors = bool(provided.get("to_factors", False))
return {"fit_intercept": fit_intercept, "to_factors": to_factors}


def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
provided = dict(kwargs or {})
allowed = {"fit_intercept", "rcond_values"}
allowed = {"fit_intercept", "rcond_values", "to_factors"}
unexpected = set(provided) - allowed
if unexpected:
raise ValueError(f"Unexpected linear_regression_svd kwargs: {sorted(unexpected)}")
fit_intercept = bool(provided.get("fit_intercept", True))
to_factors = bool(provided.get("to_factors", False))
rcond_values = provided.get("rcond_values")
if rcond_values is not None:
if not isinstance(rcond_values, (list, tuple)):
Expand All @@ -58,6 +60,7 @@ def _validate_linear_regression_svd_kwargs(kwargs: Mapping[str, Any] | None) ->
rcond_values = tuple(float(v) for v in rcond_values)
return {
"fit_intercept": fit_intercept,
"to_factors": to_factors,
"rcond_values": rcond_values,
}

Expand Down Expand Up @@ -152,7 +155,7 @@ def analyze(
self,
activations: Mapping[str, jax.Array],
weights: jax.Array,
belief_states: jax.Array | None = None,
belief_states: jax.Array | tuple[jax.Array, ...] | None = None,
) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]:
"""Analyze activations and return namespaced scalar metrics and projections."""
if self._requires_belief_states and belief_states is None:
Expand Down
42 changes: 38 additions & 4 deletions simplexity/analysis/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,56 @@ def linear_regression_svd(
def layer_linear_regression(
layer_activations: jax.Array,
weights: jax.Array,
belief_states: jax.Array | None,
belief_states: jax.Array | tuple[jax.Array, ...] | None,
to_factors: bool = False,
**kwargs: Any,
) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]:
"""Layer-wise regression helper that wraps :func:`linear_regression`."""
if belief_states is None:
raise ValueError("linear_regression requires belief_states")
return linear_regression(layer_activations, belief_states, weights, **kwargs)

if to_factors:
scalars, projections = {}, {}
if not isinstance(belief_states, tuple):
raise ValueError("belief_states must be a tuple when to_factors is True")
for factor_idx, factor in enumerate(belief_states):
if not isinstance(factor, jax.Array):
raise ValueError("Each factor in belief_states must be a jax.Array")
factor_scalars, factor_projections = linear_regression(layer_activations, factor, weights, **kwargs)
for key, value in factor_scalars.items():
scalars[f"factor_{factor_idx}/{key}"] = value
for key, value in factor_projections.items():
projections[f"factor_{factor_idx}/{key}"] = value
return scalars, projections
else:
belief_states = jnp.concatenate(belief_states, axis=-1) if isinstance(belief_states, tuple) else belief_states
return linear_regression(layer_activations, belief_states, weights, **kwargs)


def layer_linear_regression_svd(
layer_activations: jax.Array,
weights: jax.Array,
belief_states: jax.Array | None,
belief_states: jax.Array | tuple[jax.Array, ...] | None,
to_factors: bool = False,
**kwargs: Any,
) -> tuple[Mapping[str, float], Mapping[str, jax.Array]]:
"""Layer-wise regression helper that wraps :func:`linear_regression_svd`."""
if belief_states is None:
raise ValueError("linear_regression_svd requires belief_states")
return linear_regression_svd(layer_activations, belief_states, weights, **kwargs)

if to_factors:
scalars, projections = {}, {}
if not isinstance(belief_states, tuple):
raise ValueError("belief_states must be a tuple when to_factors is True")
for factor_idx, factor in enumerate(belief_states):
if not isinstance(factor, jax.Array):
raise ValueError("Each factor in belief_states must be a jax.Array")
factor_scalars, factor_projections = linear_regression_svd(layer_activations, factor, weights, **kwargs)
for key, value in factor_scalars.items():
scalars[f"factor_{factor_idx}/{key}"] = value
for key, value in factor_projections.items():
projections[f"factor_{factor_idx}/{key}"] = value
return scalars, projections
else:
belief_states = jnp.concatenate(belief_states, axis=-1) if isinstance(belief_states, tuple) else belief_states
return linear_regression_svd(layer_activations, belief_states, weights, **kwargs)
Loading
Loading