@@ -202,6 +202,49 @@ void dequantize_per_channel_optimized(
202202 }
203203}
204204
205+ void dequant_and_gemm (
206+ const int64_t m,
207+ const int64_t n,
208+ const int64_t k,
209+ float * qk_data,
210+ const int64_t qk_stride_m,
211+ const MaybeQuantizedMatrixData& v_data,
212+ const int64_t v_stride_n,
213+ float * o_data,
214+ const int64_t o_stride_m,
215+ const float beta) {
216+ std::vector<float > dequantized_v_data (v_data.m * v_data.n );
217+ dequantize_per_channel_optimized (
218+ static_cast <const int8_t *>(v_data.data ),
219+ static_cast <const float *>(v_data.scales ),
220+ static_cast <const int8_t *>(v_data.zero_points ),
221+ dequantized_v_data.data (),
222+ -128 ,
223+ 127 ,
224+ 1 ,
225+ 0 ,
226+ 0 ,
227+ v_data.m ,
228+ v_stride_n,
229+ v_data.n ,
230+ v_data.n ,
231+ v_data.zero_points_stride );
232+ ::executorch::cpublas::gemm (
233+ ::executorch::cpublas::TransposeType::NoTranspose,
234+ ::executorch::cpublas::TransposeType::NoTranspose,
235+ n,
236+ m,
237+ k,
238+ static_cast <float >(1 ),
239+ dequantized_v_data.data(),
240+ v_data.n,
241+ qk_data,
242+ qk_stride_m,
243+ beta,
244+ o_data,
245+ o_stride_m);
246+ }
247+
205248template <typename accum_t >
206249void _qk_at_v_gemm (
207250 const int64_t m,
@@ -216,36 +259,41 @@ void _qk_at_v_gemm(
216259 const accum_t beta) {
217260 if (v_data.dtype == ScalarType::Char) {
218261 if constexpr (std::is_same<accum_t , float >::value) {
219- std::vector<float > dequantized_v_data (v_data.m * v_data.n );
220- dequantize_per_channel_optimized (
221- static_cast <const int8_t *>(v_data.data ),
222- static_cast <const float *>(v_data.scales ),
223- static_cast <const int8_t *>(v_data.zero_points ),
224- dequantized_v_data.data (),
225- -128 ,
226- 127 ,
227- 1 ,
228- 0 ,
229- 0 ,
230- v_data.m ,
231- v_stride_n,
232- v_data.n ,
233- v_data.n ,
234- v_data.zero_points_stride );
235- ::executorch::cpublas::gemm (
236- ::executorch::cpublas::TransposeType::NoTranspose,
237- ::executorch::cpublas::TransposeType::NoTranspose,
238- n,
239- m,
240- k,
241- static_cast <accum_t >(1 ),
242- dequantized_v_data.data(),
243- v_data.n,
244- qk_data,
245- qk_stride_m,
246- beta,
247- o_data,
248- o_stride_m);
262+ if (m > 4 ) {
263+ // For larger batch sizes, dequantize and use BLAS for better
264+ // performance
265+ dequant_and_gemm (
266+ m,
267+ n,
268+ k,
269+ const_cast <float *>(qk_data),
270+ qk_stride_m,
271+ v_data,
272+ v_stride_n,
273+ o_data,
274+ o_stride_m,
275+ beta);
276+ } else {
277+ // For smaller batch sizes, use quantized gemm
278+ int a_stride_m_tmp, b_stride_n_tmp;
279+ auto kernel = torchao::kernels::cpu::quantized_matmul::
280+ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul (
281+ m, n, k, false , false , a_stride_m_tmp, b_stride_n_tmp);
282+ kernel (
283+ m,
284+ n,
285+ k,
286+ qk_data,
287+ qk_stride_m /* lhs_stride_m*/ ,
288+ static_cast <const int8_t *>(v_data.data ),
289+ v_stride_n /* rhs_stride_n*/ ,
290+ o_data,
291+ o_stride_m /* out_stride_n*/ ,
292+ static_cast <const int8_t *>(v_data.zero_points ),
293+ static_cast <const float *>(v_data.scales ),
294+ beta,
295+ v_data.zero_points_stride );
296+ }
249297 } else {
250298 ET_CHECK_MSG (
251299 false , " Accumulation in dtype other than float not supported yet" );
0 commit comments