File tree Expand file tree Collapse file tree 2 files changed +29
-3
lines changed
model_executor/layers/quantization/utils Expand file tree Collapse file tree 2 files changed +29
-3
lines changed Original file line number Diff line number Diff line change @@ -559,15 +559,15 @@ def cutlass_scaled_mm(a: torch.Tensor,
559559 scale_a.shape * [1, 128] == a.shape
560560 scale_b.shape * [128, 128] == b.shape
561561 """
562- assert (b .shape [0 ] % 16 == 0 and b .shape [1 ] % 16 == 0 )
563562 assert (out_dtype is torch .bfloat16 or out_dtype is torch .float16 )
564563 assert bias is None or bias .shape [0 ] == b .shape [
565564 1 ] and bias .dtype == out_dtype
566565
567566 m = a .shape [0 ]
568567 n = b .shape [1 ]
569568
570- if current_platform .is_rocm ():
569+ cutlass_compatible_b = (b .shape [0 ] % 16 == 0 and b .shape [1 ] % 16 == 0 )
570+ if current_platform .is_rocm () or not cutlass_compatible_b :
571571 triton_scaled_mm_module = importlib .import_module (
572572 "vllm.model_executor.layers.quantization.compressed_tensors."
573573 "triton_scaled_mm" )
Original file line number Diff line number Diff line change @@ -85,6 +85,32 @@ def block_dequant(
8585 return x_dq_block
8686
8787
88+ if current_platform .is_rocm ():
89+ from triton .language import core
90+
91+ # NOTE: This can be removed when hip.libdevice.round() is available.
92+ @core .extern
93+ def round_f32 (arg0 , _builder = None ):
94+ return core .extern_elementwise ("" ,
95+ "" , [arg0 ], {
96+ (core .dtype ("fp32" ), ):
97+ ("llvm.round" , core .dtype ("fp32" )),
98+ (core .dtype ("fp64" ), ):
99+ ("llvm.round" , core .dtype ("fp64" )),
100+ },
101+ is_pure = True ,
102+ _builder = _builder )
103+
104+ @triton .jit
105+ def round_int8 (x ):
106+ return round_f32 (x ).to (tl .int8 )
107+ else :
108+
109+ @triton .jit
110+ def round_int8 (x ):
111+ return tl .extra .cuda .libdevice .round (x ).to (tl .int8 )
112+
113+
88114@triton .jit
89115def _per_token_quant_int8 (
90116 x_ptr ,
@@ -106,7 +132,7 @@ def _per_token_quant_int8(
106132 absmax = tl .maximum (tl .max (tl .abs (x )), 1e-10 )
107133 scale_x = absmax / 127
108134 x_q = x * (127 / absmax )
109- x_q = tl . extra . cuda . libdevice . round (x_q ). to ( tl . int8 )
135+ x_q = round_int8 (x_q )
110136
111137 tl .store (xq_ptr + row_id * stride_xq + cols , x_q , mask = mask )
112138 tl .store (scale_ptr + row_id , scale_x )
You can’t perform that action at this time.
0 commit comments