Skip to content

Mask the full weight for shared_across_batch#323

Draft
danbraunai-goodfire wants to merge 1 commit intomainfrom
feature/pgd-weight-mask
Draft

Mask the full weight for shared_across_batch#323
danbraunai-goodfire wants to merge 1 commit intomainfrom
feature/pgd-weight-mask

Conversation

@danbraunai-goodfire
Copy link
Collaborator

@danbraunai-goodfire danbraunai-goodfire commented Jan 4, 2026

Description

messy test for a speedup which takes advantage of the fact that PGD shared_across_batch doesn't have mask values for the batch and sequence position. This let's us mask V directly instead of doing x @ V and then masking the acts. Results of this on ss_llama_simple_mlp (4L) are below:

  • 1-PGD-steps with weight mask: 16:10. without weight mask: 18:55.
  • 4-PGD-steps with weight mask: 24:00. without weight mask: 30:00.
  • Noting that it also very slightly reduces the VRAM consumption
    So this is a reasonable speedup. The code is quite gross, but there are probably ways to improve it so that it doesn't add much bloat to main.

Related Issue

Motivation and Context

How Has This Been Tested?

Does this PR introduce a breaking change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant