@@ -89,26 +89,22 @@ def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, b
8989 return y .to (orig_dtype )
9090
9191
92- def _linear_bf16_act_uint4_weight_int8_zero_check (input_tensor , weight_tensor , bias ):
92+ def _linear_fp_act_uint4_weight_int8_zero_check (input_tensor , weight_tensor , bias ):
9393 return (
94- # input is native bfloat16 tensor
9594 not is_traceable_wrapper_subclass (input_tensor )
96- and input_tensor .dtype == torch .bfloat16
9795 and
9896 # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor
9997 isinstance (weight_tensor , AffineQuantizedTensor )
10098 and _aqt_is_xpu_layout_uint4 (weight_tensor )
101- and weight_tensor .dtype == torch .bfloat16
10299 and len (weight_tensor .shape ) == 2
103100 and weight_tensor .zero_point_domain == ZeroPointDomain .INT
104101 and weight_tensor .tensor_impl .scale_and_zero is None
105- and weight_tensor .tensor_impl .scale .dtype == torch .bfloat16
106102 and weight_tensor .tensor_impl .zero .dtype == torch .int8
107103 and isinstance (weight_tensor ._layout , Int4XPULayout )
108104 )
109105
110106
111- def _linear_bf16_act_uint4_weight_int8_zero_impl (input_tensor , weight_tensor , bias ):
107+ def _linear_fp_act_uint4_weight_int8_zero_impl (input_tensor , weight_tensor , bias ):
112108 assert weight_tensor .block_size [0 ] == 1 , (
113109 f"Requires groupwise quantization, got block_size: { weight_tensor .block_size } "
114110 )
@@ -129,7 +125,7 @@ def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bi
129125 orig_act_size = act_mat .size ()
130126 orig_dtype = act_mat .dtype
131127
132- act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ]). to ( torch . bfloat16 )
128+ act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ])
133129
134130 # groupwise int4 quantization
135131 groupsize = weight_tensor .block_size [1 ]
0 commit comments