@@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
437437 for (int n_idx = 0 ; n_idx < WARP_NITER; ++n_idx) {
438438 #pragma unroll
439439 for (int k_idx = 0 ; k_idx < 2 ; ++k_idx) {
440- FType low16 = static_cast <FType>(C_frag[m_idx][n_idx][k_idx * 2 ]);
440+ FType low16 =
441+ ScalarType<FType>::float2num (C_frag[m_idx][n_idx][k_idx * 2 ]);
441442 FType high16 =
442- static_cast <FType>(C_frag[m_idx][n_idx][k_idx * 2 + 1 ]);
443+ ScalarType <FType>:: float2num (C_frag[m_idx][n_idx][k_idx * 2 + 1 ]);
443444 uint32_t tmp = (reinterpret_cast <uint32_t &>(low16) & 0xffff ) |
444445 (reinterpret_cast <uint32_t &>(high16) << 16 );
445446 int sts_offset =
@@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
793794 FT scale_reg[4 ];
794795 *(reinterpret_cast <uint2 *>(scale_reg)) =
795796 *(reinterpret_cast <const uint2 *>(scales + params_nidx));
796- FT zero_reg[4 ] = { 0 } ;
797+ FT zero_reg[4 ];
797798 if (zeros != nullptr ) {
798799 *(reinterpret_cast <uint2 *>(zero_reg)) =
799800 *(reinterpret_cast <const uint2 *>(zeros + params_nidx));
@@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
809810 reinterpret_cast <typename HalfType<FT>::T2*>(&(fval_reg[ni * 4 ])));
810811 #pragma unroll
811812 for (int ki = 0 ; ki < 4 ; ++ki) {
812- fval_reg[ni * 4 + ki] =
813- (fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni];
813+ if (zeros != nullptr ) {
814+ fval_reg[ni * 4 + ki] = __hsub (fval_reg[ni * 4 + ki], zero_reg[ni]);
815+ }
816+ fval_reg[ni * 4 + ki] = __hmul (fval_reg[ni * 4 + ki], scale_reg[ni]);
814817 int sts_offset = sts_base_offset + ((ki / 2 ) * 8 + (ki % 2 )) * 32 +
815818 ((ni + lane_id % 4 ) % 4 ) * 8 ;
816819 smem[sts_offset] = fval_reg[ni * 4 + ki];
0 commit comments