15
15
quantize_per_channel_group ,
16
16
)
17
17
18
- from torchao .quantization .granularity import PerGroup , PerRow
18
+ from torchao .quantization .granularity import Granularity , PerAxis , PerGroup , PerRow
19
19
from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
20
20
21
21
logger = logging .getLogger (__name__ )
@@ -366,32 +366,44 @@ def __init__(
366
366
):
367
367
super ().__init__ ()
368
368
self .bit_width = bit_width
369
- self .pack_weights_op = getattr (
370
- torch .ops .torchao , f"_pack_embedding_{ bit_width } bit"
371
- )
372
- self .embedding_op = getattr (torch .ops .torchao , f"_embedding_{ bit_width } bit" )
373
369
374
370
def quantize_and_pack_weights (self , weights , group_size , has_weight_zeros ):
375
371
assert has_weight_zeros , "has_weight_zeros must be True for QuantizedEmbedding"
376
372
num_embeddings , embedding_dim = weights .shape
377
- if group_size == - 1 :
378
- group_size = embedding_dim
379
- self .group_size = group_size
380
373
381
- weight_qvals , weight_scales , weight_zeros = _quantize (
382
- weights , self .group_size , self .bit_width , has_weight_zeros = True
374
+ embedding = torch .nn .Embedding (num_embeddings , embedding_dim )
375
+ embedding .weight = weights
376
+ quantize_ (
377
+ embedding ,
378
+ IntxWeightOnlyConfig (
379
+ weight_dtype = getattr (torch , f"int{ self .bit_width } " ),
380
+ granularity = PerGroup (group_size ) if group_size > 0 else PerAxis (0 ),
381
+ zero_point_domain = ZeroPointDomain .INT
382
+ if has_weight_zeros
383
+ else ZeroPointDomain .NONE ,
384
+ mapping_type = MappingType .ASYMMETRIC ,
385
+ ),
386
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
387
+ )
388
+ weight_qvals , weight_scales , weight_zeros = (
389
+ embedding .weight .tensor_impl .get_plain ()
383
390
)
391
+ weight_scales = weight_scales .reshape (num_embeddings , - 1 )
392
+ weight_zeros = weight_zeros .reshape (num_embeddings , - 1 ).to (torch .int8 )
384
393
self .register_buffer (
385
- "packed_weight_qvals" , self .pack_weights_op (weight_qvals .to (torch .int8 ))
394
+ "packed_weight_qvals" ,
395
+ getattr (torch .ops .torchao , f"_pack_embedding_{ self .bit_width } bit" )(
396
+ weight_qvals .to (torch .int8 )
397
+ ),
386
398
)
387
399
self .num_embeddings = num_embeddings
388
400
self .embedding_dim = embedding_dim
389
401
self .register_buffer ("weight_scales" , weight_scales )
390
- self .register_buffer ("weight_zeros" , weight_zeros . to ( torch . int8 ) )
402
+ self .register_buffer ("weight_zeros" , weight_zeros )
391
403
392
404
def forward (self , x ):
393
405
shape = x .shape
394
- return self .embedding_op (
406
+ return getattr ( torch . ops . torchao , f"_embedding_ { self .bit_width } bit" ) (
395
407
self .packed_weight_qvals ,
396
408
self .num_embeddings ,
397
409
self .embedding_dim ,
@@ -410,38 +422,23 @@ def __init__(
410
422
self .bit_width = bit_width
411
423
412
424
def quantize_and_pack_weights (self , weights , group_size , has_weight_zeros ):
413
- assert (
414
- has_weight_zeros
415
- ), "has_weight_zeros must be True for QuantizedEmbeddingFallback"
416
- num_embeddings , embedding_dim = weights .shape
417
- if group_size == - 1 :
418
- group_size = embedding_dim
419
- self .group_size = group_size
420
-
421
- weight_qvals , weight_scales , weight_zeros = _quantize (
422
- weights , self .group_size , self .bit_width , has_weight_zeros = True
425
+ self .embedding = torch .nn .Embedding (* weights .shape )
426
+ self .embedding .weight = weights
427
+ quantize_ (
428
+ self .embedding ,
429
+ IntxWeightOnlyConfig (
430
+ weight_dtype = getattr (torch , f"int{ self .bit_width } " ),
431
+ granularity = PerGroup (group_size ) if group_size > 0 else PerAxis (0 ),
432
+ zero_point_domain = ZeroPointDomain .INT
433
+ if has_weight_zeros
434
+ else ZeroPointDomain .NONE ,
435
+ mapping_type = MappingType .ASYMMETRIC ,
436
+ ),
437
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
423
438
)
424
- self .weight_qvals = weight_qvals .to (torch .int32 )
425
- self .weight_scales = weight_scales
426
- self .weight_zeros = weight_zeros .to (torch .int32 )
427
439
428
440
def forward (self , x ):
429
- shape = x .shape
430
- res = []
431
- for i in x :
432
- res .append (
433
- dequantize_per_channel_group (
434
- w_int8 = self .weight_qvals [i , :].reshape (1 , - 1 ),
435
- scales = self .weight_scales [i , :].reshape (1 , - 1 ),
436
- zero_points = self .weight_zeros [i , :].reshape (1 , - 1 ),
437
- quant_min = None , # TODO: why is this an arg for this function
438
- quant_max = None , # TODO: why is this an arg for this function
439
- dtype = None , # TODO: why is this an arg for this function
440
- group_size = self .group_size ,
441
- output_dtype = torch .float32 ,
442
- ).reshape (- 1 )
443
- )
444
- return torch .stack (res ).reshape (* shape , - 1 )
441
+ return self .embedding (x )
445
442
446
443
447
444
class QuantizedSharedEmbedding (nn .Module ):
@@ -586,15 +583,16 @@ class EmbeddingQuantizer:
586
583
def __init__ (
587
584
self ,
588
585
weight_dtype : torch .dtype = torch .int4 ,
589
- granularity : Union [ PerRow , PerGroup ] = PerRow ( ),
586
+ granularity : Granularity = PerAxis ( 0 ),
590
587
has_weight_zeros : bool = True ,
591
588
use_fallback : bool = False ,
592
589
):
593
590
bit_width = _dtype_to_bit_width (weight_dtype )
594
591
595
592
if isinstance (granularity , PerGroup ):
596
593
group_size = granularity .group_size
597
- elif isinstance (granularity , PerRow ):
594
+ elif isinstance (granularity , PerAxis ):
595
+ assert granularity .axis == 0
598
596
group_size = - 1
599
597
else :
600
598
raise ValueError (f"Unsupported granularity: { granularity } " )
@@ -630,6 +628,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
630
628
to_linear_activation_quantized ,
631
629
)
632
630
from torchao .quantization .quant_api import (
631
+ IntxWeightOnlyConfig ,
633
632
MappingType ,
634
633
ZeroPointDomain ,
635
634
to_affine_quantized_intx ,
0 commit comments