1616)
1717
1818
19- if _is_quanto_greater_than_0_2_5 := is_quanto_greater ("0.2.5" , accept_dev = True ):
20- from optimum .quanto import MaxOptimizer , qint2 , qint4 , quantize_weight
21-
2219if is_hqq_available ():
2320 from hqq .core .quantize import Quantizer as HQQQuantizer
2421
@@ -558,7 +555,7 @@ def __init__(
558555 q_group_size : int = 64 ,
559556 residual_length : int = 128 ,
560557 ):
561- super ().__init__ (self )
558+ super ().__init__ ()
562559 self .nbits = nbits
563560 self .axis_key = axis_key
564561 self .axis_value = axis_value
@@ -635,10 +632,12 @@ def __init__(
635632 residual_length = residual_length ,
636633 )
637634
638- if not _is_quanto_greater_than_0_2_5 :
635+ # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
636+ if is_quanto_greater ("0.2.5" , accept_dev = True ):
637+ from optimum .quanto import MaxOptimizer , qint2 , qint4
638+ else :
639639 raise ImportError (
640640 "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
641- "Detected version {optimum_quanto_version}."
642641 )
643642
644643 if self .nbits not in [2 , 4 ]:
@@ -656,6 +655,8 @@ def __init__(
656655 self .optimizer = MaxOptimizer () # hardcode as it's the only one for per-channel quantization
657656
658657 def _quantize (self , tensor , axis ):
658+ from optimum .quanto import quantize_weight
659+
659660 scale , zeropoint = self .optimizer (tensor , self .qtype , axis , self .q_group_size )
660661 qtensor = quantize_weight (tensor , self .qtype , axis , scale , zeropoint , self .q_group_size )
661662 return qtensor
0 commit comments