Skip to content

Autocast to bf16 for up to 25% spd speedup#359

Merged
danbraunai-goodfire merged 13 commits intomainfrom
refactor/use-tf32-config
Feb 7, 2026
Merged

Autocast to bf16 for up to 25% spd speedup#359
danbraunai-goodfire merged 13 commits intomainfrom
refactor/use-tf32-config

Conversation

@danbraunai-goodfire
Copy link
Collaborator

@danbraunai-goodfire danbraunai-goodfire commented Feb 6, 2026

Description

Add autocast_bf16 config option (default true) for SPD training, and add bf16 autocast to harvest, dataset attributions, and app backend. All forward passes (not backward) are wrapped with torch.autocast(device_type="cuda", dtype=torch.bfloat16).

Files with autocast:

  • spd/run_spd.py — main SPD optimization loop (configurable via autocast_bf16)
  • spd/harvest/harvest.py — inference loop
  • spd/dataset_attributions/harvester.py — both forward passes
  • spd/app/backend/compute.py — all forward pass sites (CI, attributions, interventions)
  • spd/app/backend/optim_cis.py — forward passes and loss computation in CI optimization

Motivation and Context

Speed.

How Has This Been Tested?

I have not done any tests on anything other than SPD training. I'm just assuming that it'll be fine because it's the same computations.

Testing autocast on bs=128 ss_llama_simple_mlp-1L. In the below, TF32 means using torch.set_float32_matmul_precision("high"), which reduces the precision for float32 matmuls. It basically does the same thing but doesn't reduce the memory consumption like autocast_bf16 does.

  • With TF32 With autocast: 1h37m

  • With TF32 Without autocast: 1h35m. VRAM 23265

  • Without TF32 With autocast: 1h36m. VRAM 20509

  • Without TF32 Without autocast: 2h6m

  • Run without TF32 without autocast

  • Run with TF32 without autocast

  • Run without TF32 with autocast.
    All look about identical. L0 is the same. Images below for losses (dark pink = autocast, light pink = TF32, brown = nothing)

Screenshot 2026-02-06 at 17 39 41

Does this PR introduce a breaking change?

No. The new autocast_bf16 defaults to true. Harvest, dataset attributions, and app backend always use autocast (no-op on CPU).

Previously TF32 was always enabled unconditionally. This adds a `use_tf32`
config field (default True) so it can be disabled when needed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@danbraunai-goodfire danbraunai-goodfire marked this pull request as draft February 6, 2026 16:38
Same training performance as TF32 but uses less memory (20509MB vs
23265MB for bs=128 ss_llama_simple_mlp-1L).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@danbraunai-goodfire danbraunai-goodfire changed the title Add use_tf32 config option for TF32 matmul precision Use bfloat16 autocast for mixed precision training Feb 6, 2026
More directly describes the mechanism (torch.autocast with bf16).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ocg-goodfire
Copy link
Collaborator

how much faster is it?

@danbraunai-goodfire
Copy link
Collaborator Author

danbraunai-goodfire commented Feb 6, 2026

@ocg-goodfire

how much faster is it?

depends on batch size. I think we'll get about 25% faster assuming the batch size is big enough to prevent GPU memory transfers being the bottleneck for training. I've made some notes about this in the description.

@danbraunai-goodfire danbraunai-goodfire changed the title Use bfloat16 autocast for mixed precision training Autocast to bf16 for up to 25% spd speedup Feb 6, 2026
danbraunai-goodfire and others added 2 commits February 6, 2026 17:18
Wraps forward passes (not backward passes) with
torch.autocast(device_type="cuda", dtype=torch.bfloat16) in:
- spd/harvest/harvest.py: main inference loop
- spd/dataset_attributions/harvester.py: both forward passes
- spd/app/backend/compute.py: all forward pass sites
- spd/app/backend/optim_cis.py: forward passes and loss computation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@danbraunai-goodfire danbraunai-goodfire marked this pull request as ready for review February 6, 2026 17:37
@danbraunai-goodfire
Copy link
Collaborator Author

@claude can you review this PR? Note that your review should cover the scope of this PR only. If you spot things unrelated to this PR, feel free to bring them up and we'll consider them for new issues. I want you to think about possible ways that the implementation could be improved. I want critical comments, do feel comfortable to give me the hard truths.

@claude
Copy link
Contributor

claude bot commented Feb 6, 2026

Claude finished @danbraunai-goodfire's task —— View job


PR Review: Autocast to bf16 for up to 25% spd speedup

Todo List:

  • Set up environment and understand repo structure
  • Analyze config changes in spd/configs.py
  • Review main SPD optimization loop changes in spd/run_spd.py
  • Review harvest module changes in spd/harvest/harvest.py
  • Review dataset attributions changes in spd/dataset_attributions/harvester.py
  • Review app backend changes in spd/app/backend/compute.py
  • Review app backend optimization changes in spd/app/backend/optim_cis.py
  • Provide comprehensive code review feedback

Review Summary

This PR implements autocast to bfloat16 consistently across the codebase for performance improvements. The implementation shows good consistency and appropriate usage patterns. However, there are several areas for improvement:

Critical Issues

1. Inconsistent CPU Fallback Behavior

The autocast implementation doesn't handle CPU execution properly. On CPU, torch.autocast(device_type="cuda", dtype=torch.bfloat16) will raise an error since CUDA autocast isn't available on CPU devices.

Problem locations:

  • spd/run_spd.py:128-132 - Conditional context manager doesn't check device type
  • All other locations use hardcoded device_type="cuda"

Suggested fix:

# In run_spd.py
autocast_ctx = (
    torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16)
    if config.autocast_bf16 and device.startswith('cuda')
    else contextlib.nullcontext()
)

# For other files, determine device dynamically:
device_type = "cuda" if torch.cuda.is_available() and next(model.parameters()).is_cuda else "cpu"
autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16)

2. Potential Mixed Precision Training Issues

The current implementation only applies autocast to forward passes, but there could be implications for gradient computation in run_spd.py.

Problem:

  • Line 278: microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() happens outside the autocast context
  • This could lead to gradient scaling issues or inconsistent precision

Suggestion: Verify that gradient computation behavior is intended and document why backward passes are outside autocast.

Code Quality Issues

3. Hardcoded Device Type String

Multiple files use hardcoded device_type="cuda" string literals:

  • spd/harvest/harvest.py:244
  • spd/dataset_attributions/harvester.py:152, 163
  • spd/app/backend/compute.py:143, 171, etc. (9 locations)
  • spd/app/backend/optim_cis.py:149, 212, etc. (5 locations)

Suggested improvement: Extract to a constant or utility function for maintainability:

# In spd/utils/device_utils.py
def get_autocast_context(enable: bool = True) -> contextlib.AbstractContextManager:
    device_type = "cuda" if torch.cuda.is_available() else "cpu"
    return (
        torch.autocast(device_type=device_type, dtype=torch.bfloat16)
        if enable and torch.cuda.is_available()
        else contextlib.nullcontext()
    )

4. Missing Error Handling

No graceful fallback if bfloat16 isn't supported on the hardware.

5. Config Documentation Gap

The config field autocast_bf16 in spd/configs.py:396-399 has minimal documentation. Should clarify:

  • Performance implications
  • Hardware requirements
  • Interaction with TF32 settings

Performance & Testing

6. Memory Usage Claims Need Verification

The PR description claims memory reduction (20509MB vs 23265MB), but this needs verification across different model sizes and batch sizes. Consider adding memory profiling to tests.

7. Missing Performance Regression Testing

No automated way to detect if autocast accidentally gets disabled or causes performance regressions.

Minor Issues

8. Code Duplication

The same torch.autocast(device_type="cuda", dtype=torch.bfloat16) pattern is repeated 14+ times across files. Consider extracting to a utility.

9. Inconsistent Context Manager Style

Some places use with torch.no_grad(), torch.autocast(...) while others separate the contexts. Consider standardizing on one approach for readability.

Positive Aspects

  • Consistent application across all forward pass sites
  • Good default value (True) for the config option
  • Comprehensive coverage of harvest, attributions, and app modules
  • Performance testing shows significant improvements
  • Backward compatibility maintained with configurable option

Recommendations

  1. Fix CPU compatibility as highest priority - this will cause runtime errors
  2. Extract autocast logic to utility function for maintainability
  3. Add error handling for unsupported hardware
  4. Document gradient computation behavior in mixed precision context
  5. Consider adding performance regression tests for CI

The core implementation is solid and the performance gains are compelling, but the CPU compatibility issue needs to be addressed before merging.

--- | Branch: refactor/use-tf32-config

@danbraunai-goodfire
Copy link
Collaborator Author

Addressed the claude comments about handling CPU.

Copy link
Collaborator

@ocg-goodfire ocg-goodfire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

params = ci_params.get_parameters()
optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay)

autocast_ctx = bf16_autocast()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do it like this instead of line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this to avoid having to create the object everytime in the for loop. But it is a negligible cost, and it's a bit cleaner with it inline, so I'll just put it inline.

spd/run_spd.py Outdated
"""Run the optimization loop for LM decomposition."""

torch.set_float32_matmul_precision("high")
autocast_ctx = bf16_autocast(enabled=config.autocast_bf16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here, why not inline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

@danbraunai-goodfire danbraunai-goodfire merged commit 3687c16 into main Feb 7, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants