Skip to content

Commit

Permalink
Quantized KV cache: update quanto (huggingface#31052)
Browse files Browse the repository at this point in the history
* quanto latest version was refactored

* add error msg

* incorrect compare sign

* Update src/transformers/cache_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
zucchini-nlp and amyeroberts authored May 29, 2024
1 parent a564d10 commit d521ba5
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import copy
import importlib.metadata
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging


if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version >= version.parse("0.2.0"):
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4

if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
Expand Down Expand Up @@ -488,6 +492,13 @@ class QuantoQuantizedCache(QuantizedCache):

def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Please upgrade quanto with `pip install -U quanto`"
)

if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")

Expand All @@ -500,9 +511,11 @@ def __init__(self, cache_config: CacheConfig) -> None:
)

self.qtype = qint4 if self.nbits == 4 else qint2
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization

def _quantize(self, tensor, axis):
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size)
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
return qtensor

def _dequantize(self, qtensor):
Expand Down

0 comments on commit d521ba5

Please sign in to comment.