@@ -286,6 +286,20 @@ def _float8_linear_impl(
286286        f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )}  
287287    )
288288
289+     # TODO: make this better 
290+     # During the backward pass, we transpose the weight tensor, 
291+     # so if the weight tensor was originally rowwise quantized, 
292+     # now it becomes colwise. In this case, simply dequantize 
293+     # the tensor and do a bf16 matmul 
294+     is_backward  =  (
295+         weight_tensor .block_size [0 ] ==  weight_tensor .shape [0 ] and 
296+         weight_tensor .block_size [1 ] ==  1 
297+     )
298+     if  is_backward :
299+         return  torch .nn .functional .linear (
300+             input_tensor , weight_tensor .dequantize (), bias ,
301+         )
302+ 
289303    act_quant_kwargs  =  weight_tensor .act_quant_kwargs 
290304    # quantizing activation, if `act_quant_kwargs` is specified 
291305    if  act_quant_kwargs  is  not None :
@@ -321,8 +335,7 @@ def _float8_linear_impl(
321335            wq  =  weight_tensor .qdata 
322336            x_scale  =  input_tensor .scale 
323337            w_scale  =  weight_tensor .scale 
324-             # TODO: fix this? 
325-             if  True :  # _is_rowwise_scaled(weight_tensor): 
338+             if  _is_rowwise_scaled (weight_tensor ):
326339                assert  _is_rowwise_scaled (input_tensor ), (
327340                    "Input tensor must be rowwise block size" 
328341                )
0 commit comments