Skip to content

Commit

Permalink
Clean up indexing slightly in interp_max_utils
Browse files Browse the repository at this point in the history
Also remove unused imports and comment out old code in
proof_max2_01_exhaustive
  • Loading branch information
JasonGross authored and tkwa committed Sep 14, 2023
1 parent 092d2ac commit d02b030
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 199 deletions.
8 changes: 4 additions & 4 deletions training/interp_max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ def logit_delta(model: HookedTransformer, renderer=None, histogram_all_incorrect
"""

all_tokens = compute_all_tokens(model=model)
predicted_logits = model(all_tokens)[:,-1].detach().cpu()
predicted_logits = model(all_tokens)[:,-1,:].detach().cpu()

# Extract statistics for each row
# Use values in all_tokens as indices to gather correct logits
indices_of_max = all_tokens.max(dim=1, keepdim=True).values
correct_logits = torch.gather(predicted_logits, 1, indices_of_max)
indices_of_max = all_tokens.max(dim=-1, keepdim=True).values
correct_logits = torch.gather(predicted_logits, -1, indices_of_max)
logits_above_correct = correct_logits - predicted_logits
# replace correct logit indices with large number so that they don't get picked up by the min
logits_above_correct[torch.arange(logits_above_correct.shape[0]), indices_of_max.squeeze()] = float('inf')
min_incorrect_logit = logits_above_correct.min(dim=1).values
min_incorrect_logit = logits_above_correct.min(dim=-1).values

if histogram_all_incorrect_logit_differences:
all_incorrect_logits = logits_above_correct[logits_above_correct != float('inf')]
Expand Down
Loading

0 comments on commit d02b030

Please sign in to comment.