Skip to content

Commit

Permalink
some tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jul 31, 2024
1 parent 8a41dec commit 52f51f1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
26 changes: 20 additions & 6 deletions examples/flex_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
return causal_mask & window_mask


document_masks = torch.full((2, 6), 0, dtype=torch.int32, device="cuda")
document_masks[:, 3:] = 1
def create_doc_mask_func(document_id):
def doc_mask_wrapper(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]

return doc_mask_wrapper

def doc_mask(b, h, q, kv):
same_doc = document_masks[b, q] == document_masks[b, kv]
return same_doc

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx


# Main execution
Expand All @@ -62,7 +64,19 @@ def main():
query, key, mask_mod=sliding_window_causal, name="sliding_window_causal"
)

visualize_attention_scores(query, key, mask_mod=doc_mask, name="document_masks")
# Example usage:
document_id = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2], dtype=torch.int32, device="cuda")

doc_mask = create_doc_mask_func(document_id)

visualize_attention_scores(
torch.ones(B, H, document_id.numel(), HEAD_DIM, device="cuda"),
torch.ones(B, H, document_id.numel(), HEAD_DIM, device="cuda"),
mask_mod=doc_mask,
name="document_mask",
)

visualize_attention_scores(query, key, mask_mod=causal_mask, name="causal_mask")


if __name__ == "__main__":
Expand Down
15 changes: 8 additions & 7 deletions transformer_nuggets/flex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,19 @@ def visualize_attention_scores(
head_idx=head_idx,
)

suffix_title = f"Batch {batch_idx}, Head {head_idx}"
suffix_title = f"Batch {batch_idx}, Head {head_idx}" if batch_idx != 0 or head_idx != 0 else ""

fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(scores_viz.cpu().detach()[0, 0, :, :], aspect="auto", cmap="viridis")
color = "viridis" if score_mod is not None else "cividis"
im = ax.imshow(scores_viz.cpu().detach()[0, 0, :, :], aspect="auto", cmap=color)
fig.colorbar(im)

title = _name_to_title(name)
file_path = Path(name).with_suffix(".png") if path is None else path.with_suffix(".png")
ax.set_title(f"{title}\n{suffix_title}")
ax.set_title(f"{title}\n{suffix_title}", fontsize=20)

ax.set_xlabel("Key Tokens")
ax.set_ylabel("Query Tokens")
ax.set_xlabel("Key Tokens", fontsize=18)
ax.set_ylabel("Query Tokens", fontsize=18)

# Move y-axis ticks and labels to the top
ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False)
Expand All @@ -131,9 +132,9 @@ def visualize_attention_scores(
num_query_tokens, num_kv_tokens = scores_viz.shape[-2:]
if num_query_tokens <= 32 and num_kv_tokens <= 32:
ax.set_xticks(range(num_kv_tokens))
ax.set_xticklabels([f"KV{i}" for i in range(num_kv_tokens)])
ax.set_xticklabels([f"KV{i}" for i in range(num_kv_tokens)], fontsize=16)
ax.set_yticks(range(num_query_tokens))
ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)])
ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)], fontsize=16)
# Align grid with pixel boundaries
ax.set_xticks(np.arange(-0.5, num_kv_tokens, 1), minor=True)
ax.set_yticks(np.arange(-0.5, num_query_tokens, 1), minor=True)
Expand Down

0 comments on commit 52f51f1

Please sign in to comment.