Skip to content

Support training Transcoders #45

@danbraunai-apollo

Description

@danbraunai-apollo

Add a script to train transcoders.

It's probably going to be easier to first start with a single transcoder between two layers, then expand to multiple transcoders in a row. The first PR can handle just the single transcoder case, but the multiple transcoder case should be kept in mind when writing.

The structure could be something like:

  1. Load pretrained SAEs (i.e. the output of one or morerun_train_tlens_saes.py runs) into an SAETransformer class.
  2. Create a new Transcoder/Skeleton/WhateverName class which is simply a sequence of linear layers (in the beginning, just one linear layer).
  3. The training loop will then look like:
    • Load raw tokens
    • Do forward pass of the tlens model through the first SAE and get the c activations. These will be the inputs to the (first) transcoder.
    • Do a forward pass of the tlens model via the second SAE only, i.e. not the first SAE, and get c activations. These will be the labels for the first transcoder. We will need to tweak the SAETransformer.forward method to handle only passing through specific SAEs as opposed to all the SAEs for this.
    • loss.backward and repeat.
    • Note that we'll probably want many of the similar metrics we use in run_train_tlens_saes.py, so can maybe use that train loop as a template.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions