Autocast to bf16 for up to 25% spd speedup#359
Conversation
Previously TF32 was always enabled unconditionally. This adds a `use_tf32` config field (default True) so it can be disabled when needed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Same training performance as TF32 but uses less memory (20509MB vs 23265MB for bs=128 ss_llama_simple_mlp-1L). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
More directly describes the mechanism (torch.autocast with bf16). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
how much faster is it? |
depends on batch size. I think we'll get about 25% faster assuming the batch size is big enough to prevent GPU memory transfers being the bottleneck for training. I've made some notes about this in the description. |
Wraps forward passes (not backward passes) with torch.autocast(device_type="cuda", dtype=torch.bfloat16) in: - spd/harvest/harvest.py: main inference loop - spd/dataset_attributions/harvester.py: both forward passes - spd/app/backend/compute.py: all forward pass sites - spd/app/backend/optim_cis.py: forward passes and loss computation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@claude can you review this PR? Note that your review should cover the scope of this PR only. If you spot things unrelated to this PR, feel free to bring them up and we'll consider them for new issues. I want you to think about possible ways that the implementation could be improved. I want critical comments, do feel comfortable to give me the hard truths. |
|
Claude finished @danbraunai-goodfire's task —— View job PR Review: Autocast to bf16 for up to 25% spd speedupTodo List:
Review SummaryThis PR implements autocast to bfloat16 consistently across the codebase for performance improvements. The implementation shows good consistency and appropriate usage patterns. However, there are several areas for improvement: Critical Issues1. Inconsistent CPU Fallback BehaviorThe autocast implementation doesn't handle CPU execution properly. On CPU, Problem locations:
Suggested fix: # In run_spd.py
autocast_ctx = (
torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16)
if config.autocast_bf16 and device.startswith('cuda')
else contextlib.nullcontext()
)
# For other files, determine device dynamically:
device_type = "cuda" if torch.cuda.is_available() and next(model.parameters()).is_cuda else "cpu"
autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16)2. Potential Mixed Precision Training IssuesThe current implementation only applies autocast to forward passes, but there could be implications for gradient computation in Problem:
Suggestion: Verify that gradient computation behavior is intended and document why backward passes are outside autocast. Code Quality Issues3. Hardcoded Device Type StringMultiple files use hardcoded
Suggested improvement: Extract to a constant or utility function for maintainability: # In spd/utils/device_utils.py
def get_autocast_context(enable: bool = True) -> contextlib.AbstractContextManager:
device_type = "cuda" if torch.cuda.is_available() else "cpu"
return (
torch.autocast(device_type=device_type, dtype=torch.bfloat16)
if enable and torch.cuda.is_available()
else contextlib.nullcontext()
)4. Missing Error HandlingNo graceful fallback if bfloat16 isn't supported on the hardware. 5. Config Documentation GapThe config field
Performance & Testing6. Memory Usage Claims Need VerificationThe PR description claims memory reduction (20509MB vs 23265MB), but this needs verification across different model sizes and batch sizes. Consider adding memory profiling to tests. 7. Missing Performance Regression TestingNo automated way to detect if autocast accidentally gets disabled or causes performance regressions. Minor Issues8. Code DuplicationThe same 9. Inconsistent Context Manager StyleSome places use Positive Aspects
Recommendations
The core implementation is solid and the performance gains are compelling, but the CPU compatibility issue needs to be addressed before merging. |
|
Addressed the claude comments about handling CPU. |
spd/app/backend/optim_cis.py
Outdated
| params = ci_params.get_parameters() | ||
| optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) | ||
|
|
||
| autocast_ctx = bf16_autocast() |
There was a problem hiding this comment.
why do it like this instead of line?
There was a problem hiding this comment.
I did this to avoid having to create the object everytime in the for loop. But it is a negligible cost, and it's a bit cleaner with it inline, so I'll just put it inline.
spd/run_spd.py
Outdated
| """Run the optimization loop for LM decomposition.""" | ||
|
|
||
| torch.set_float32_matmul_precision("high") | ||
| autocast_ctx = bf16_autocast(enabled=config.autocast_bf16) |
There was a problem hiding this comment.
same question here, why not inline?
There was a problem hiding this comment.
see above
No functional change — the context manager is lightweight and creating it each iteration is equivalent. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Description
Add
autocast_bf16config option (default true) for SPD training, and add bf16 autocast to harvest, dataset attributions, and app backend. All forward passes (not backward) are wrapped withtorch.autocast(device_type="cuda", dtype=torch.bfloat16).Files with autocast:
spd/run_spd.py— main SPD optimization loop (configurable viaautocast_bf16)spd/harvest/harvest.py— inference loopspd/dataset_attributions/harvester.py— both forward passesspd/app/backend/compute.py— all forward pass sites (CI, attributions, interventions)spd/app/backend/optim_cis.py— forward passes and loss computation in CI optimizationMotivation and Context
Speed.
How Has This Been Tested?
I have not done any tests on anything other than SPD training. I'm just assuming that it'll be fine because it's the same computations.
Testing autocast on bs=128 ss_llama_simple_mlp-1L. In the below, TF32 means using
torch.set_float32_matmul_precision("high"), which reduces the precision for float32 matmuls. It basically does the same thing but doesn't reduce the memory consumption like autocast_bf16 does.With TF32 With autocast: 1h37m
With TF32 Without autocast: 1h35m. VRAM 23265
Without TF32 With autocast: 1h36m. VRAM 20509
Without TF32 Without autocast: 2h6m
Run without TF32 without autocast
Run with TF32 without autocast
Run without TF32 with autocast.
All look about identical. L0 is the same. Images below for losses (dark pink = autocast, light pink = TF32, brown = nothing)
Does this PR introduce a breaking change?
No. The new autocast_bf16 defaults to true. Harvest, dataset attributions, and app backend always use autocast (no-op on CPU).