Skip to content

Commit

Permalink
modify train_max_of_n to remove all biases
Browse files Browse the repository at this point in the history
tkwa committed Sep 13, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent b1a5bf5 commit 0b1c487
Showing 1 changed file with 81 additions and 125 deletions.
206 changes: 81 additions & 125 deletions training/train_max_of_n.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,35 @@
# %%
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig
import tqdm.auto as tqdm
import circuitsvis as cv
from fancy_einsum import einsum
import dataclasses
from pathlib import Path
import wandb
import datetime

from coq_export_utils import strify
# from analysis_utils import line, summarize, plot_QK_cosine_similarity, \
# analyze_svd, calculate_OV_of_pos_embed, calculate_attn, calculate_attn_by_pos, \
# calculate_copying, calculate_copying_with_pos, calculate_embed_and_pos_embed_overlap, \
# calculate_embed_overlap, calculate_pos_embed_overlap, check_monotonicity, \
# compute_slack, plot_avg_qk_heatmap, plot_qk_heatmap, plot_qk_heatmaps_normed, plot_unembed_cosine_similarity
from training_utils import train_or_load_model, make_testset_trainset, make_generator_from_data
from max_of_n import acc_fn, loss_fn, large_data_gen
from training_utils import make_testset_trainset, make_generator_from_data, DEFAULT_WANDB_ENTITY
from coq_export_utils import coq_export_params
from max_of_n import acc_fn, loss_fn, train_model, large_data_gen
from training_utils import compute_all_tokens, make_testset_trainset, make_generator_from_data

import os, sys
from importlib import reload

from scipy.optimize import curve_fit



# %%

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 32
D_HEAD = 32
D_MLP = None
D_VOCAB = 64
SEED = 123
N_EPOCHS = 100
DETERMINISTIC = False # @param
DEVICE = "cuda" if torch.cuda.is_available() and not DETERMINISTIC else "cpu"
N_LAYERS = 1 # @param
N_HEADS = 1 # @param
D_MODEL = 32 # @param
D_HEAD = 32 # @param
D_MLP = None # @param
D_VOCAB = 64 # @param
SEED = 123 # @param
N_EPOCHS = 50000 # @param
N_CTX = 5 # @param
ADJACENT_FRACTION=0.3 # @param
BATCH_SIZE = 128 # @param
FAIL_IF_CANT_LOAD = '--fail-if-cant-load' in sys.argv[1:] # @param

ALWAYS_TRAIN_MODEL = False # @param
SAVE_IN_GOOGLE_DRIVE = False # @param
OVERWRITE_DATA = False # @param
TRAIN_MODEL_IF_CANT_LOAD = True # @param

ALWAYS_TRAIN_MODEL = False
IN_COLAB = False
SAVE_IN_GOOGLE_DRIVE = False
OVERWRITE_DATA = True
TRAIN_MODEL_IF_CANT_LOAD = True


# %%

# %%

@@ -60,7 +38,7 @@
n_layers=N_LAYERS,
n_heads=N_HEADS,
d_head=D_HEAD,
n_ctx=5,
n_ctx=N_CTX,
d_vocab=D_VOCAB,
seed=SEED,
device=DEVICE,
@@ -71,87 +49,65 @@

model = HookedTransformer(simpler_cfg).to(DEVICE)

for name, param in model.named_parameters():
if "b_" in name:
param.requires_grad = False

model_is_trained = False


# %%

def train(fail_if_cant_load=FAIL_IF_CANT_LOAD, train_if_cant_load=TRAIN_MODEL_IF_CANT_LOAD, overwrite_data=OVERWRITE_DATA,
always_train_model=ALWAYS_TRAIN_MODEL,
wandb_entity=DEFAULT_WANDB_ENTITY,
save_in_google_drive=SAVE_IN_GOOGLE_DRIVE):

global model_is_trained
train_data_gen = large_data_gen(n_digits=model.cfg.d_vocab, sequence_length=model.cfg.n_ctx, batch_size=BATCH_SIZE, context="train", device=DEVICE, adjacent_fraction=ADJACENT_FRACTION)
test_data_gen = large_data_gen(n_digits=model.cfg.d_vocab, sequence_length=model.cfg.n_ctx, batch_size=BATCH_SIZE * 20, context="test", adjacent_fraction=ADJACENT_FRACTION)
data_test = next(test_data_gen)

training_losses, model_pth_path = train_or_load_model(
f'neural-net-coq-interp-max-{model.cfg.n_ctx}-epochs-{N_EPOCHS}',
model,
loss_fn=loss_fn,
acc_fn=acc_fn,
train_data_gen_maybe_lambda=train_data_gen,
train_data_gen_is_lambda=False,
data_test=data_test,
n_epochs=N_EPOCHS,
batch_size=BATCH_SIZE,
adjacent_fraction=1,
use_complete_data=True,
batches_per_epoch=10,
wandb_project=f'neural-net-coq-interp-max-{model.cfg.n_ctx}-epochs-{N_EPOCHS}',
deterministic=DETERMINISTIC,
save_in_google_drive=save_in_google_drive,
overwrite_data=overwrite_data,
train_model_if_cant_load=train_if_cant_load,
model_description=f"trained max of {model.cfg.n_ctx} model on {DEVICE}",
save_model=True,
force_train=always_train_model,
wandb_entity=wandb_entity,
fail_if_cant_load=fail_if_cant_load,
)

model_is_trained = True
return training_losses, model_pth_path

# %%

# test large_data_gen
gen = large_data_gen(n_digits=10, sequence_length=5, batch_size=128, context="train", device=DEVICE, adjacent_fraction=0.5)
gen.__next__()
def get_model(train_if_necessary = False, **kwargs):

train(fail_if_cant_load = not train_if_necessary, train_if_cant_load = train_if_necessary, **kwargs)

return model


# %%
if __name__ == '__main__':
training_losses, model_pth_path = train()
print(coq_export_params(model))

# where we save the model
if IN_COLAB:
# if SAVE_IN_GOOGLE_DRIVE:
# from google.colab import drive
# drive.mount('/content/drive/')
# PTH_BASE_PATH = Path('/content/drive/MyDrive/Colab Notebooks/')
# else:
# PTH_BASE_PATH = Path("/workspace/_scratch/")
pass
else:
PTH_BASE_PATH = Path(os.getcwd())

PTH_BASE_PATH = PTH_BASE_PATH / 'trained-models'

if not os.path.exists(PTH_BASE_PATH):
os.makedirs(PTH_BASE_PATH)

cfg_dict = simpler_cfg.to_dict()
cfg_str = '-'.join([f'{k}={cfg_dict[k]}' for k in 'n_ctx d_model n_layers n_heads d_head d_vocab'.split(' ')])
datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

MODEL_PTH_PATH = PTH_BASE_PATH / f'max-of-n-{cfg_str}-{datetime_str}.pth'


TRAIN_MODEL = ALWAYS_TRAIN_MODEL
if not ALWAYS_TRAIN_MODEL:
try:
cached_data = torch.load(MODEL_PTH_PATH)
model.load_state_dict(cached_data['model'])
#model_checkpoints = cached_data["checkpoints"]
#checkpoint_epochs = cached_data["checkpoint_epochs"]
#test_losses = cached_data['test_losses']
simpler_train_losses = cached_data['train_losses']
#train_indices = cached_data["train_indices"]
#test_indices = cached_data["test_indices"]
except Exception as e:
print(e)
TRAIN_MODEL = TRAIN_MODEL_IF_CANT_LOAD


# In[ ]:


if TRAIN_MODEL:
wandb.init(project=f'neural-net-coq-interp-max-{model.cfg.n_ctx}')

simpler_train_losses = train_model(model, n_epochs=100, batch_size=256, batches_per_epoch=10,
adjacent_fraction=0.3, use_complete_data=False, device=DEVICE, use_wandb=True)

wandb.finish()


# In[ ]:


if TRAIN_MODEL:
data = {
"model":model.state_dict(),
"config": model.cfg,
"train_losses": simpler_train_losses,
}
if OVERWRITE_DATA or not os.path.exists(MODEL_PTH_PATH):
torch.save(
data,
MODEL_PTH_PATH)
else:
print(f'WARNING: Not overwriting {MODEL_PTH_PATH} because it already exists.')
ext = 0
while os.path.exists(f"{MODEL_PTH_PATH}.{ext}"):
ext += 1
torch.save(
data,
f"{MODEL_PTH_PATH}.{ext}")
print(f'WARNING: Wrote to {MODEL_PTH_PATH}.{ext} instead.')

# %%
# %%

0 comments on commit 0b1c487

Please sign in to comment.