Skip to content

Commit

Permalink
update to add ticks
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jul 16, 2024
1 parent 3a0a218 commit 9e2ec80
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 11 additions & 1 deletion examples/flex_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@ def checkerboard(score, batch, head, q_idx, kv_idx):
return score


SLIDING_WINDOW = 2


def sliding_window_causal(score, b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= SLIDING_WINDOW
return torch.where(causal_mask & window_mask, score, -float("inf"))


if __name__ == "__main__":
B, H, SEQ_LEN, HEAD_DIM = 2, 2, 16, 64
B, H, SEQ_LEN, HEAD_DIM = 2, 2, 6, 64
make_tensor = partial(torch.ones, B, H, SEQ_LEN, HEAD_DIM, device="cuda")
query, key, value = make_tensor(), make_tensor(), make_tensor()

Expand All @@ -29,3 +38,4 @@ def checkerboard(score, batch, head, q_idx, kv_idx):
name="relative_positional",
)
visualize_attention_scores(query, key, checkerboard, name="checkerboard")
visualize_attention_scores(query, key, sliding_window_causal, name="sliding_window_causal")
5 changes: 5 additions & 0 deletions transformer_nuggets/flex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, Callable, Optional
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import math
from torch.nn.attention._flex_attention import (
_score_mod_signature,
Expand Down Expand Up @@ -117,6 +118,10 @@ def visualize_attention_scores(
ax.set_xticklabels([f"KV{i}" for i in range(num_kv_tokens)])
ax.set_yticks(range(num_query_tokens))
ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)])
# Align grid with pixel boundaries
ax.set_xticks(np.arange(-0.5, num_kv_tokens, 1), minor=True)
ax.set_yticks(np.arange(-0.5, num_query_tokens, 1), minor=True)
ax.grid(which="minor", color="black", linestyle="-", linewidth=2)

plt.tight_layout()
plt.savefig(file_path, dpi=300, bbox_inches="tight")
Expand Down

0 comments on commit 9e2ec80

Please sign in to comment.