generated from ApolloResearch/sample
-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Labels
Description
Add a script to train transcoders.
It's probably going to be easier to first start with a single transcoder between two layers, then expand to multiple transcoders in a row. The first PR can handle just the single transcoder case, but the multiple transcoder case should be kept in mind when writing.
The structure could be something like:
- Load pretrained SAEs (i.e. the output of one or more
run_train_tlens_saes.pyruns) into an SAETransformer class. - Create a new Transcoder/Skeleton/WhateverName class which is simply a sequence of linear layers (in the beginning, just one linear layer).
- The training loop will then look like:
- Load raw tokens
- Do forward pass of the tlens model through the first SAE and get the c activations. These will be the inputs to the (first) transcoder.
- Do a forward pass of the tlens model via the second SAE only, i.e. not the first SAE, and get c activations. These will be the labels for the first transcoder. We will need to tweak the SAETransformer.forward method to handle only passing through specific SAEs as opposed to all the SAEs for this.
- loss.backward and repeat.
- Note that we'll probably want many of the similar metrics we use in
run_train_tlens_saes.py, so can maybe use that train loop as a template.
Reactions are currently unavailable