Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SimJeg committed Jan 13, 2025
1 parent 9edb610 commit 3ac3df2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 101 deletions.
2 changes: 0 additions & 2 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def compress(
if self.press.compression_ratio == 0:
return keys, values

assert isinstance(self.press, ScorerPress)

# Compute scores from base press
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)

Expand Down
109 changes: 10 additions & 99 deletions notebooks/new_press.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from contextlib import contextmanager\n",
"\n",
"import torch\n",
"from torch import nn\n",
Expand All @@ -29,101 +30,11 @@
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6758cc9db344df3840d72945b23f5d2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 59%|#####8 | 1.81G/3.09G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0cfded89ef6a414ab7f64abd7852edac",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/242 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "092093357aa544718e7bafb54880d767",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/7.30k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1bd20e2a47424e38ad3096f69137f6bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.json: 0%| | 0.00/2.78M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a96cb348c3e413885d7c58e8edb4c8f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"merges.txt: 0%| | 0.00/1.67M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c7e4079761954c6c9a695bc98a9b513d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/7.03M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
"Device set to use cuda:0\n"
]
}
Expand Down Expand Up @@ -336,14 +247,14 @@
"\n",
"Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. \n",
"\n",
"To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected (see attention_patch.py). \n",
"To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.\n",
"\n",
"To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -361,10 +272,10 @@
"Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the\n",
"\n",
"compression_ratio: 0.25\n",
"Answer: The purpose of this guide is to teach you how to create a new press in kvpress.\n",
"Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing\n",
"\n",
"compression_ratio: 0.8\n",
"Answer: This guide is designed to help you understand the purpose of this guide.\n"
"compression_ratio: 0.9\n",
"Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a\n"
]
}
],
Expand All @@ -373,16 +284,16 @@
"class RandomHeadPress(BasePress):\n",
"\n",
" compression_ratio: float = 0.0\n",
"\n",
" def compress(self, module, hidden_states, keys, values, attentions, kwargs):\n",
" assert keys.shape[0] == 1, \"Only batch size 1 is supported\"\n",
" scores = torch.rand(keys.shape[:-1], device=keys.device)\n",
" mask = scores < torch.quantile(scores, self.compression_ratio)\n",
" module.masked_key_indices = torch.nonzero(mask, as_tuple=True)\n",
"\n",
" \n",
" return keys, values\n",
"\n",
"\n",
"for compression_ratio in [0, 0.25, 0.8]:\n",
"for compression_ratio in [0, 0.25, 0.9]:\n",
" press = RandomHeadPress(compression_ratio)\n",
" print(f\"\\ncompression_ratio: {compression_ratio}\")\n",
" print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
Expand Down

0 comments on commit 3ac3df2

Please sign in to comment.