Skip to content

Commit

Permalink
Fix errors from flake8
Browse files Browse the repository at this point in the history
```
git ls-files "*.py" | xargs flake8 --count --select=E9,F63,F7,F82 --show-source --statistics
```
  • Loading branch information
JasonGross authored and tkwa committed Sep 10, 2023
1 parent e44360c commit 4fa4d39
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion training/Proving_How_A_Transformer_Takes_Max.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@

# In[ ]:


all_integers_result = simpler_model(all_integers)
print(f"loss: {loss_fn(all_integers_result, all_integers)}")
print(f"acc: {acc_fn(all_integers_result, all_integers)}")

Expand Down
24 changes: 12 additions & 12 deletions training/analyze_maxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def find_d_score_coeff(model) -> float:
# plt.hist(points.flatten())
# %%
# 2d plot of x_scores
plt.imshow(x_scores.detach().cpu().numpy())
#plt.imshow(x_scores.detach().cpu().numpy())
# Set axis labels
plt.title("Attention scores")
plt.xlabel("Key token")
plt.ylabel("Query token")
#plt.title("Attention scores")
#plt.xlabel("Key token")
#plt.ylabel("Query token")

# %%
list(enumerate(model(torch.tensor([1, 1, 1, 18, 19]))[0, -1, :]))
Expand Down Expand Up @@ -181,14 +181,14 @@ def find_d_score_coeff(model) -> float:
plt.text(i,j,f'{label:.3f}',ha='center',va='center')
# %%

last_resid = (W_E + W_pos[-1]) # (d_vocab, d_model). Rows = possible residual streams.
key_tok_resid = (W_E + W_pos[0]) # (d_model, d_vocab). Rows = possible residual streams.
q = last_resid @ W_Q[0, 0, :, :] # (d_vocab, d_model).
k = key_tok_resid @ W_K[0, 0, :, :] # (d_vocab, d_model).
x_scores = q @ k.T # (d_vocab, d_vocab).
#last_resid = (W_E + W_pos[-1]) # (d_vocab, d_model). Rows = possible residual streams.
#key_tok_resid = (W_E + W_pos[0]) # (d_model, d_vocab). Rows = possible residual streams.
#q = last_resid @ W_Q[0, 0, :, :] # (d_vocab, d_model).
#k = key_tok_resid @ W_K[0, 0, :, :] # (d_vocab, d_model).
#x_scores = q @ k.T # (d_vocab, d_vocab).

scores = x_scores.detach().cpu().numpy()
print(f"{scores[25, 23]=}, {scores[25, 25]=}")
#scores = x_scores.detach().cpu().numpy()
#print(f"{scores[25, 23]=}, {scores[25, 25]=}")
# %%
# There's some kind of mismatch between cached scores and the attention influences
# calculated above.
Expand All @@ -199,7 +199,7 @@ def find_d_score_coeff(model) -> float:
k_cached = cache['k', 0].detach().cpu().numpy()[0, :, 0, :]
k_cached.shape # (n_ctx, d_model)

scores_cached = q_cached @ k_cached.T / np.sqrt(d_model)
scores_cached = q_cached @ k_cached.T / np.sqrt(model.cfg.d_model)
# %%
plt.imshow(scores_cached[-1:, :])
for (j, i), label in np.ndenumerate(scores_cached[-1:, :]):
Expand Down
2 changes: 1 addition & 1 deletion training/undertrained_max2.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@

# In[ ]:


all_integers_result = undertrain_simpler_model(all_integers)
print(f"loss: {loss_fn(all_integers_result, all_integers)}")
print(f"acc: {acc_fn(all_integers_result, all_integers)}")

Expand Down

0 comments on commit 4fa4d39

Please sign in to comment.