From 52f51f1102eb7b616337ad470ae37b2a9745a1cc Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 30 Jul 2024 19:29:08 -0700 Subject: [PATCH] some tweaks --- examples/flex_viz.py | 26 ++++++++++++++++++++------ transformer_nuggets/flex/utils.py | 15 ++++++++------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/examples/flex_viz.py b/examples/flex_viz.py index 2a7e35c..d6ab42c 100644 --- a/examples/flex_viz.py +++ b/examples/flex_viz.py @@ -36,13 +36,15 @@ def sliding_window_causal(b, h, q_idx, kv_idx): return causal_mask & window_mask -document_masks = torch.full((2, 6), 0, dtype=torch.int32, device="cuda") -document_masks[:, 3:] = 1 +def create_doc_mask_func(document_id): + def doc_mask_wrapper(b, h, q_idx, kv_idx): + return document_id[q_idx] == document_id[kv_idx] + return doc_mask_wrapper -def doc_mask(b, h, q, kv): - same_doc = document_masks[b, q] == document_masks[b, kv] - return same_doc + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx # Main execution @@ -62,7 +64,19 @@ def main(): query, key, mask_mod=sliding_window_causal, name="sliding_window_causal" ) - visualize_attention_scores(query, key, mask_mod=doc_mask, name="document_masks") + # Example usage: + document_id = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2], dtype=torch.int32, device="cuda") + + doc_mask = create_doc_mask_func(document_id) + + visualize_attention_scores( + torch.ones(B, H, document_id.numel(), HEAD_DIM, device="cuda"), + torch.ones(B, H, document_id.numel(), HEAD_DIM, device="cuda"), + mask_mod=doc_mask, + name="document_mask", + ) + + visualize_attention_scores(query, key, mask_mod=causal_mask, name="causal_mask") if __name__ == "__main__": diff --git a/transformer_nuggets/flex/utils.py b/transformer_nuggets/flex/utils.py index 5f54148..1c5948d 100644 --- a/transformer_nuggets/flex/utils.py +++ b/transformer_nuggets/flex/utils.py @@ -111,18 +111,19 @@ def visualize_attention_scores( head_idx=head_idx, ) - suffix_title = f"Batch {batch_idx}, Head {head_idx}" + suffix_title = f"Batch {batch_idx}, Head {head_idx}" if batch_idx != 0 or head_idx != 0 else "" fig, ax = plt.subplots(figsize=(12, 10)) - im = ax.imshow(scores_viz.cpu().detach()[0, 0, :, :], aspect="auto", cmap="viridis") + color = "viridis" if score_mod is not None else "cividis" + im = ax.imshow(scores_viz.cpu().detach()[0, 0, :, :], aspect="auto", cmap=color) fig.colorbar(im) title = _name_to_title(name) file_path = Path(name).with_suffix(".png") if path is None else path.with_suffix(".png") - ax.set_title(f"{title}\n{suffix_title}") + ax.set_title(f"{title}\n{suffix_title}", fontsize=20) - ax.set_xlabel("Key Tokens") - ax.set_ylabel("Query Tokens") + ax.set_xlabel("Key Tokens", fontsize=18) + ax.set_ylabel("Query Tokens", fontsize=18) # Move y-axis ticks and labels to the top ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False) @@ -131,9 +132,9 @@ def visualize_attention_scores( num_query_tokens, num_kv_tokens = scores_viz.shape[-2:] if num_query_tokens <= 32 and num_kv_tokens <= 32: ax.set_xticks(range(num_kv_tokens)) - ax.set_xticklabels([f"KV{i}" for i in range(num_kv_tokens)]) + ax.set_xticklabels([f"KV{i}" for i in range(num_kv_tokens)], fontsize=16) ax.set_yticks(range(num_query_tokens)) - ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)]) + ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)], fontsize=16) # Align grid with pixel boundaries ax.set_xticks(np.arange(-0.5, num_kv_tokens, 1), minor=True) ax.set_yticks(np.arange(-0.5, num_query_tokens, 1), minor=True)