Skip to content

Speedup with torch compile#322

Draft
danbraunai-goodfire wants to merge 1 commit intomainfrom
play/compile
Draft

Speedup with torch compile#322
danbraunai-goodfire wants to merge 1 commit intomainfrom
play/compile

Conversation

@danbraunai-goodfire
Copy link
Collaborator

@danbraunai-goodfire danbraunai-goodfire commented Jan 4, 2026

extremely messy test of whether we can speed up things by using torch.compile(). This requires getting rid of our hooks and monkeypatching the target model to insert the components. Not sure it'll work with identity components. Anyway, results are apparently:

  | Batch | Seq Len | Mode            | Eager   | Compiled | Speedup       |
  |-------|---------|-----------------|---------|----------|---------------|
  | 16    | 128     | reduce-overhead | 10.56ms | 9.64ms   | 1.09x (9.5%)  |
  | 32    | 256     | reduce-overhead | 12.61ms | 11.20ms  | 1.13x (12.6%) |
  | 64    | 256     | reduce-overhead | 20.70ms | 15.96ms  | 1.30x (30%)   |
  | 128   | 256     | reduce-overhead | 36.19ms | 28.34ms  | 1.28x (28%)   |
  | 64    | 256     | max-autotune    | 20.70ms | 15.38ms  | 1.35x (35%)   |

For a very stripped down version of our computation which does just masked forward and backward passes through the full ss_llama_simple_mlp (4L).
I think our batch size will be <64 in practice to avoid OOMs in our full training setup.
Probably going to leave this on the shelf. Not keen on the drastic core code change and skeptical that we could get >1.2x speedup for our workflows. Though might be worth picking up when we go to bigger models. (edited)

Description

Related Issue

Motivation and Context

How Has This Been Tested?

Does this PR introduce a breaking change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant