diff --git a/main.py b/main.py index 833373d..d985ad8 100644 --- a/main.py +++ b/main.py @@ -55,34 +55,38 @@ def train(model, sae, ds, learning_rate, l0_coefficient): i = 0 total_tokens = 0 for input in ds["train"]: - input = input["text"] - tokens = tokenizer(input)["input_ids"] - total_tokens += len(tokens) - _, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True) - x = cache[hook_point] - - x_hat, h = sae(x) - - reconstruction_loss = criterion(x_hat, x) - l0_loss = sae.expected_l0_loss(h) - loss = reconstruction_loss + l0_coefficient * l0_loss - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if i % 10 == 0: - wandb.log( - { - "loss": loss.item(), - "reconstruction_loss": reconstruction_loss.item(), - "l0_loss": l0_loss.item(), - "total_tokens": total_tokens, - } - ) - i += 1 - if total_tokens > training_tokens: - break + try: + input = input["text"] + tokens = tokenizer(input)["input_ids"] + total_tokens += len(tokens) + _, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True) + x = cache[hook_point] + + x_hat, h = sae(x) + + reconstruction_loss = criterion(x_hat, x) + l0_loss = sae.expected_l0_loss(h) + loss = reconstruction_loss + l0_coefficient * l0_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % 10 == 0: + wandb.log( + { + "loss": loss.item(), + "reconstruction_loss": reconstruction_loss.item(), + "l0_loss": l0_loss.item(), + "total_tokens": total_tokens, + } + ) + i += 1 + if total_tokens > training_tokens: + break + except BaseException as e: + print(e) + pass if __name__ == "__main__":