@@ -190,23 +190,53 @@ def _apply_fn_to_data(self, fn):
190190def _linear_fp8_act_fp8_weight_sparse_cutlass_check (input_tensor , weight_tensor , bias ):
191191 from torchao .dtypes .floatx import Float8Layout
192192
193- return (
194- isinstance (input_tensor , AffineQuantizedTensor )
195- and isinstance (input_tensor ._layout , Float8Layout )
196- and input_tensor .dtype in (torch .float16 , torch .bfloat16 )
197- and len (input_tensor .shape ) >= 2
198- and input_tensor .tensor_impl .scale .dtype == torch .float32
199- and len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) - 1
200- and isinstance (weight_tensor , AffineQuantizedTensor )
201- and isinstance (weight_tensor ._layout , CutlassSemiSparseLayout )
202- and weight_tensor .dtype == input_tensor .dtype
203- and len (weight_tensor .shape ) == 2
204- and weight_tensor .tensor_impl .scale .dtype == torch .float32
205- and len (weight_tensor .tensor_impl .scale .shape ) == 1
206- and (bias is None or bias .dtype == input_tensor .dtype )
207- and (bias is None or len (bias .shape ) == 1 )
193+ base_check = (
194+ isinstance (input_tensor , AffineQuantizedTensor ) and
195+ isinstance (input_tensor ._layout , Float8Layout ) and
196+ input_tensor .dtype in (torch .float16 , torch .bfloat16 ) and
197+ len (input_tensor .shape ) >= 2 and
198+ input_tensor .tensor_impl .scale .dtype == torch .float32 and
199+ isinstance (weight_tensor , AffineQuantizedTensor ) and
200+ isinstance (weight_tensor ._layout , CutlassSemiSparseLayout ) and
201+ weight_tensor .dtype == input_tensor .dtype and
202+ len (weight_tensor .shape ) == 2 and
203+ weight_tensor .tensor_impl .scale .dtype == torch .float32 and
204+ (bias is None or bias .dtype == input_tensor .dtype ) and
205+ (bias is None or len (bias .shape ) == 1 )
208206 )
209207
208+ if base_check :
209+
210+ # do extra check and reshape if needed
211+ input_tensor_squeezed = False
212+ if len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) and \
213+ len (input_tensor .tensor_impl .scale .shape ) > 1 and \
214+ input_tensor .tensor_impl .scale .shape [- 1 ] == 1 :
215+ input_tensor .tensor_impl .scale = torch .squeeze (input_tensor .tensor_impl .scale , dim = - 1 )
216+ input_tensor_squeezed = True
217+
218+ weight_tensor_squeezed = False
219+ if len (weight_tensor .tensor_impl .scale .shape ) == 2 and \
220+ weight_tensor .tensor_impl .scale .shape [- 1 ] == 1 :
221+ weight_tensor .tensor_impl .scale = torch .squeeze (weight_tensor .tensor_impl .scale , dim = - 1 )
222+ weight_tensor_squeezed = True
223+
224+ extra_check = (
225+ len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) - 1
226+ and len (weight_tensor .tensor_impl .scale .shape ) == 1
227+ )
228+
229+ if not extra_check : # revert if extra check failed
230+ if input_tensor_squeezed :
231+ input_tensor .tensor_impl .scale = torch .unsqueeze (input_tensor .tensor_impl .scale , dim = - 1 )
232+ if weight_tensor_squeezed :
233+ weight_tensor .tensor_impl .scale = torch .unsqueeze (weight_tensor .tensor_impl .scale , dim = - 1 )
234+
235+ return extra_check
236+
237+ else :
238+ return False
239+
210240
211241def _linear_fp8_act_fp8_weight_sparse_cutlass_impl (input_tensor , weight_tensor , bias ):
212242 from torchao .ops import rowwise_scaled_linear_sparse_cutlass_f8f8
0 commit comments