@@ -190,23 +190,53 @@ def _apply_fn_to_data(self, fn):
190190def  _linear_fp8_act_fp8_weight_sparse_cutlass_check (input_tensor , weight_tensor , bias ):
191191    from  torchao .dtypes .floatx  import  Float8Layout 
192192
193-     return  (
194-         isinstance (input_tensor , AffineQuantizedTensor )
195-         and  isinstance (input_tensor ._layout , Float8Layout )
196-         and  input_tensor .dtype  in  (torch .float16 , torch .bfloat16 )
197-         and  len (input_tensor .shape ) >=  2 
198-         and  input_tensor .tensor_impl .scale .dtype  ==  torch .float32 
199-         and  len (input_tensor .tensor_impl .scale .shape ) ==  len (input_tensor .shape ) -  1 
200-         and  isinstance (weight_tensor , AffineQuantizedTensor )
201-         and  isinstance (weight_tensor ._layout , CutlassSemiSparseLayout )
202-         and  weight_tensor .dtype  ==  input_tensor .dtype 
203-         and  len (weight_tensor .shape ) ==  2 
204-         and  weight_tensor .tensor_impl .scale .dtype  ==  torch .float32 
205-         and  len (weight_tensor .tensor_impl .scale .shape ) ==  1 
206-         and  (bias  is  None  or  bias .dtype  ==  input_tensor .dtype )
207-         and  (bias  is  None  or  len (bias .shape ) ==  1 )
193+     base_check  =  (
194+         isinstance (input_tensor , AffineQuantizedTensor ) and  
195+         isinstance (input_tensor ._layout , Float8Layout ) and 
196+         input_tensor .dtype  in  (torch .float16 , torch .bfloat16 ) and 
197+         len (input_tensor .shape ) >=  2  and 
198+         input_tensor .tensor_impl .scale .dtype  ==  torch .float32  and 
199+         isinstance (weight_tensor , AffineQuantizedTensor ) and 
200+         isinstance (weight_tensor ._layout , CutlassSemiSparseLayout ) and 
201+         weight_tensor .dtype  ==  input_tensor .dtype  and 
202+         len (weight_tensor .shape ) ==  2  and 
203+         weight_tensor .tensor_impl .scale .dtype  ==  torch .float32  and 
204+         (bias  is  None  or  bias .dtype  ==  input_tensor .dtype ) and 
205+         (bias  is  None  or  len (bias .shape ) ==  1 )
208206    )
209207
208+     if  base_check :
209+         
210+         # do extra check and reshape if needed 
211+         input_tensor_squeezed  =  False 
212+         if  len (input_tensor .tensor_impl .scale .shape ) ==  len (input_tensor .shape ) and  \
213+             len (input_tensor .tensor_impl .scale .shape ) >  1  and  \
214+             input_tensor .tensor_impl .scale .shape [- 1 ] ==  1 :
215+             input_tensor .tensor_impl .scale  =  torch .squeeze (input_tensor .tensor_impl .scale , dim = - 1 )
216+             input_tensor_squeezed  =  True 
217+         
218+         weight_tensor_squeezed  =  False 
219+         if  len (weight_tensor .tensor_impl .scale .shape ) ==  2  and  \
220+             weight_tensor .tensor_impl .scale .shape [- 1 ] ==  1 :
221+             weight_tensor .tensor_impl .scale  =  torch .squeeze (weight_tensor .tensor_impl .scale , dim = - 1 )
222+             weight_tensor_squeezed  =  True 
223+ 
224+         extra_check  =  (
225+             len (input_tensor .tensor_impl .scale .shape ) ==  len (input_tensor .shape ) -  1 
226+             and  len (weight_tensor .tensor_impl .scale .shape ) ==  1 
227+         )
228+ 
229+         if  not  extra_check : # revert if extra check failed 
230+             if  input_tensor_squeezed :
231+                 input_tensor .tensor_impl .scale  =  torch .unsqueeze (input_tensor .tensor_impl .scale , dim = - 1 )
232+             if  weight_tensor_squeezed :
233+                 weight_tensor .tensor_impl .scale  =  torch .unsqueeze (weight_tensor .tensor_impl .scale , dim = - 1 )
234+ 
235+         return  extra_check 
236+ 
237+     else :
238+         return  False 
239+ 
210240
211241def  _linear_fp8_act_fp8_weight_sparse_cutlass_impl (input_tensor , weight_tensor , bias ):
212242    from  torchao .ops  import  rowwise_scaled_linear_sparse_cutlass_f8f8 
0 commit comments