Skip to content

Commit 59a3296

Browse files
committed
Update cache_utils.py
1 parent 9ff0cb6 commit 59a3296

File tree

1 file changed

+41
-12
lines changed

1 file changed

+41
-12
lines changed

src/transformers/cache_utils.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,44 @@ def __init__(self, max_cache_len: int, config: PretrainedConfig):
12361236
self.offloading = True
12371237

12381238

1239-
class QuantoQuantizedCache(Cache):
1239+
class QuantizedCache(Cache):
1240+
"""
1241+
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
1242+
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
1243+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
1244+
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
1245+
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
1246+
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
1247+
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
1248+
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
1249+
See `Cache` for details on common methods that are implemented by all cache classes.
1250+
"""
1251+
1252+
def __init__(
1253+
self,
1254+
backend: str,
1255+
config: PretrainedConfig,
1256+
nbits: int = 4,
1257+
axis_key: int = 0,
1258+
axis_value: int = 0,
1259+
q_group_size: int = 64,
1260+
residual_length: int = 128,
1261+
):
1262+
if backend == "quanto":
1263+
layer_class = QuantoQuantizedLayer
1264+
elif backend == "hqq":
1265+
layer_class = HQQQuantizedLayer
1266+
else:
1267+
raise ValueError(f"Unknown quantization backend `{backend}`")
1268+
1269+
layers = [
1270+
layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
1271+
for _ in range(config.num_hidden_layers)
1272+
]
1273+
super().__init__(layers=layers)
1274+
1275+
1276+
class QuantoQuantizedCache(QuantizedCache):
12401277
"""
12411278
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
12421279
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
@@ -1276,14 +1313,10 @@ def __init__(
12761313
q_group_size: int = 64,
12771314
residual_length: int = 128,
12781315
):
1279-
layers = [
1280-
QuantoQuantizedLayer(nbits, axis_key, axis_value, q_group_size, residual_length)
1281-
for _ in range(config.num_hidden_layers)
1282-
]
1283-
super().__init__(layers=layers)
1316+
super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length)
12841317

12851318

1286-
class HQQQuantizedCache(Cache):
1319+
class HQQQuantizedCache(QuantizedCache):
12871320
"""
12881321
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
12891322
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
@@ -1323,11 +1356,7 @@ def __init__(
13231356
q_group_size: int = 64,
13241357
residual_length: int = 128,
13251358
):
1326-
layers = [
1327-
HQQQuantizedLayer(nbits, axis_key, axis_value, q_group_size, residual_length)
1328-
for _ in range(config.num_hidden_layers)
1329-
]
1330-
super().__init__(layers=layers)
1359+
super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length)
13311360

13321361

13331362
class EncoderDecoderCache(Cache):

0 commit comments

Comments
 (0)