Skip to content

Commit

Permalink
Pass on exceptions.
Browse files Browse the repository at this point in the history
```
root@ceb53da4d023:~/sae_expected_l0# python3 main.py
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: huggingface/transformers#31884
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: huggingface/transformers#31884
  warnings.warn(
Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer
Repo card metadata block was not found. Setting CardData to empty.
sigma:   0%|                                                                                                                                                                          | 0/5 [00:00<?, ?it/swandb: Currently logged in as: joelb. Use `wandb login --relogin` to force relogin                                                                                                     | 0/5 [00:00<?, ?it/s]
wandb: Tracking run with wandb version 0.17.6                                                                                                                                         | 0/1 [00:00<?, ?it/s]
wandb: Run data is saved locally in /root/sae_expected_l0/wandb/run-20240811_194131-exu7ftic                                                                                          | 0/3 [00:00<?, ?it/s]
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run unique-snow-4
wandb: ⭐️ View project at https://wandb.ai/joelb/sae_expected_l0
wandb: 🚀 View run at https://wandb.ai/joelb/sae_expected_l0/runs/exu7ftic
sigma:   0%|                                                                                                                                                                        | 0/5 [2:00:45<?, ?it/s]
Traceback (most recent call last):
  File "/root/sae_expected_l0/main.py", line 113, in <module>
    train(model, sae, ds, learning_rate, l0_coefficient)
  File "/root/sae_expected_l0/main.py", line 61, in train
    _, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True)
  File "/usr/local/lib/python3.10/dist-packages/transformer_lens/HookedTransformer.py", line 631, in run_with_cache
    out, cache_dict = super().run_with_cache(
  File "/usr/local/lib/python3.10/dist-packages/transformer_lens/hook_points.py", line 566, in run_with_cache
    model_out = self(*model_args, **model_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_lens/HookedTransformer.py", line 522, in forward
    ) = self.input_to_embed(
  File "/usr/local/lib/python3.10/dist-packages/transformer_lens/HookedTransformer.py", line 330, in input_to_embed
    embed = self.hook_embed(self.embed(tokens))  # [batch, pos, d_model]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_lens/components/embed.py", line 34, in forward
    return self.W_E[tokens, :]
IndexError: tensors used as indices must be long, int, byte or bool tensors
wandb: | 0.011 MB of 0.011 MB uploaded
wandb: Run history:
wandb:             l0_loss ▃▅▃▂▃▂▂▂█▂▁▁▂▃▂▃▁▂▂▂▁▂▂▂▃▂▁▃▁▂▁▂▁▃▁▂▁▂▂▁
wandb:                loss ▃▅▃▂▃▂▂▂█▂▁▁▂▃▂▃▁▂▂▂▁▂▂▂▃▂▁▃▁▂▁▂▁▃▁▂▁▂▂▁
wandb: reconstruction_loss █▃▂▂▂▂▂▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:        total_tokens ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb:
wandb: Run summary:
wandb:             l0_loss 29664.35352
wandb:                loss 29.66452
wandb: reconstruction_loss 0.00016
wandb:        total_tokens 7541274
wandb:
wandb: 🚀 View run unique-snow-4 at: https://wandb.ai/joelb/sae_expected_l0/runs/exu7ftic
wandb: ⭐️ View project at: https://wandb.ai/joelb/sae_expected_l0
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240811_194131-exu7ftic/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
```
  • Loading branch information
joelburget committed Aug 11, 2024
1 parent f0553f9 commit 39208d5
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,34 +55,38 @@ def train(model, sae, ds, learning_rate, l0_coefficient):
i = 0
total_tokens = 0
for input in ds["train"]:
input = input["text"]
tokens = tokenizer(input)["input_ids"]
total_tokens += len(tokens)
_, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True)
x = cache[hook_point]

x_hat, h = sae(x)

reconstruction_loss = criterion(x_hat, x)
l0_loss = sae.expected_l0_loss(h)
loss = reconstruction_loss + l0_coefficient * l0_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()

if i % 10 == 0:
wandb.log(
{
"loss": loss.item(),
"reconstruction_loss": reconstruction_loss.item(),
"l0_loss": l0_loss.item(),
"total_tokens": total_tokens,
}
)
i += 1
if total_tokens > training_tokens:
break
try:
input = input["text"]
tokens = tokenizer(input)["input_ids"]
total_tokens += len(tokens)
_, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True)
x = cache[hook_point]

x_hat, h = sae(x)

reconstruction_loss = criterion(x_hat, x)
l0_loss = sae.expected_l0_loss(h)
loss = reconstruction_loss + l0_coefficient * l0_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()

if i % 10 == 0:
wandb.log(
{
"loss": loss.item(),
"reconstruction_loss": reconstruction_loss.item(),
"l0_loss": l0_loss.item(),
"total_tokens": total_tokens,
}
)
i += 1
if total_tokens > training_tokens:
break
except BaseException as e:
print(e)
pass


if __name__ == "__main__":
Expand Down

0 comments on commit 39208d5

Please sign in to comment.