@@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]:
524524
525525
526526class  SoftmaxPattern (QuantizationPattern ):
527- 
528527    def  partition_types (self ) ->  List [OpOverload ]:
529528        return  [torch .ops .aten ._softmax .default ]
530529
@@ -546,3 +545,57 @@ def get_anchors(
546545
547546    def  replacement_op (self ) ->  OpOverload :
548547        return  torch .ops .cadence .quantized_softmax .default 
548+ 
549+ 
550+ class  MixedW8A32LinearPattern (QuantizationPattern ):
551+     def  partition_types (self ) ->  List [OpOverload ]:
552+         return  [torch .ops .aten .linear .default ]
553+ 
554+     def  get_anchors (
555+         self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
556+     ) ->  Tuple [PartitionAnchors , fx .Node ]:
557+         # pyre-ignore[29] 
558+         linear_layer  =  fused_partition [0 ].nodes [- 1 ]
559+ 
560+         # Bail if the arguments have different shapes than expected 
561+         if  len (linear_layer .args ) !=  3  or  len (linear_layer .kwargs ) >  0 :
562+             return  (
563+                 PartitionAnchors (
564+                     empty = True ,
565+                 ),
566+                 linear_layer ,
567+             )
568+ 
569+         input_node  =  linear_layer .args [0 ]
570+         input_shape  =  input_node .meta ["tensor_meta" ].shape 
571+ 
572+         # Bail if the weights are not multiple of 4 (SIMD) 
573+         if  input_shape [- 1 ] %  4  !=  0 :
574+             return  (
575+                 PartitionAnchors (
576+                     empty = True ,
577+                 ),
578+                 linear_layer ,
579+             )
580+         # Currenly only supporting vector-matrix multiplication 
581+         if  len (input_shape ) >  0  and  input_shape [- 2 ] !=  1 :
582+             return  (
583+                 PartitionAnchors (
584+                     empty = True ,
585+                 ),
586+                 linear_layer ,
587+             )
588+ 
589+         return  (
590+             PartitionAnchors (
591+                 inputs = [],
592+                 weights = [(linear_layer , 1 )],
593+                 biases = [(linear_layer , 2 )],
594+                 output = [],
595+                 others = [(linear_layer , 0 )],
596+             ),
597+             linear_layer ,
598+         )
599+ 
600+     def  replacement_op (self ) ->  OpOverload :
601+         return  torch .ops .cadence .quantized_w8a32_linear .default 
0 commit comments