Skip to content

Feature/global ci fn#344

Draft
danbraunai-goodfire wants to merge 145 commits intodevfrom
feature/global-ci-fn
Draft

Feature/global ci fn#344
danbraunai-goodfire wants to merge 145 commits intodevfrom
feature/global-ci-fn

Conversation

@danbraunai-goodfire
Copy link
Collaborator

@danbraunai-goodfire danbraunai-goodfire commented Jan 20, 2026

Copied over from #336 with the new dev base branch.

Description

Add global CI function mode as an alternative to the existing layerwise CI functions.

Key changes:

  • Add discriminated union config pattern (LayerwiseCiConfig | GlobalCiConfig) for CI function configuration
  • Implement GlobalSharedMLPCiFn that concatenates activations across all layers, processes through a shared MLP, and splits outputs back to per-layer CI values
  • Add unified wrapper classes (LayerwiseCiFnWrapper, GlobalCiFnWrapper) that provide a consistent interface regardless of CI mode
  • Update all experiment configs to new ci_config format
  • Add comprehensive tests for both CI modes and checkpoint compatibility validation
  • Update canonical runs in registry with new format

Related Issue

N/A

Motivation and Context

The existing layerwise CI functions learn separate importance functions per layer, which limits parameter sharing and may miss cross-layer patterns. Global CI provides an alternative where a single network sees all layer activations simultaneously, enabling:

  • Better parameter efficiency for models with many layers
  • Potential to learn cross-layer relationships in component importance
  • Simpler architecture for experiments where layer-specific CI isn't needed

How Has This Been Tested?

  • All tests pass (380 passed locally, CI green)
  • New tests added for:
    • Global CI function creation and forward pass
    • Checkpoint compatibility validation between CI modes
    • State dict save/load for both wrapper types
    • Integration with full ComponentModel pipeline
  • New canonical runs created with updated config format

Does this PR introduce a breaking change?

Yes. Checkpoints from before this PR are incompatible due to:

  • Config format change from flat ci_fn_type/ci_fn_hidden_dims to nested ci_config discriminated union
  • State dict key prefix change (now under ci_fn._ci_fns or ci_fn._global_ci_fn)

New canonical runs have been created to replace the old ones.

New Canonical Runs

Experiment Run ID WandB Link
tms_5-2 s-38e1a3e2 https://wandb.ai/goodfire/spd/runs/s-38e1a3e2
tms_5-2-id s-a1c0e9e2 https://wandb.ai/goodfire/spd/runs/s-a1c0e9e2
tms_40-10 s-7387fc20 https://wandb.ai/goodfire/spd/runs/s-7387fc20
tms_40-10-id s-2a2b5a57 https://wandb.ai/goodfire/spd/runs/s-2a2b5a57
resid_mlp1 s-62fce8c4 https://wandb.ai/goodfire/spd/runs/s-62fce8c4
resid_mlp2 s-a9ad193d https://wandb.ai/goodfire/spd/runs/s-a9ad193d

ocg-goodfire and others added 30 commits December 5, 2025 13:29
- Add global exception handlers for RequestValidationError, HTTPException, and Exception
- Add request/response logging middleware
- Replace silent JSONResponse error returns with HTTPException raises
- Ensures all errors log tracebacks and are visible in server logs

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
lee-goodfire and others added 30 commits January 14, 2026 18:29
Replace separate ci_fn_type, use_global_ci, and ci_fn_hidden_dims fields
with a single ci_config field using discriminated unions:

- LayerwiseCiConfig: mode="layerwise", fn_type in [mlp, vector_mlp, shared_mlp]
- GlobalCiConfig: mode="global", fn_type in [global_shared_mlp]

This separates the type namespaces for layerwise and global CI functions,
making it explicit that they are distinct systems rather than suggesting
every layerwise CI fn has a global variant.

Updated ComponentModel to use match statement on ci_config type, and
updated all test files to use the new config format.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Allow building new graph in interventions tab

* Add 'Generating' message

* Prevent multiple manual graphs with the same nodes

* Remove unused graph name everywhere

* Address some PR review comments

* More PR fixes

* Remove comments
Creates per-layer scatter plots showing normalized component activation
values for datapoints where CI exceeds a threshold. Components are ranked
by median activation on the x-axis.

Usage: python scripts/plot_component_activations.py <run_id> --ci-threshold 0.1

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add second set of plots ordered by firing frequency (from harvest data)
- Transform y-values to |normalized - 0.5| for frequency plots
- Organize outputs: scripts/outputs/<run-id>/component-act-scatter/order-by-{median,freq}/
- Include run_id in plot titles
- Use pre-calculated firing_counts from token_stats.pt

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…nticity plots) (#343)

* Add script to plot component activations vs component rank

Creates per-layer scatter plots showing normalized component activation
values for datapoints where CI exceeds a threshold. Components are ranked
by median activation on the x-axis.

Usage: python scripts/plot_component_activations.py <run_id> --ci-threshold 0.1

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Add frequency-ordered plots with abs distance from midpoint

- Add second set of plots ordered by firing frequency (from harvest data)
- Transform y-values to |normalized - 0.5| for frequency plots
- Organize outputs: scripts/outputs/<run-id>/component-act-scatter/order-by-{median,freq}/
- Include run_id in plot titles
- Use pre-calculated firing_counts from token_stats.pt

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ormat

- Resolve merge conflict in spd/models/component_model.py (keep methods from both branches)
- Update all YAML configs to use new ci_config discriminated union format:
  - Old: ci_fn_type + ci_fn_hidden_dims
  - New: ci_config with mode, fn_type, hidden_dims
- Update test_grid_search.py inline configs to include ci_config field
- Includes all changes from dev/app branch (app improvements, renames, etc.)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolve conflict in spd/models/component_model.py by keeping methods from both branches:
- HEAD: _calc_layerwise_causal_importances, _calc_global_causal_importances
- dev/app: get_all_component_acts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolves merge conflicts in PromptCardHeader.svelte and ss_llama_simple_mlp-2L-wide.yaml

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix incorrect tensor assertions (tensor.all() <= 1.0 → (tensor <= 1.0).all())
- Fix non-deterministic layer ordering (use sorted keys)
- Add config validation for CI fn_type/mode compatibility
- Add checkpoint compatibility validation between CI modes
- Extract _get_module_input_dim() helper to reduce duplication
- Improve type annotation for global_ci_fn
- Add 17 comprehensive tests for global CI functionality

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove broken RENAMED_CONFIG_KEYS entries (gate_type, gate_hidden_dims)
- Remove redundant _validate_ci_config() - type system already enforces this
- Simplify checkpoint validation error messages to be accurate
- Remove redundant comments in GlobalSharedMLPCiFn and _calc_global_causal_importances
- Add test for binomial sampling with global CI

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove accidentally committed large PNG output files (48 files, ~100MB).
Add scripts/outputs/ to .gitignore to prevent future accidents.

Note: Files remain in git history on this branch. For complete removal,
run git-filter-repo on the main repo after merge.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add usage examples to docstring
- Move script from scripts/ to spd/scripts/
- Change output directory to use Path(__file__).parent / "out"

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace dual tracking of global/layerwise CI with single unified interface:

- Add wrapper classes (GlobalCiFnWrapper, LayerwiseCiFnWrapper) to
  components.py
- Replace is_global_ci flag and separate ci_fns/global_ci_fn attributes
  with single ci_fn attribute in ComponentModel
- Remove _calc_layerwise_causal_importances and _calc_global_causal_importances
  methods - logic now in wrapper forward() methods
- Simplify parameter collection in run_spd.py and gradient logging in
  logging_utils.py
- Update tests to use new wrapper-based assertions

This eliminates the boolean flag anti-pattern and reduces 5 CI-related
attributes to 1, making the code cleaner and more maintainable.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Rename methods and variables to be explicit about layerwise vs global:
- _create_ci_fn → _create_layerwise_ci_fn
- _create_ci_fns → _create_layerwise_ci_fns
- has_ci_fns → has_layerwise_ci_fns
- ci_fns (local var) → layerwise_ci_fns

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
For consistency with config classes and wrapper class definitions,
put layerwise CI cases before global CI cases in match statements.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Change from old flat format (ci_fn_type, ci_fn_hidden_dims) to new
nested ci_config format.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Update all 6 canonical runs in registry.py with new run IDs
- Update clustering test to use resid_mlp2 run with new format

New canonical runs:
- tms_5-2: s-38e1a3e2
- tms_5-2-id: s-a1c0e9e2
- tms_40-10: s-7387fc20
- tms_40-10-id: s-2a2b5a57
- resid_mlp1: s-62fce8c4
- resid_mlp2: s-a9ad193d

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Streaming dataset loading only works for 'lm' tasks, not resid_mlp.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants