From 5e4d50ca1f1ea8564b372e5a28536834366bf09b Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 1 Apr 2025 13:43:54 -0700 Subject: [PATCH 01/19] Reintroduce has_weight_zeros as a template param Differential Revision: D71503133 Pull Request resolved: https://github.com/pytorch/ao/pull/1991 --- ..._8bit_activation_groupwise_lowbit_weight.h | 8 +-- .../kernel_1x8x16_f32_neondot-impl.h | 5 +- .../kernels/cpu/aarch64/linear/linear.h | 7 +- .../kernels/cpu/aarch64/tests/test_linear.cpp | 44 ++++++------ .../embedding_xbit/op_embedding_xbit-impl.h | 13 ++-- .../kernel_selector.h | 71 +++++++++++++------ .../packed_weights_format.h | 6 +- .../test_linear_8bit_act_xbit_weight.cpp | 2 +- 8 files changed, 92 insertions(+), 64 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index 4ca9cef54d..9ff75e3344 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -245,7 +245,7 @@ void kernel_1x4x16_f32_neondot( has_clamp); } -template +template void kernel_1x8x16_f32_neondot( // Outputs float32_t* output, @@ -260,10 +260,11 @@ void kernel_1x8x16_f32_neondot( // Ignored if has_clamp = false float clamp_min, float clamp_max, - bool has_weight_zeros, + bool has_weight_zeros_, bool has_bias, bool has_clamp) { - kernel::kernel_1x8x16_f32_neondot( + (void)has_weight_zeros_; // unused + kernel::kernel_1x8x16_f32_neondot( output, output_m_stride, m, @@ -274,7 +275,6 @@ void kernel_1x8x16_f32_neondot( packed_activations, clamp_min, clamp_max, - has_weight_zeros, has_bias, has_clamp); } diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h index 81f6e6b023..7a53c7302c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h @@ -58,7 +58,7 @@ vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) { // Roughly inspired by // https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads -template +template void kernel_1x8x16_f32_neondot( // Outputs float32_t* output, @@ -73,7 +73,6 @@ void kernel_1x8x16_f32_neondot( // Ignored if has_clamp is false float clamp_min, float clamp_max, - bool has_weight_zeros, bool has_bias, bool has_clamp) { assert(k % group_size == 0); @@ -267,7 +266,7 @@ void kernel_1x8x16_f32_neondot( int32x4_t term1_4567 = vmulq_n_s32(weight_qvals_sum, activation_zero); - if (has_weight_zeros) { + if constexpr (has_weight_zeros) { // Compute term2 and term3 int32_t activation_qvals_sum = *((int32_t*)activation_ptr); diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h index cd816dba46..7b983a1929 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h @@ -320,7 +320,7 @@ void prepare_weight_data( bias); } -template +template void kernel( // Outputs float32_t* output, @@ -335,12 +335,13 @@ void kernel( // Ignored if has_clamp = false float clamp_min, float clamp_max, - bool has_weight_zeros, + bool has_weight_zeros_, bool has_bias, bool has_clamp) { + (void)has_weight_zeros_; // unused torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x8x16_f32_neondot( + kernel_1x8x16_f32_neondot( output, output_m_stride, m, diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 2e19a524e5..0157769fec 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -311,7 +311,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot bias_ptr); std::vector output(m * n); - kernel( + kernel( output.data(), /*output_m_stride=*/n, m, @@ -388,13 +388,12 @@ TEST( } } -template +template void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( int m, int k, int n, int group_size, - bool has_weight_zeros, bool has_bias, bool has_clamp) { constexpr int mr = 1; @@ -453,7 +452,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( has_bias ? test_case.bias.data() : nullptr); std::vector output(m * n); - kernel_1x8x16_f32_neondot( + kernel_1x8x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, @@ -476,85 +475,90 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, LUT) { constexpr int weight_nbit = 4; - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); // has_weight_zeros - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ true>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/true, /*has_bias=*/false, /*has_clamp=*/false); // has_bias - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/true, /*has_clamp=*/false); // has_clamp - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/true); // n less than 8 (nr) for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); } // Other bitwidths test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 1>( + /*weight_nbit*/ 1, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 2>( + /*weight_nbit*/ 2, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 3>( + /*weight_nbit*/ 3, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); } diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h index 22b87cfb9e..8113a0566b 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h @@ -253,9 +253,11 @@ Tensor shared_embedding_out_cpu( torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat:: from_packed_weights_header(header); - torchao::ops::linear_8bit_act_xbit_weight::check_format( + + torchao::ops::linear_8bit_act_xbit_weight::check_format( format, - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit); constexpr int nr = 8; constexpr int kr = 16; constexpr int sr = 2; @@ -316,12 +318,7 @@ Tensor shared_embedding_cpu( const Tensor& indices) { Tensor output_tensor = torch::empty({}, torch::kFloat32); shared_embedding_out_cpu( - packed_weights, - group_size, - n, - k, - indices, - output_tensor); + packed_weights, group_size, n, k, indices, output_tensor); return output_tensor; } #endif // USE_ATEN diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 17d7ec13b1..e960a918d8 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -89,9 +89,11 @@ void register_ukernel_config_universal( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - check_format( + + check_format( format, - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit); if (format.nr == 8 && format.kr == 16 && format.sr == 2) { #if defined(TORCHAO_BUILD_CPU_AARCH64) @@ -99,25 +101,50 @@ void register_ukernel_config_universal( log_registration(format, "universal"); namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); + + if (format.has_weight_zeros) { + constexpr bool has_weight_zeros = true; + table.register_ukernel_config( + format, + uarch, + UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}); + } else { + constexpr bool has_weight_zeros = false; + table.register_ukernel_config( + format, + uarch, + UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}); + } return; } #endif // TORCHAO_BUILD_CPU_AARCH64 @@ -166,7 +193,7 @@ void register_ukernel_config_kleidi( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - check_format(format, torchao::ops::PackedWeightsType::kleidi_ai); + check_format(format, torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit); namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h index 82beea43fb..e22082f9f1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h @@ -53,10 +53,10 @@ struct PackedWeightsFormat { } }; -template -void check_format( +inline void check_format( PackedWeightsFormat format, - torchao::ops::PackedWeightsType type) { + torchao::ops::PackedWeightsType type, + int weight_nbit) { if (format.type != type) { throw std::runtime_error( "Kernel expects packed_weights type=" + diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index ae11b56e42..caaf8baf74 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -42,7 +42,7 @@ UKernelConfig get_ukernel_config() { /*prepare_activation_data_fn*/ &kernel::prepare_activation_data, /*kernel*/ - &kernel::kernel}}}}; + &kernel::kernel}}}}; } template < From c9b1490971ab3b95c20f076e8dd3a679e27883dc Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:09:36 -0700 Subject: [PATCH 02/19] Claen up op interface Differential Revision: D72179480 Pull Request resolved: https://github.com/pytorch/ao/pull/1998 --- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 95 +++- ..._8bit_activation_groupwise_lowbit_weight.h | 24 +- .../kernels/cpu/aarch64/linear/linear.h | 365 ------------- .../kernels/cpu/aarch64/tests/test_linear.cpp | 485 ++++++++++-------- .../kernel_config.h | 238 +++++++++ .../kernel_selector.h | 230 ++++----- .../linear_8bit_act_xbit_weight.cpp | 416 +++++---------- .../linear_8bit_act_xbit_weight.h | 144 +----- .../op_linear_8bit_act_xbit_weight-impl.h | 112 ++-- .../test_linear_8bit_act_xbit_weight.cpp | 326 +++++------- 10 files changed, 1006 insertions(+), 1429 deletions(-) delete mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/linear.h create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 2e8d0aa453..2a8e668fa7 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -60,27 +60,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -template -size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { +size_t packed_activations_size( + int m, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr) { (void)group_size; // unused (void)has_weight_zeros; // unused auto lhs_packing = get_lhs_packing(); return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr); } -template -void prepare_activation_data( - void* activation_data, +size_t packed_activations_offset( + int m_idx, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr) { + (void)group_size; // unused + (void)has_weight_zeros; // unused + auto lhs_pack = get_lhs_packing(); + return lhs_pack.get_lhs_packed_offset(m_idx, k, mr, kr, sr); +} + +void pack_activations( + void* packed_activations, int m, int k, int group_size, const float* activations, - bool has_weight_zeros) { + bool has_weight_zeros, + int mr, + int kr, + int sr) { (void)group_size; // unused (void)has_weight_zeros; // unused auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack( m, k, @@ -90,33 +110,62 @@ void prepare_activation_data( /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), - activation_data); + packed_activations); } -template -size_t weight_data_size( +size_t packed_weights_size( int n, int k, int group_size, + int weight_nbit, bool has_weight_zeros, - bool has_bias) { + bool has_bias, + int nr, + int kr, + int sr) { + (void)weight_nbit; // unused (void)has_weight_zeros; // unused (void)has_bias; // unused auto rhs_pack = get_rhs_packing(); return rhs_pack.get_rhs_packed_size( - n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); + internal::adjust_n(n), + k, + nr, + kr, + sr, + group_size, + kai_datatype::kai_dt_bf16); +} + +size_t packed_weights_offset( + int n_idx, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr) { + (void)has_weight_zeros; // unused + (void)has_bias; // unused + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_offset( + n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); } -template -void prepare_weight_data( - void* weight_data, +void pack_weights( + void* packed_weights, int n, int k, int group_size, const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros, - const float* bias) { + const float* bias, + int nr, + int kr, + int sr) { if (group_size % 32 != 0) { throw std::runtime_error( "Group size must be a multiple of 32, but got group_size=" + @@ -187,7 +236,7 @@ void prepare_weight_data( reinterpret_cast(weight_scales_bf16_padded.data()), /*scale_stride=*/sizeof(uint16_t) * (internal::roundup(k, group_size) / group_size), - /*rhs_packed=*/weight_data, + /*rhs_packed=*/packed_weights, /*extra_bytes=*/0, /*qparams=*/&qparams); } @@ -220,8 +269,8 @@ size_t get_preferred_alignement() { int n, \ int k, \ int group_size, \ - const void* weight_data, \ - const void* activation_data, \ + const void* packed_weights, \ + const void* packed_activations, \ float clamp_min, \ float clamp_max, \ bool has_weight_zeros, \ @@ -235,11 +284,11 @@ size_t get_preferred_alignement() { } \ get_ukernel().run_matmul( \ m, \ - internal::adjust_n(n), \ + n, \ k, \ group_size, \ - activation_data, \ - weight_data, \ + packed_activations, \ + packed_weights, \ output, \ /*dst_stride_row=*/output_m_stride * sizeof(float), \ /*dst_stride_col=*/sizeof(float), \ diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index 9ff75e3344..95ecb79dc0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -49,15 +49,21 @@ inline size_t packed_activations_offset( return (m_idx / mr) * packed_activations_size_mr_rows; } -template +template void pack_activations( void* packed_activations, int m, int k, int group_size, const float* activations, - bool has_weight_zeros) { - activation_packing::pack_activations( + bool has_weight_zeros, + int mr, + int kr, + int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused + activation_packing::pack_activations( packed_activations, m, k, group_size, activations, has_weight_zeros); } @@ -93,7 +99,7 @@ inline size_t packed_weights_offset( return (n_idx / nr) * packed_weights_size_nr_cols; } -template +template void pack_weights( void* packed_weights, int n, @@ -102,8 +108,14 @@ void pack_weights( const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros, - const float* bias) { - weight_packing::pack_weights( + const float* bias, + int nr, + int kr, + int sr) { + (void)nr; // unused + (void)kr; // unused + (void)sr; // unused + weight_packing::pack_weights( packed_weights, n, k, diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h deleted file mode 100644 index 7b983a1929..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -// TODO: this file will be deleted and replaced by -// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h -// It exists now to prevent breaking existing code in the interim. - -#pragma once - -#if defined(__aarch64__) || defined(__ARM_NEON) - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::linear { -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 32, - /*sr*/ 1); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 1, - /*kr*/ 32, - /*sr*/ 1); -} - -template -void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x1x32_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot - -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 16, - /*sr*/ 2); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -inline size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 4, - /*kr*/ 16, - /*sr*/ 2); -} - -template -inline void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x4x16_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot - -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 16, - /*sr*/ 2); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 8, - /*kr*/ 16, - /*sr*/ 2); -} - -template -void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros_, - bool has_bias, - bool has_clamp) { - (void)has_weight_zeros_; // unused - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x8x16_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot - -} // namespace torchao::kernels::cpu::aarch64::linear - -#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 0157769fec..671ee3f0b9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -12,17 +12,23 @@ #include #include #include -#include #include float kTol = 0.0001; -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 1; + constexpr int kr = 32; + constexpr int sr = 1; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -35,48 +41,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x1x32_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -88,56 +92,19 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -150,48 +117,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x4x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -203,69 +168,19 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - NLessThan4) { - for (int n = 1; n < 4; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } -} - -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -278,48 +193,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x8x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -331,60 +244,173 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x1x32) { + constexpr int weight_nbit = 4; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/false); -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/true); } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x4x16) { + constexpr int weight_nbit = 4; + + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/true); + + // n less than 4 + for (int n = 1; n < 4; n++) { + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/n, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - NLessThan8) { +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x8x16) { + constexpr int weight_nbit = 4; + + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/true); + + // n less than 8 for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/n, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); } } @@ -423,7 +449,10 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); // Define equivalent LUT for affine quantization constexpr int lut_size = (1 << weight_nbit); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h new file mode 100644 index 0000000000..1e4a9ef670 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h @@ -0,0 +1,238 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +constexpr int kMaxLinearConfigs = 4; +struct UKernelConfig { + // Size of packed_activations buffer + using packed_activations_size_fn_type = size_t (*)( + int m, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Offset in packed_activations buffer for a given m_idx + // m_idx is index in unpacked activations matrix; it will be a multiple of + // m_step + using packed_activations_offset_fn_type = size_t (*)( + int m_idx, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Pack activations into packed_activations buffer + using pack_activations_fn_type = void (*)( + void* packed_activations, + int m, + int k, + int group_size, + const float* activations, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Size of packed_weights buffer + using packed_weights_size_fn_type = size_t (*)( + int n, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr); + + // Offset in packed_weights buffer for a given n_idx + // n_inx is index in unpacked weights matrix; it will be a multiple of n_step + using packed_weights_offset_fn_type = size_t (*)( + int n_idx, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr); + + // Pack weights into packed_weights buffer + using pack_weights_fn_type = void (*)( + void* packed_weights, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias, + int nr, + int kr, + int sr); + + // Run matmul kernel + using kernel_fn_type = void (*)( + float* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* packed_weights, + const void* packed_activations, + float clamp_min, + float clamp_max, + bool has_weight_zeros, + bool has_bias, + bool has_clamp); + + struct linear_config_type { + int m_step{0}; // m_idx will be a multiple of this + int mr{0}; + packed_activations_size_fn_type packed_activations_size{nullptr}; + packed_activations_offset_fn_type packed_activations_offset{nullptr}; + pack_activations_fn_type pack_activations{nullptr}; + kernel_fn_type kernel{nullptr}; + }; + + // preferred_alignment for packed_activations and packed_weights + // Integration surfaces are not required to respect this alignment, and the + // kernel must behave correctly no matter how buffers are aligned + size_t preferred_alignment{0}; + int n_step{0}; // n_idx will be a multiple of this + int nr{0}; + int kr{0}; + int sr{0}; + int weight_nbit{0}; + bool has_weight_zeros{false}; + bool has_bias{false}; + packed_weights_size_fn_type packed_weights_size{nullptr}; + packed_weights_offset_fn_type packed_weights_offset{nullptr}; + pack_weights_fn_type pack_weights{nullptr}; + + // linear_configs must be sorted in ascending m_step + std::array linear_configs; + + static UKernelConfig make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array linear_configs); + + inline void validate() const { + TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); + TORCHAO_CHECK(nr >= 1, "nr must be >= 1"); + TORCHAO_CHECK(kr >= 1, "kr must be >= 1"); + TORCHAO_CHECK(sr >= 1, "sr must be >= 1"); + TORCHAO_CHECK(weight_nbit >= 1, "weight_nbit must be >= 1"); + TORCHAO_CHECK( + packed_weights_size != nullptr, "packed_weights_size must be set"); + TORCHAO_CHECK( + packed_weights_offset != nullptr, "packed_weights_offset must be set"); + TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); + + bool linear_configs_set = true; // first linear config must be set + for (int i = 0; i < linear_configs.size(); i++) { + if (linear_configs_set) { + TORCHAO_CHECK( + linear_configs[i].m_step >= 1, + "linear_configs[i].m_step must be >= 1"); + TORCHAO_CHECK( + linear_configs[i].mr >= 1, "linear_configs[i].mr must be >= 1"); + TORCHAO_CHECK( + linear_configs[i].packed_activations_size != nullptr, + "linear_configs[i].packed_activations_size must be set"); + TORCHAO_CHECK( + linear_configs[i].packed_activations_offset != nullptr, + "linear_configs[i].packed_activations_offset must be set"); + TORCHAO_CHECK( + linear_configs[i].pack_activations != nullptr, + "linear_configs[i].pack_activations must be set"); + TORCHAO_CHECK( + linear_configs[i].kernel != nullptr, + "linear_configs[i].kernel must be set"); + if (i >= 1) { + TORCHAO_CHECK( + linear_configs[i - 1].m_step < linear_configs[i].m_step, + "set linear_configs must be increasing in m_step"); + } + if (i + 1 < linear_configs.size()) { + linear_configs_set = (linear_configs[i + 1].m_step >= 1); + } + } + } + } + + inline int select_linear_config_idx(int m) const { + assert(m >= 1); + assert(linear_configs[0].m_step >= 1); + + int i = 0; + while (i + 1 < linear_configs.size() && linear_configs[i + 1].m_step >= 1 && + linear_configs[i + 1].m_step <= m) { + assert(linear_configs[i].m_step < linear_configs[i + 1].m_step); + i++; + } + + assert(i < linear_configs.size()); + assert(linear_configs[i].m_step >= 1); + assert(i == 0 || linear_configs[i].m_step <= m); + return i; + } +}; + +inline UKernelConfig UKernelConfig::make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array linear_configs) { + return UKernelConfig{ + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + packed_weights_size, + packed_weights_offset, + pack_weights, + std::move(linear_configs)}; +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index e960a918d8..719c2e01e4 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,11 +6,11 @@ #pragma once #include -#include +#include #include #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 #include @@ -50,6 +50,7 @@ struct UKernelConfigRegistrationTable { throw std::runtime_error( "UKernelConfig is already registered for this format"); } + config.validate(); registration_table_[key] = config; } std::optional get_ukernel_config( @@ -95,94 +96,90 @@ void register_ukernel_config_universal( torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, weight_nbit); + namespace kernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + + constexpr bool has_lut = false; + int preferred_alignment = 16; + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + constexpr int m_step = 1; + #if defined(TORCHAO_BUILD_CPU_AARCH64) if (cpuinfo_has_arm_neon_dot()) { - log_registration(format, "universal"); - namespace kernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + log_registration(format, "universal: kernel_1x8x16_f32_neondot"); + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + format.has_weight_zeros, + format.has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*linear_configs*/ {}); if (format.has_weight_zeros) { constexpr bool has_weight_zeros = true; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + + table.register_ukernel_config(format, uarch, std::move(uk)); + return; } else { constexpr bool has_weight_zeros = false; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + + table.register_ukernel_config(format, uarch, std::move(uk)); + return; } - return; } #endif // TORCHAO_BUILD_CPU_AARCH64 } } #if defined(TORCHAO_ENABLE_KLEIDI) -template < - typename kernel_struct, - int m_step, - int mr, - int n_step, - int nr, - int kr, - int sr> -UKernelConfig::linear_config_type get_linear_config_kleidi() { +template +UKernelConfig::linear_config_type +get_linear_config_kleidi(int n_step, int nr, int kr, int sr) { namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; - assert(m_step == kernel_struct::get_ukernel().get_m_step()); - assert(mr == kernel_struct::get_ukernel().get_mr()); assert(n_step == kernel_struct::get_ukernel().get_n_step()); assert(nr == kernel_struct::get_ukernel().get_nr()); assert(kr == kernel_struct::get_ukernel().get_kr()); assert(sr == kernel_struct::get_ukernel().get_sr()); - return UKernelConfig::linear_config_type{ - /*mr*/ m_step, - /*activation_data_size_fn*/ &op::activation_data_size, - /*prepare_activation_data_fn*/ &op::prepare_activation_data, - /*kernel*/ &kernel_struct::kernel}; -} - -template -UKernelConfig::weight_packing_config_type get_weight_packing_config_kleidi() { - namespace op = torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p; - return UKernelConfig::weight_packing_config_type( - {/*weight_data_size_fn*/ &op::weight_data_size, - /*prepare_weight_data_fn*/ &op::prepare_weight_data}); + return UKernelConfig::linear_config_type( + {static_cast(kernel_struct::get_ukernel().get_m_step()), + static_cast(kernel_struct::get_ukernel().get_mr()), + &op::packed_activations_size, + &op::packed_activations_offset, + &op::pack_activations, + &kernel_struct::kernel}); } template @@ -197,89 +194,62 @@ void register_ukernel_config_kleidi( namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = UKernelConfig::make( + /*preferred_alignment*/ op::get_preferred_alignement(), + /*n_step*/ format.nr, + format.nr, + format.kr, + format.sr, + format.weight_nbit, + format.has_weight_zeros, + format.has_bias, + &op::packed_weights_size, + &op::packed_weights_offset, + &op::pack_weights, + {} /*linear_configs*/); + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; + uk.n_step = 8; + #if defined(TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - constexpr int n_step = 8; + /*m_step=4*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>( + uk.n_step, uk.nr, uk.kr, uk.sr); log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - /*m_step*/ 4, - /*mr*/ 4, - n_step, - nr, - kr, - sr>()}}}); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } #endif // TORCHAO_ENABLE_ARM_I8MM if (cpuinfo_has_arm_neon_dot()) { - constexpr int n_step = 8; log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - /*m_step*/ 1, - /*mr*/ 1, - n_step, - nr, - kr, - sr>()}}}); + /*m_step=1*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } } if (format.nr == 4 && format.kr == 16 && format.sr == 2) { - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; + uk.n_step = 4; if (cpuinfo_has_arm_neon_dot()) { - constexpr int n_step = 4; + /*m_step=1*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /*m_step*/ 1, - /*mr*/ 1, - n_step, - nr, - kr, - sr>()}}}); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } } @@ -361,7 +331,7 @@ PackedWeightsFormat select_packed_weights_format( torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, has_weight_zeros, - /*has_bias*/ true, + has_bias, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 0421e6a25f..6929e6e4a4 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -7,43 +7,19 @@ #include #include #include +#include #include #include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread) { - TORCHAO_CHECK(n >= 1, "n must be >= 1"); - TORCHAO_CHECK( - target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1"); - - PackWeightDataTilingParams tiling_params; - int nr = ukernel_config.nr; - int num_threads = torchao::get_num_threads(); - int numerator = n; - int denominator = num_threads * target_panels_per_thread; - - // Set nc = ceil(numerator / denominator) - int nc = (numerator + denominator - 1) / denominator; - assert(nc >= 1); - - // Replace nc with the next number nr divides - nc = ((nc + nr - 1) / nr) * nr; - tiling_params.nc_by_nr = nc / nr; - - return tiling_params; -} - -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, +void pack_weights_operator( + const UKernelConfig& uk, // Outputs - void* weight_data, + void* packed_weights, // Inputs int n, int k, @@ -54,12 +30,14 @@ void pack_weight_data_operator( const float* bias) { TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); + TORCHAO_CHECK( + uk.has_bias == (bias != nullptr), "bias/has_bias is inconsistent"); + TORCHAO_CHECK( + uk.has_weight_zeros == (weight_zeros != nullptr), + "weight_zeros/has_weight_zeros is inconsistent"); - bool has_weight_zeros = (weight_zeros != nullptr); - bool has_bias = (bias != nullptr); - - int nr = ukernel_config.nr; - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int n_step = uk.n_step; + int nc = std::min(n, n_step); int num_nc_panels = (n + nc - 1) / nc; torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { @@ -67,50 +45,53 @@ void pack_weight_data_operator( int n_idx = nc_tile_idx * nc; int nc_tile_size = std::min(nc, n - n_idx); - int weight_data_offset = (n_idx / nr) * - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + int weight_qvals_offset = n_idx * k; int weight_scales_and_zeros_offset = (n_idx * k / group_size); - - const int8_t* weight_zeros_ptr = nullptr; - if (weight_zeros != nullptr) { - weight_zeros_ptr = weight_zeros + weight_scales_and_zeros_offset; - } - const float* bias_ptr = nullptr; - if (bias != nullptr) { - bias_ptr = bias + n_idx; - } - - ukernel_config.weight_packing_config.prepare_weight_data_fn( - (char*)weight_data + weight_data_offset, + uk.pack_weights( + (char*)packed_weights + packed_weights_offset, /*n=*/nc_tile_size, k, group_size, weight_qvals + weight_qvals_offset, weight_scales + weight_scales_and_zeros_offset, - weight_zeros_ptr, - bias_ptr); + (weight_zeros == nullptr) + ? nullptr + : (weight_zeros + weight_scales_and_zeros_offset), + (bias == nullptr) ? nullptr : (bias + n_idx), + uk.nr, + uk.kr, + uk.sr); }); } -// This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, +LinearTilingParams LinearTilingParams::from_target_tiles_per_thread( int m, + int m_step, int n, + int n_step, int target_tiles_per_thread) { TORCHAO_CHECK(m >= 1, "m must be >= 1"); + TORCHAO_CHECK(m_step >= 1, "m_step must be >= 1"); + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); TORCHAO_CHECK( target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1"); - - LinearTilingParams tiling_params; auto num_threads = torchao::get_num_threads(); TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); - tiling_params.mc_by_mr = 1; - int mc = tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr; + int mc = m_step; int num_mc_panels = (m + mc - 1) / mc; int numerator = n * num_mc_panels; @@ -120,50 +101,25 @@ LinearTilingParams get_default_linear_tiling_params( int nc = (numerator + denominator - 1) / denominator; assert(nc >= 1); - // Replace nc with next number nr divides - int nr = ukernel_config.nr; - nc = ((nc + nr - 1) / nr) * nr; - assert(nc % nr == 0); - tiling_params.nc_by_nr = nc / nr; + // Replace nc with next number n_step divides + nc = ((nc + n_step - 1) / n_step) * n_step; - assert(tiling_params.mc_by_mr >= 1); - assert(tiling_params.nc_by_nr >= 1); - return tiling_params; -} - -namespace internal { + // Clamp mc, nc to be no larger than m, n + mc = std::min(m, mc); + nc = std::min(n, nc); -inline size_t -get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size, - bool has_weight_zeros) { - return ukernel_config.linear_configs[0].activation_data_size_fn( - tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr, - k, - group_size, - has_weight_zeros); -} + assert((mc == m) || (mc % m_step == 0)); + assert((nc == n) || (nc % n_step == 0)); -inline size_t -get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size, - bool has_weight_zeros) { - return ukernel_config.linear_configs[0].activation_data_size_fn( - m, k, group_size, has_weight_zeros); + LinearTilingParams tiling_params; + tiling_params.mc = mc; + tiling_params.nc = nc; + return tiling_params; } -inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, +void linear_operator( + const UKernelConfig& uk, + const std::optional& tiling_params, // Outputs float* output, // Inputs @@ -171,237 +127,101 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int n, int k, int group_size, - const void* weight_data, + const void* packed_weights, const float* activations, - // Ignored if has_clamp = false + bool has_clamp, float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - int nr = ukernel_config.nr; - int mc = - std::min(m, tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr); - int nc = std::min(n, tiling_params.nc_by_nr * nr); + float clamp_max) { + TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); + TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); + + // Select linear config based on m + int linear_config_idx = uk.select_linear_config_idx(m); + auto& linear_config = uk.linear_configs[linear_config_idx]; + int n_step = uk.n_step; + int m_step = linear_config.m_step; + + // Choose tiling params + int mc, nc; + if (tiling_params.has_value()) { + mc = tiling_params->mc; + nc = tiling_params->nc; + } else { + auto params = LinearTilingParams::from_target_tiles_per_thread( + m, + m_step, + n, + n_step, + /*target_tiles_per_thread=*/5); + mc = params.mc; + nc = params.nc; + } + TORCHAO_CHECK(mc >= 1, "mc must be >= 1"); + TORCHAO_CHECK(nc >= 1, "nc must be >= 1"); + TORCHAO_CHECK( + (mc == m) || (mc % m_step == 0), + "mc from tiling_params must be m or a multiple of m_step"); + TORCHAO_CHECK( + (nc == n) || (nc % n_step == 0), + "nc from tiling_params must be n or a multiple of n_step"); + int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; - size_t weight_data_size = - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); + auto packed_activations_size = linear_config.packed_activations_size( + mc, k, group_size, uk.has_weight_zeros, linear_config.mr, uk.kr, uk.sr); + + auto packed_activations = torchao::make_aligned_byte_ptr( + uk.preferred_alignment, packed_activations_size); for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) { int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.linear_configs[0].prepare_activation_data_fn( - activation_data_buffer, + + linear_config.pack_activations( + packed_activations.get(), /*m=*/mc_tile_size, k, group_size, activations + activations_offset, - has_weight_zeros); + uk.has_weight_zeros, + linear_config.mr, + uk.kr, + uk.sr); torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { int nc_tile_idx = idx; int n_idx = nc_tile_idx * nc; int nc_tile_size = std::min(nc, n - n_idx); - int output_offset = m_idx * n + n_idx; - int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.linear_configs[0].kernel_fn( + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + + linear_config.kernel( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, /*n=*/nc_tile_size, k, group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer, + /*packed_weights=*/(char*)packed_weights + packed_weights_offset, + /*packed_activations=*/packed_activations.get(), clamp_min, clamp_max, - has_weight_zeros, - has_bias, + uk.has_weight_zeros, + uk.has_bias, has_clamp); }); } } -inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - int mr = ukernel_config.linear_configs[0].mr; - int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * mr); - int nc = std::min(n, tiling_params.nc_by_nr * nr); - int num_mc_panels = (m + mc - 1) / mc; - int num_nc_panels = (n + nc - 1) / nc; - - size_t weight_data_size = - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); - size_t activation_data_size = - ukernel_config.linear_configs[0].activation_data_size_fn( - mr, k, group_size, has_weight_zeros); - - torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) { - int mc_tile_idx = idx; - int m_idx = mc_tile_idx * mc; - int mc_tile_size = std::min(mc, m - m_idx); - int activations_offset = m_idx * k; - int activation_data_offset = (m_idx / mr) * activation_data_size; - - ukernel_config.linear_configs[0].prepare_activation_data_fn( - activation_data_buffer + activation_data_offset, - /*m=*/mc_tile_size, - k, - group_size, - activations + activations_offset, - has_weight_zeros); - }); - - torchao::parallel_1d(0, num_mc_panels * num_nc_panels, [&](int64_t idx) { - int mc_tile_idx = idx / num_nc_panels; - int m_idx = mc_tile_idx * mc; - int mc_tile_size = std::min(mc, m - m_idx); - - int nc_tile_idx = idx % num_nc_panels; - int n_idx = nc_tile_idx * nc; - int nc_tile_size = std::min(nc, n - n_idx); - - int activation_data_offset = (m_idx / mr) * activation_data_size; - int output_offset = m_idx * n + n_idx; - int weight_data_offset = (n_idx / nr) * weight_data_size; - - ukernel_config.linear_configs[0].kernel_fn( - output + output_offset, - /*output_m_stride=*/n, - /*m=*/mc_tile_size, - /*n=*/nc_tile_size, - k, - group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer + activation_data_offset, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - }); -} -} // namespace internal - -void linear_operator( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); - TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); - switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - internal::linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - break; - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - internal:: - linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - break; - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); - } -} - -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size, - bool has_weight_zeros) { - switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - m, - k, - group_size, - has_weight_zeros); - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - m, - k, - group_size, - has_weight_zeros); - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); - } -} - -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index dba0adb32d..accc5be5a1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -7,102 +7,17 @@ #pragma once #include #include +#include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { -struct UKernelConfig { - using activation_data_size_fn_type = - size_t (*)(int m, int k, int group_size, bool has_weight_zeros); - using prepare_activation_data_fn_type = void (*)( - void* activation_data, - int m, - int k, - int group_size, - const float* activations, - bool has_weight_zeros); - using weight_data_size_fn_type = size_t (*)( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias); - using prepare_weight_data_fn_type = void (*)( - void* weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias); - using kernel_fn_type = void (*)( - float* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp); - - struct weight_packing_config_type { - weight_data_size_fn_type weight_data_size_fn{nullptr}; - prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - }; - struct linear_config_type { - int mr{0}; - activation_data_size_fn_type activation_data_size_fn{nullptr}; - prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; - kernel_fn_type kernel_fn{nullptr}; - }; - - // preferred_alignment for activation and weight data - // Integration surfaces are not required to respect this alignment, and the - // ukernel must behave correctly no matter how buffers are aligned - size_t preferred_alignment{0}; - int nr{0}; - weight_packing_config_type weight_packing_config; - std::array linear_configs; -}; - -// Pack weight functions -struct PackWeightDataTilingParams { - int nc_by_nr{1}; -}; - -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread = 1); - -inline size_t get_packed_weight_data_size( - const UKernelConfig& ukernel_config, - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return ukernel_config.weight_packing_config.weight_data_size_fn( - n, k, group_size, has_weight_zeros, has_bias); -} - -inline size_t get_preferred_packed_weight_data_alignment( - const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_alignment; -} - -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, +void pack_weights_operator( + const UKernelConfig& uk, // Outputs - void* weight_data, + void* packed_weights, // Inputs int n, int k, @@ -114,40 +29,23 @@ void pack_weight_data_operator( // Linear functions struct LinearTilingParams { - int mc_by_mr{1}; - int nc_by_nr{1}; -}; - -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, - int m, - int n, - int target_tiles_per_thread = 5); + int mc{0}; + int nc{0}; -enum class LinearTileSchedulingPolicy { - single_mc_parallel_nc, - parallel_mc_parallel_nc + // Returns LinearTilingParams with mc and nc chosen so that there are + // approximately target_tiles_per_thread tiles per thread. The method + // guarantees 1. mc = m or mc % m_step == 0, and 2. nc = n or nc % n_step == 0 + static LinearTilingParams from_target_tiles_per_thread( + int m, + int m_step, + int n, + int n_step, + int target_tiles_per_thread); }; -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size, - bool has_weight_zeros); - -inline size_t get_preferred_activation_data_buffer_alignment( - const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_alignment; -} - void linear_operator( const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, + const std::optional& tiling_params, // Outputs float* output, // Inputs @@ -155,13 +53,11 @@ void linear_operator( int n, int k, int group_size, - const void* weight_data, + const void* packed_weights, const float* activations, + bool has_clamp, float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp); + float clamp_max); } // namespace // torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 636fc01c64..065a5b0319 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -69,29 +69,31 @@ Tensor pack_weights_cpu( bias_ptr = bias.value().const_data_ptr(); } - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - auto packed_weights_format = select_packed_weights_format( - target, has_weight_zeros, has_bias); + auto packed_weights_format = + torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_format< + weight_nbit>(target, has_weight_zeros, has_bias); auto packed_weights_header = packed_weights_format.to_packed_weights_header(); - auto ukernel_config = - select_ukernel_config(packed_weights_header); - - auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( - ukernel_config, n, /*target_panels_per_thread=*/1); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_header); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); - auto packed_weight_data_size = - torchao::ops::PackedWeightsHeader::size() + - get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); packed_weights_header.write(packed_weights.mutable_data_ptr()); - // TODO: support passing in bias in future - pack_weight_data_operator( - ukernel_config, - pack_weight_tiling_params, + torchao::ops::linear_8bit_act_xbit_weight::pack_weights_operator( + uk, packed_weights.mutable_data_ptr() + torchao::ops::PackedWeightsHeader::size(), n, @@ -122,18 +124,26 @@ Tensor pack_weights_meta( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - auto packed_weights_format = select_packed_weights_format( - target, has_weight_zeros, has_bias); - auto ukernel_config = - select_ukernel_config(packed_weights_format); - - auto packed_weight_data_size = - torchao::ops::PackedWeightsHeader::size() + - get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); - auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); + auto packed_weights_format = + torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_format< + weight_nbit>(target, has_weight_zeros, has_bias); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_format); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + auto options = + torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } #endif // USE_ATEN @@ -169,8 +179,6 @@ Tensor linear_out_cpu( // Explicit cast from int64_t to int is required for Executorch TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); - using namespace torchao::ops::linear_8bit_act_xbit_weight; - TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); #ifdef USE_ATEN TORCHAO_CHECK( @@ -182,36 +190,12 @@ Tensor linear_out_cpu( auto header = torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); - auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat:: - from_packed_weights_header(header); - - auto ukernel_config = select_ukernel_config(header); - - auto linear_tiling_params = get_default_linear_tiling_params( - ukernel_config, - m, - n, - /*target_tiles_per_thread=*/5); - - auto linear_scheduling_policy = - LinearTileSchedulingPolicy::single_mc_parallel_nc; - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, - group_size, - format.has_weight_zeros); - - std::vector activation_data_buffer(activation_data_buffer_size); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(header); - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.data(), + torchao::ops::linear_8bit_act_xbit_weight::linear_operator( + uk, + std::nullopt, out.mutable_data_ptr(), m, n, @@ -220,13 +204,9 @@ Tensor linear_out_cpu( packed_weights.const_data_ptr() + torchao::ops::PackedWeightsHeader::size(), activations.const_data_ptr(), - // Clamp parameters are ignored because config is created from - // has_clamp = false + /*has_clamp=*/false, /*clamp_min=*/0.0, - /*clamp_max=*/0.0, - format.has_weight_zeros, - format.has_bias, - /*has_clamp*/ false); + /*clamp_max=*/0.0); return out; } diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index caaf8baf74..980228a1a8 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -6,7 +6,9 @@ #include // TODO: move test_utils.h out of aarch64 -#include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#include +#endif // TORCHAO_BUILD_CPU_AARCH64 #include #include #include @@ -26,23 +28,41 @@ using namespace torchao::ops::linear_8bit_act_xbit_weight; template UKernelConfig get_ukernel_config() { namespace kernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - return UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}; + channelwise_8bit_activation_groupwise_lowbit_weight; + + int preferred_alignment = 16; + int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + int m_step = 1; + constexpr bool has_lut = false; + + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*linear_configs*/ {}); + + uk.linear_configs[0] = UKernelConfig::linear_config_type{ + m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel:: + kernel_1x8x16_f32_neondot}; + + return uk; } template < @@ -82,87 +102,68 @@ void test_linear_8bit_act_xbit_weight( auto output = std::vector(m * n); - for (auto linear_scheduling_policy : - {LinearTileSchedulingPolicy::single_mc_parallel_nc, - LinearTileSchedulingPolicy::parallel_mc_parallel_nc}) { - for (auto num_threads : {1, 4, 500}) { - torchao::set_num_threads(num_threads); - EXPECT_EQ(torchao::get_num_threads(), num_threads); - - // Pack weights - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); - auto packed_weight_data_size = get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - pack_weight_data_operator( - ukernel_config, - pack_weight_data_tiling_params, - packed_weight_data.get(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - weight_zeros_ptr, - bias_ptr); - - // Allocate activation buffer - auto linear_tiling_params = - get_default_linear_tiling_params(ukernel_config, m, n); - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, - group_size, - has_weight_zeros); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - // Run linear - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.get(), - output.data(), - m, - n, - k, - group_size, - packed_weight_data.get(), - test_case.activations.data(), - test_case.clamp_min, - test_case.clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - - // Test correctness - float tol = kTol; - if (has_kleidi) { - tol = kTolKleidiAI; - } - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], tol); - } + for (auto num_threads : {1, 4, 500}) { + torchao::set_num_threads(num_threads); + EXPECT_EQ(torchao::get_num_threads(), num_threads); + + // Pack weights + auto packed_weight_data_size = ukernel_config.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + ukernel_config.nr, + ukernel_config.kr, + ukernel_config.sr); + auto preferred_packed_weight_data_alignment = + ukernel_config.preferred_alignment; + auto packed_weights = torchao::make_aligned_byte_ptr( + preferred_packed_weight_data_alignment, packed_weight_data_size); + + int8_t* weight_zeros_ptr = nullptr; + if (has_weight_zeros) { + weight_zeros_ptr = test_case.weight_zeros.data(); + } + float* bias_ptr = nullptr; + // kleidi always has bias in these tests + if (has_bias || has_kleidi) { + bias_ptr = test_case.bias.data(); + } + + pack_weights_operator( + ukernel_config, + packed_weights.get(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + weight_zeros_ptr, + bias_ptr); + + linear_operator( + ukernel_config, + std::nullopt, + output.data(), + m, + n, + k, + group_size, + packed_weights.get(), + test_case.activations.data(), + has_clamp, + test_case.clamp_min, + test_case.clamp_max); + + // Test correctness + float tol = kTol; + if (has_kleidi) { + tol = kTolKleidiAI; + } + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], tol); } } } @@ -176,102 +177,56 @@ enum kai_kernel_id { i8mm_8x4x32 }; -template < - typename kernel_struct, - int m_step, - int mr, - int n_step, - int nr, - int kr, - int sr> -UKernelConfig get_ukernel_config_kleidi() { +template +UKernelConfig get_ukernel_config_kleidi_impl() { namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = kernel_struct::get_ukernel(); - assert(m_step == uk.get_m_step()); - assert(mr == uk.get_mr()); - assert(n_step == uk.get_n_step()); - assert(nr == uk.get_nr()); - assert(kr == uk.get_kr()); - assert(sr == uk.get_sr()); - return UKernelConfig{ + auto ukernel_config = UKernelConfig::make( op::get_preferred_alignement(), - n_step, - {/*weight_data_size_fn*/ &op::weight_data_size, - /*prepare_weight_data_fn*/ &op::prepare_weight_data}, - {{{m_step, - &op::activation_data_size, - &op::prepare_activation_data, - &kernel_struct::kernel}}}}; + uk.get_n_step(), + uk.get_nr(), + uk.get_kr(), + uk.get_sr(), + /*weight_nbit*/ 4, + /*has_weight_zeros*/ false, + /*has_bias*/ true, + &op::packed_weights_size, + &op::packed_weights_offset, + &op::pack_weights, + /*linear_configs*/ {}); + + ukernel_config.linear_configs[0] = UKernelConfig::linear_config_type{ + static_cast(uk.get_m_step()), + static_cast(uk.get_mr()), + &op::packed_activations_size, + &op::packed_activations_offset, + &op::pack_activations, + &kernel_struct::kernel}; + + return ukernel_config; } template UKernelConfig get_ukernel_config_kleidi() { #if defined(TORCHAO_ENABLE_ARM_I8MM) if constexpr (kernel_id == i8mm_4x8x32) { - constexpr int m_step = 4; - constexpr int mr = 4; - constexpr int n_step = 8; - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>(); } if constexpr (kernel_id == i8mm_8x4x32) { - constexpr int m_step = 8; - constexpr int mr = 8; - constexpr int n_step = 4; - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm>(); } #endif // TORCHAO_ENABLE_ARM_I8MM if constexpr (kernel_id == dotprod_1x8x32) { - constexpr int m_step = 1; - constexpr int mr = 1; - constexpr int n_step = 8; - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); } if constexpr (kernel_id == dotprod_1x4x32) { - constexpr int m_step = 1; - constexpr int mr = 1; - constexpr int n_step = 4; - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>(); } throw std::runtime_error("Unsupported kernel_id"); } @@ -332,15 +287,11 @@ TEST(test_linear_8bit_act_xbit_weight, KNotDivisibleByGroupSize) { true /*has_weight_zeros*/, true /*has_bias*/, true /*has_clamp*/>(); - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); - EXPECT_THROW( { - pack_weight_data_operator( + pack_weights_operator( ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, + /*packed_weights=*/nullptr, n, k, group_size, @@ -362,15 +313,12 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { true /*has_weight_zeros*/, true /*has_bias*/, true /*has_clamp*/>(); - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); EXPECT_THROW( { - pack_weight_data_operator( + pack_weights_operator( ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, + /*packed_weights=*/nullptr, n, k, group_size, From 8776dd32259eb088e9211d7453ffd77148a102f4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 09:51:27 -0700 Subject: [PATCH 03/19] quantized matmul Differential Revision: D71370592 Pull Request resolved: https://github.com/pytorch/ao/pull/1994 --- ...hannelwise_8bit_b_1x16x16_f32_smlal-impl.h | 384 ++++++++++++++++++ ...annelwise_8bit_b_1x8x16_f32_neondot-impl.h | 336 +++++++++++++++ .../kernels/cpu/aarch64/matmul/matmul.h | 74 ++++ .../kernels/cpu/aarch64/matmul/matmul_utils.h | 70 ++++ .../cpu/aarch64/quantization/quantize.cpp | 23 +- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 9 + .../cpu/aarch64/tests/build_and_run_tests.sh | 1 + .../cpu/aarch64/tests/test_qmatmul.cpp | 229 +++++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 229 +++++++++-- 9 files changed, 1324 insertions(+), 31 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h new file mode 100644 index 0000000000..b83c28143f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h @@ -0,0 +1,384 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal { + +namespace { +/* +This function loads int8x16_t value from a, and 8 int8x16_t values from b. +For each int8x16_t of b: +- subl to subtarct a_zero_point from a, to get a_low, a_high +- 4 int32x4 accumulated values +- for i in [0, 8]: + - load b[i] + - subl to subtarct b_zero_point from b, to get b_low, b_high + - smlal_lane to multiply a_low[i] and b_low_low. + - smlal_lane to multiply a_low[i] and b_low_high. + - smlal_lane to multiply a_low[i] and b_high_low. + - smlal_lane to multiply a_low[i] and b_high_high. + - This produces 2 int32x4_t values +- for i in [0, 8]: + - load b[i] + - subl to subtarct b_zero_point from b, to get b_low, b_high + - smlal_lane to multiply a_low[i] and b_low_low. + - smlal_lane to multiply a_low[i] and b_low_high. + - smlal_lane to multiply a_low[i] and b_high_low. + - smlal_lane to multiply a_low[i] and b_high_high. + - This produces 2 int32x4_t values +Possibly better to transpose 16x16 of b and use dotprod. Left for future. +*/ + +template +TORCHAO_ALWAYS_INLINE void block_mul_1x16x1( + const int16x4_t& a_vec, + const int8x16_t& b_vec, + const int8x16_t& b_zero_point_vec, + int32x4_t (&partial_sums)[4]) { + int16x8_t b_vec_low = + vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); + int16x8_t b_vec_high = + vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); + partial_sums[0] = + vmlal_lane_s16(partial_sums[0], vget_low_s16(b_vec_low), a_vec, lane); + partial_sums[1] = + vmlal_lane_s16(partial_sums[1], vget_high_s16(b_vec_low), a_vec, lane); + partial_sums[2] = + vmlal_lane_s16(partial_sums[2], vget_low_s16(b_vec_high), a_vec, lane); + partial_sums[3] = + vmlal_lane_s16(partial_sums[3], vget_high_s16(b_vec_high), a_vec, lane); +} + +void block_mul_1x16x16( + const int8_t* a, + const int8_t* b, + const size_t ldb, + const int8_t a_zero_point, + const int8_t* b_zero_point, + int32x4_t (&partial_sums)[4]) { + int8x16_t a_vec = vld1q_s8(a); + int8x8_t a_zero_point_vec = vdup_n_s8(a_zero_point); + int8x16_t b_zero_point_vec = vld1q_s8(b_zero_point); + int16x8_t a_vec_low = vsubl_s8(vget_low_s8(a_vec), a_zero_point_vec); + int16x8_t a_vec_high = vsubl_s8(vget_high_s8(a_vec), a_zero_point_vec); + + int8x16_t b_vec = vld1q_s8(b + 0 * ldb); + block_mul_1x16x1<0>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 1 * ldb); + block_mul_1x16x1<1>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 2 * ldb); + block_mul_1x16x1<2>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 3 * ldb); + block_mul_1x16x1<3>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 4 * ldb); + block_mul_1x16x1<0>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 5 * ldb); + block_mul_1x16x1<1>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 6 * ldb); + block_mul_1x16x1<2>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 7 * ldb); + block_mul_1x16x1<3>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + + // Second set of 8 channels + b_vec = vld1q_s8(b + 8 * ldb); + block_mul_1x16x1<0>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 9 * ldb); + block_mul_1x16x1<1>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 10 * ldb); + block_mul_1x16x1<2>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 11 * ldb); + block_mul_1x16x1<3>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 12 * ldb); + block_mul_1x16x1<0>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 13 * ldb); + block_mul_1x16x1<1>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 14 * ldb); + block_mul_1x16x1<2>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 15 * ldb); + block_mul_1x16x1<3>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); +} + +TORCHAO_ALWAYS_INLINE void dequantize_1x16_int32_t( + const int32x4_t (&sums)[4], + const float* lhs_scales, + const float* rhs_scales, + float32x4_t (&outputs)[4]) { + float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); + float32x4_t scales_4567 = + vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); + float32x4_t scales_89ab = + vmulq_n_f32(vld1q_f32(rhs_scales + 8), lhs_scales[0]); + float32x4_t scales_cdef = + vmulq_n_f32(vld1q_f32(rhs_scales + 12), lhs_scales[0]); + + outputs[0] = vmulq_f32(vcvtq_f32_s32(sums[0]), scales_0123); + outputs[1] = vmulq_f32(vcvtq_f32_s32(sums[1]), scales_4567); + outputs[2] = vmulq_f32(vcvtq_f32_s32(sums[2]), scales_89ab); + outputs[3] = vmulq_f32(vcvtq_f32_s32(sums[3]), scales_cdef); +} + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + /** + * @brief Implements quantized matrix multiplication for 8-bit channelwise + * quantized matrices + * + * This specialized implementation handles the case where: + * - Both LHS and RHS have zero points (true, true) + * - Neither LHS nor RHS are transposed (false, false) + * + * The function performs a quantized matrix multiplication C = A * B where: + * - A is an mƗk matrix (LHS) + * - B is a kƗn matrix (RHS) + * - C is an mƗn matrix (output) + * + * The implementation uses NEON intrinsics for vectorized computation and + * processes data in blocks of 16Ɨ16 for optimal performance on ARM + * architecture. + * + * @param m Number of rows in LHS and output + * @param n Number of columns in RHS and output + * @param k Number of columns in LHS and rows in RHS + * @param lhs Pointer to LHS matrix data (int8_t) + * @param lhs_stride_m Stride between rows of LHS + * @param rhs Pointer to RHS matrix data (int8_t) + * @param rhs_stride_n Stride between rows of RHS + * @param output Pointer to output matrix (float32_t) + * @param out_stride_m Stride between rows of output + * @param lhs_zero_points Zero points for LHS quantization (per-channel) + * @param rhs_zero_points Zero points for RHS quantization (per-channel) + * @param lhs_scales Scales for LHS quantization (per-channel) + * @param rhs_scales Scales for RHS quantization (per-channel) + * @param lhs_qparams_stride Stride for LHS quantization parameters + * @param rhs_qparams_stride Stride for RHS quantization parameters + */ + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + // If lhs_zero_points and rhs_zero_points are not contiguous, transpose + std::unique_ptr lhs_zero_points_transposed = + std::make_unique(m); + std::unique_ptr lhs_scales_transposed = + std::make_unique(m); + if (lhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + lhs_zero_points, + lhs_scales, + lhs_zero_points_transposed.get(), + lhs_scales_transposed.get(), + m, + lhs_qparams_stride); + lhs_zero_points = lhs_zero_points_transposed.get(); + lhs_scales = lhs_scales_transposed.get(); + } + std::unique_ptr rhs_zero_points_transposed = + std::make_unique(n); + std::unique_ptr rhs_scales_transposed = + std::make_unique(n); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + n, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 16 cols at a time + // Access to partial tiles must be protected:w + constexpr int nr = 16; + constexpr int kr = 16; + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx; + int32x4_t int32_sums[nr / 4] = {vdupq_n_s32(0)}; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x16x16( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + lhs_zero_points[m_idx], + rhs_zero_points + n_idx, + int32_sums); + lhs_ptr += kr; + rhs_ptr += kr * rhs_stride_n; + } + + int8x16_t b_zero_point_vec = vld1q_s8(rhs_zero_points + n_idx); + for (int ki = 0; ki < (k - k_idx); ++ki) { + // For each of the remaining k values + // Load 1 int8_t from lhs + // Load 16 int8_t from rhs + // And multiply + add into the 16 accumulators + // arranged as int32x4_t[4] + int16_t a_val = static_cast(lhs_ptr[ki]) - + static_cast(lhs_zero_points[m_idx]); + int8x16_t b_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); + int16x8_t b_vec_low = + vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); + int16x8_t b_vec_high = + vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); + int32_sums[0] = + vmlal_n_s16(int32_sums[0], vget_low_s16(b_vec_low), a_val); + int32_sums[1] = + vmlal_n_s16(int32_sums[1], vget_high_s16(b_vec_low), a_val); + int32_sums[2] = + vmlal_n_s16(int32_sums[2], vget_low_s16(b_vec_high), a_val); + int32_sums[3] = + vmlal_n_s16(int32_sums[3], vget_high_s16(b_vec_high), a_val); + } + + float32x4_t res[4]; + dequantize_1x16_int32_t( + int32_sums, lhs_scales + m_idx, rhs_scales + n_idx, res); + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + float* store_loc = output + m_idx * out_stride_m + n_idx; + vst1q_f32(store_loc, res[0]); + vst1q_f32(store_loc + 4, res[1]); + vst1q_f32(store_loc + 8, res[2]); + vst1q_f32(store_loc + 12, res[3]); + } // n_idx + } // m_idx + } +}; + +} // namespace + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal + +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h new file mode 100644 index 0000000000..123b7723e4 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -0,0 +1,336 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal { + +/* +This function loads int8x16_t value from a, and 8 int8x16_t values from b, and +computes 8 dot products, resulting in 8 int32x4_t values. +Furthermore the int8x16_t values from a are reduced via summing, resulting in +int32_t row_sum_a. Similar int8x16_t values from b are reduced via summing, +resulting in int32_t row_sum_b. +*/ +TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16( + const int8_t* a, + const int8_t* b, + const size_t ldb, + int32x4_t (&partial_sums)[8], + int32_t& row_sum_a, + int32_t (&row_sum_b)[8]) { + int8x16_t a_vec = vld1q_s8(a); + row_sum_a = row_sum_a + vaddlvq_s8(a_vec); + +// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize +// get optimized by moving all the loads up in the unrolled loop. Just hoping +// OOO machine will take care of things Late replace this with macros so as to +// deconstruct the loop and do manual optimization. Or just write assembly. +#pragma unroll(8) + for (int i = 0; i < 8; ++i) { + int8x16_t b_vec = vld1q_s8(b + i * ldb); + row_sum_b[i] = row_sum_b[i] + vaddlvq_s8(b_vec); + partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec); + } +} + +TORCHAO_ALWAYS_INLINE static void reduce_1x8_int32x4_t_sums( + const int32x4_t (&partial_sums)[8], + int32_t (&sums)[8]) { +#pragma unroll(8) + for (int i = 0; i < 8; ++i) { + sums[i] = vaddvq_s32(partial_sums[i]); + } +} + +TORCHAO_ALWAYS_INLINE static void dequantize_1x8_int32_t( + const int32_t (&sums)[8], + int32_t& row_sum_lhs, + int32_t (&row_sum_rhs)[8], + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int32_t k, + float32x4x2_t& outputs) { + int32x4_t vec_sum_0123 = vld1q_s32(sums); + int32x4_t vec_sum_4567 = vld1q_s32(sums + 4); + + int32x4_t row_sum_rhs_x_lhs_zp_0123 = + vmulq_n_s32(vld1q_s32(row_sum_rhs), (int32_t)lhs_zero_points[0]); + int32x4_t row_sum_rhs_x_lhs_zp_4567 = + vmulq_n_s32(vld1q_s32(row_sum_rhs + 4), (int32_t)lhs_zero_points[0]); + + // Extract rhs zero point in int8x8_t and convert to int32x4_t + int16x8_t rhs_zero_points_vec_01234567 = vmovl_s8(vld1_s8(rhs_zero_points)); + int32x4_t rhs_zero_points_vec_0123 = + vmovl_s16(vget_low_s16(rhs_zero_points_vec_01234567)); + int32x4_t rhs_zero_points_vec_4567 = + vmovl_s16(vget_high_s16(rhs_zero_points_vec_01234567)); + int32x4_t row_sum_lhs_x_rhs_zp_0123 = + vmulq_n_s32(rhs_zero_points_vec_0123, row_sum_lhs); + int32x4_t row_sum_lhs_x_rhs_zp_4567 = + vmulq_n_s32(rhs_zero_points_vec_4567, row_sum_lhs); + + int32x4_t zp_rhs_x_zp_lhs_0123 = + vmulq_n_s32(rhs_zero_points_vec_0123, k * (int32_t)lhs_zero_points[0]); + int32x4_t zp_rhs_x_zp_lhs_4567 = + vmulq_n_s32(rhs_zero_points_vec_4567, k * (int32_t)lhs_zero_points[0]); + + vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_rhs_x_lhs_zp_0123); + vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_lhs_x_rhs_zp_0123); + vec_sum_0123 = vaddq_s32(vec_sum_0123, zp_rhs_x_zp_lhs_0123); + + vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_rhs_x_lhs_zp_4567); + vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_lhs_x_rhs_zp_4567); + vec_sum_4567 = vaddq_s32(vec_sum_4567, zp_rhs_x_zp_lhs_4567); + + float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); + float32x4_t scales_4567 = + vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); + + outputs.val[0] = vmulq_f32(vcvtq_f32_s32(vec_sum_0123), scales_0123); + outputs.val[1] = vmulq_f32(vcvtq_f32_s32(vec_sum_4567), scales_4567); +} + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + /** + * @brief Executes a quantized matrix multiplication with channelwise + * quantization parameters + * + * This function performs matrix multiplication between two 8-bit quantized + * matrices with per-channel quantization parameters. It handles the following + * operations: + * 1. Transposes quantization parameters if they're not contiguous + * 2. Processes the matrices in blocks of 8 columns at a time + * 3. Uses NEON dot product instructions for efficient computation + * 4. Handles edge cases for remaining elements + * 5. Dequantizes the results to floating point + * + * @param m Number of rows in the output matrix + * @param n Number of columns in the output matrix + * @param k Number of columns in lhs / rows in rhs + * @param lhs Pointer to the left-hand side matrix (quantized int8) + * @param lhs_stride_m Stride between rows of the lhs matrix + * @param rhs Pointer to the right-hand side matrix (quantized int8) + * @param rhs_stride_n Stride between rows of the rhs matrix. Expects matrix + * to be transposed. Thus of size [n x k] + * @param output Pointer to the output matrix (float32) + * @param out_stride_m Stride between rows of the output matrix + * @param lhs_zero_points Zero points for lhs quantization (per-channel) + * @param rhs_zero_points Zero points for rhs quantization (per-channel) + * @param lhs_scales Scales for lhs quantization (per-channel) + * @param rhs_scales Scales for rhs quantization (per-channel) + * @param lhs_qparams_stride Stride for lhs quantization parameters + * @param rhs_qparams_stride Stride for rhs quantization parameters + */ + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + // If lhs_zero_points and rhs_zero_points are not contiguous, transpose + std::unique_ptr lhs_zero_points_transposed = + std::make_unique(m); + std::unique_ptr lhs_scales_transposed = + std::make_unique(m); + if (lhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + lhs_zero_points, + lhs_scales, + lhs_zero_points_transposed.get(), + lhs_scales_transposed.get(), + m, + lhs_qparams_stride); + lhs_zero_points = lhs_zero_points_transposed.get(); + lhs_scales = lhs_scales_transposed.get(); + } + std::unique_ptr rhs_zero_points_transposed = + std::make_unique(n); + std::unique_ptr rhs_scales_transposed = + std::make_unique(n); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + n, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 8 cols at a time + // Access to partial tiles must be protected:w + constexpr int nr = 8; + constexpr int kr = 16; + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n; + int32x4_t int32_sums[nr] = {vdupq_n_s32(0)}; + int32_t row_sum_lhs = 0; + int32_t row_sum_rhs[nr] = {0, 0, 0, 0, 0, 0, 0, 0}; + int32_t sums[nr]; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x8x16( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + int32_sums, + row_sum_lhs, + row_sum_rhs); + lhs_ptr += kr; + rhs_ptr += kr; + } + + reduce_1x8_int32x4_t_sums(int32_sums, sums); + for (int ki = 0; ki < (k - k_idx); ++ki) { + row_sum_lhs += (int32_t)lhs_ptr[ki]; + } + for (int ni = 0; ni < nr; ++ni) { + for (int ki = 0; ki < (k - k_idx); ++ki) { + sums[ni] += (int32_t)lhs_ptr[ki] * + (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; + row_sum_rhs[ni] += (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; + } + } + + float32x4x2_t res; + dequantize_1x8_int32_t( + sums, + row_sum_lhs, + row_sum_rhs, + lhs_zero_points + m_idx, + rhs_zero_points + n_idx, + lhs_scales + m_idx, + rhs_scales + n_idx, + k, + res); + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + float* store_loc = output + m_idx * out_stride_m + n_idx; + vst1q_f32(store_loc, res.val[0]); + vst1q_f32(store_loc + 4, res.val[1]); + } // n_idx + } // m_idx + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal + +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h new file mode 100644 index 0000000000..4005dee564 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -0,0 +1,74 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// TODO: this file will be deleted and replaced by +// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h +// It exists now to prevent breaking existing code in the interim. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot + +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#include +#include + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h new file mode 100644 index 0000000000..68ab912705 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h @@ -0,0 +1,70 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace utils { + +TORCHAO_ALWAYS_INLINE static void transpose_scales_and_zero_points( + const int8_t* zero_points, + const float* scales, + int8_t* zero_points_transposed, + float* scales_transposed, + const int m, + const int stride_m) { + // Process 8 elements at a time using NEON + int i = 0; + for (; i + 8 <= m; i += 8) { + // Load 8 zero points with stride_m + int8x8_t zp = { + zero_points[0 * stride_m], + zero_points[1 * stride_m], + zero_points[2 * stride_m], + zero_points[3 * stride_m], + zero_points[4 * stride_m], + zero_points[5 * stride_m], + zero_points[6 * stride_m], + zero_points[7 * stride_m]}; + zero_points += 8 * stride_m; + // Store contiguously + vst1_s8(zero_points_transposed + i, zp); + + // Load 8 scales with stride_m + float32x4_t scales_lo = { + scales[0 * stride_m], + scales[1 * stride_m], + scales[2 * stride_m], + scales[3 * stride_m]}; + float32x4_t scales_hi = { + scales[4 * stride_m], + scales[5 * stride_m], + scales[6 * stride_m], + scales[7 * stride_m]}; + scales += 8 * stride_m; + // Store contiguously + vst1q_f32(scales_transposed + i, scales_lo); + vst1q_f32(scales_transposed + i + 4, scales_hi); + } + + // Handle remaining elements + for (; i < m; i++) { + zero_points_transposed[i] = zero_points[0]; + scales_transposed[i] = scales[0]; + zero_points += stride_m; + scales += stride_m; + } +} + +} // namespace utils +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp index 65416fdf1d..3460d67fba 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include void torchao::quantization::get_qvals_range( @@ -64,8 +65,6 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int8_t zero, int8_t qmin, int8_t qmax) { - assert(size % 8 == 0); - float32_t invScale = 1.0 / (scale + 1e-16); float32x4_t vec_zero = vdupq_n_f32(zero); float32x4_t vec_invScale = vdupq_n_f32(invScale); @@ -78,7 +77,8 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int16x4_t vec_qval_s16_0; int16x4_t vec_qval_s16_1; - for (int i = 0; i < size; i += 8) { + int i = 0; + for (; (i + 8) < size; i += 8) { ////////////////////////////////////// // Quantize first 4 element chunk to int16 ////////////////////////////////////// @@ -112,6 +112,23 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int8x8_t vec_qval_s8_01 = vqmovn_s16(vec_qval_s16_01); vst1_s8(qvals + i, vec_qval_s8_01); } + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float32_t val = vals[i]; + float32_t qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); } #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 5b6ba2ab98..a01afac68f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -119,6 +119,14 @@ target_link_libraries( dep ) +add_executable(test_qmatmul test_qmatmul.cpp) +target_link_libraries( + test_qmatmul + PRIVATE + GTest::gtest_main + dep +) + include(GoogleTest) gtest_discover_tests(test_quantization) gtest_discover_tests(test_reduction) @@ -127,3 +135,4 @@ gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) gtest_discover_tests(test_embedding) gtest_discover_tests(test_weight_packing) +gtest_discover_tests(test_qmatmul) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 4b2181d7cc..1898e8b535 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -61,3 +61,4 @@ ${CMAKE_OUT}/test_linear ${CMAKE_OUT}/test_valpacking ${CMAKE_OUT}/test_embedding ${CMAKE_OUT}/test_weight_packing +${CMAKE_OUT}/test_qmatmul diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp new file mode 100644 index 0000000000..1b3e11156f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -0,0 +1,229 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +#include +#include +#include + +float kTol = 0.0001; + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: + generate(m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + k * stride /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + k * stride /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + false> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: + generate(m, k, n, a_has_zeros, a_has_zeros, false, false); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + k /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TransposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TransposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 10); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2Strided2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/3, /*k=*/64, /*n=*/24, 7); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, NoTransposedWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 4720b68fb0..80ddcb690d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -84,6 +84,59 @@ inline float get_float_from_bf16(uint16_t bf16) { return f; } +namespace { +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + torchao::quantization::get_qvals_range( + qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, /*vals=*/activations.data() + m_idx * n, /*size=*/n); + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} +} // namespace + struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; int k; @@ -182,34 +235,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // weights is k x n (stored in column-major) // Generate activations - auto activations = get_random_vector(m * k, -1.0, 1.0); - auto activation_qvals = std::vector(m * k, 0); - auto activation_scales = std::vector(m, 0); - auto activation_zeros = std::vector(m, 0); - - // Quantize activations with 8-bit asymmetric - // TODO: replace with generic function that does not use aarch64 - // quantize method after we combine with torchao - int qmin, qmax, zero; - float vmin, vmax, scale; - torchao::quantization::get_qvals_range( - qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); - for (int m_idx = 0; m_idx < m; m_idx++) { - torchao::kernels::cpu::aarch64::reduction::find_min_and_max( - vmin, vmax, /*vals=*/activations.data() + m_idx * k, /*size=*/k); - torchao::quantization::get_scale_and_zero( - scale, zero, vmin, vmax, qmin, qmax); - activation_scales[m_idx] = scale; - activation_zeros[m_idx] = zero; - torchao::kernels::cpu::aarch64::quantization::quantize( - /*qvals=*/activation_qvals.data() + m_idx * k, - /*vals=*/activations.data() + m_idx * k, - /*size=*/k, - scale, - zero, - qmin, - qmax); - } + auto [activations, activation_qvals, activation_scales, activation_zeros] = + generate_per_token_quantized_tensor(m, k); // Generate weights assert(k % weight_group_size == 0); @@ -219,6 +246,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { auto weight_scales = std::vector(n_weight_groups, 0.0); auto weight_zeros = std::vector(n_weight_groups, 0); + int qmin, qmax, zero; + float vmin, vmax, scale; // Quantize weights with weight_nbit // TODO: replace with generic function that does not use aarch64 // quantize method after we combine with torchao @@ -322,6 +351,150 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { } }; +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + // !Rhs transposed was considered if we were doing quantized(softmax(q@k)) @ + // quantized(v) Since v would have been [B, H, S, D]. And [S, D] would be + // rhs matrix which is not transposed when considered matmul terminology + // because for matmul we would have A[S_q, S] x B[S, D]. + // It would have been transposed if A[S_q, S] x B[D, S]. + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + template struct lowbit_embedding_test_case { int num_embeddings; From e4eff3aa7aa0f977c261da6cb83a28dd53b3a303 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 2 Apr 2025 10:25:19 -0700 Subject: [PATCH 04/19] Allow builds on less than sm75 raise runtime failure (#1999) stack-info: PR: https://github.com/pytorch/ao/pull/1999, branch: drisspg/stack/45 --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 52 +++++++++++++++------ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 19 ++++++-- torchao/ops.py | 13 ++++++ 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 531e1ba7e6..26f6494220 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -21,6 +21,7 @@ // // MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory +// - Added proper architecture check at both host and device level // @@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream, static_assert(std::is_same::value || std::is_same::value, "Type must be 'half' or '__nv_bfloat16'"); assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); - assert(N_Global>0); + assert(N_Global > 0); + + // Check GPU Compute Capability before proceeding + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + // Early exit with error for unsupported architectures + if ((major < 7) || (major == 7 && minor < 5)) { + TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " + "Your current device has SM", major, minor, " which is not supported."); + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && std::is_same::value) { + TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); + } // Work around to support more N shapes: size_t N_PowerOf2; @@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream, if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - // Check GPU Compute Capability - int device, major, minor; - CHECK_CUDA(cudaGetDevice(&device)); - CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - const bool is_sm75_gpu = (major == 7) && (minor == 5); - if (is_sm75_gpu && std::is_same::value) - TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75"); - if ((major < 7) || (major == 7 && minor < 5)) - TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); - if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. if (Split_K == 1) { @@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); + TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } @@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); + TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } @@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda( torch::Tensor _scales, int64_t splitK=1) { + // Check GPU Compute Capability before proceeding + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + // Early exit with error for unsupported architectures + if ((major < 7) || (major == 7 && minor < 5)) { + TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " + "Your current device has SM", major, minor, " which is not supported."); + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && _in_feats.scalar_type() == at::ScalarType::BFloat16) { + TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); + } + const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index d4be92b227..096bdc0d7f 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -51,17 +51,14 @@ * B: col major, FP16 * C: col major, FP16 */ - template +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 - static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required."); - // __trap(); // fails at runtime instead of compile time - #endif #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -233,3 +230,15 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, } } } +#else +// Stub implementation for older architectures +template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half *B, + OutputDataType* C, + const size_t M_Global, const size_t N_Global, const size_t K_Global, + int Split_K) +{ +// NOOP, should never actually be called +} +#endif diff --git a/torchao/ops.py b/torchao/ops.py index 34a97d03f5..5bc71321ac 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -71,6 +71,13 @@ def decorator(func): return decorator +@functools.lru_cache +def cached_compute_capability(): + device_props = torch.cuda.get_device_properties(torch.cuda.current_device()) + compute_capability = device_props.major * 10 + device_props.minor + return compute_capability + + def quant_llm_linear( EXPONENT: int, MANTISSA: int, @@ -93,6 +100,12 @@ def quant_llm_linear( Returns output of linear layer """ + # Check if we're on a supported architecture (sm7.5 or higher) + compute_capability = cached_compute_capability() + torch._check( + compute_capability >= 75, + lambda: f"quant_llm_linear requires sm7.5+ GPU architecture, but current device has sm{compute_capability}", + ) return torch.ops.torchao.quant_llm_linear.default( EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK ) From 9a9ecde937b989e46d081eff427785965b103f7b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 2 Apr 2025 11:27:34 -0700 Subject: [PATCH 05/19] Skip galore test if not cuda (#2003) Summary: fixing CI before branch cut Test Plan: python test/quantization/test_galore_quant.py and CI Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_galore_quant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index a67f7775b1..0ebc356114 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -38,6 +38,7 @@ @pytest.mark.skip("skipping for now, see comments below") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize( "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, @@ -89,6 +90,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): TEST_CONFIGS, ) @skip_if_rocm("ROCm enablement in progress") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 From e2369d34b1d21959837045c0d95c47638ee173a5 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:28:07 -0700 Subject: [PATCH 06/19] Fix experimental CI (#2005) * up * up --- .github/workflows/torchao_experimental_test.yml | 11 +++-------- dev-requirements.txt | 3 +++ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 8d274b62e7..0cb470901e 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -37,9 +37,7 @@ jobs: # of torch and torchao, which we do not want to use pip install executorch pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall - pip install numpy - pip install pytest - pip install parameterized + pip install -r dev-requirements.txt USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install . - name: Run python tests run: | @@ -99,11 +97,8 @@ jobs: python -c "import torch; print(torch.__version__)" - name: Install requirements run: | - pip install cmake - pip install parameterized - pip install pyyaml - pip install numpy - pip install importlib-metadata + pip install -r dev-requirements.txt + pip install pyyaml importlib-metadata - name: Print pip freeze run: | pip freeze diff --git a/dev-requirements.txt b/dev-requirements.txt index f5b1599ffa..1982d76795 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -26,6 +26,9 @@ importlib_metadata # Custom CUDA Extensions ninja +# CPU kernels +cmake<4.0.0,>=3.19.0 + # Linting ruff==0.6.8 pre-commit From b49f23c673701c7319ec72df430b8c038a549a61 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 12:15:41 -0700 Subject: [PATCH 07/19] Add fp32xint8 matmul Differential Revision: D71370597 Pull Request resolved: https://github.com/pytorch/ao/pull/2004 --- ...input_channelwise_8bit_b_1x16x4_f32_impl.h | 275 ++++++++++++++++++ .../kernels/cpu/aarch64/matmul/matmul.h | 21 ++ .../cpu/aarch64/tests/test_qmatmul.cpp | 185 ++++++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 17 +- 4 files changed, 489 insertions(+), 9 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h new file mode 100644 index 0000000000..389abb32a5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -0,0 +1,275 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal { + +namespace { + +/* +This function loads float32x4_t value from a, and 16 int8x16_t values from b. +For each int8x16_t of b: +- 4 float32x4 accumulated values +- load 4 a in float32x4_t +- [The following repeats for each of the 4 lanes of a] +- for i in [0, 4]: + - load b[i] in int8x16_t + - subl to subtract b_zero_point from b, to get b_low, b_high + - vmovl to get b_low_low, b_low_high, b_high_low, b_high_high + - vcvtq to convert to float32x4_t, we will have 4 of these. +- for i in [0, 4]: for each of the 4 float32x4_t of b: + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] +- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a + and accumulated 4 float32x4_t values +*/ +TORCHAO_ALWAYS_INLINE void block_mul_1x16x1( + const float32_t a, + const int8x16_t& b_vec, + const int8_t b_zero_point, + const float b_scale, + float32x4_t (&partial_sums)[4]) { + int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point); + int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec); + int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec); + float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low))); + float32x4_t b_vec_low_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low))); + float32x4_t b_vec_high_low = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high))); + float32x4_t b_vec_high_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high))); + b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale); + b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale); + b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale); + b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale); + + partial_sums[0] = vfmaq_n_f32(partial_sums[0], b_vec_low_low, a); + partial_sums[1] = vfmaq_n_f32(partial_sums[1], b_vec_low_high, a); + partial_sums[2] = vfmaq_n_f32(partial_sums[2], b_vec_high_low, a); + partial_sums[3] = vfmaq_n_f32(partial_sums[3], b_vec_high_high, a); +} + +void block_mul_1x16x4( + const float32_t* a, + const int8_t* b, + const size_t ldb, + const int8_t* b_zero_point, + const float* b_scale, + float32x4_t (&partial_sums)[4]) { + #pragma unroll(8) + for (int i = 0; i < 4; i++) { + int8x16_t b_vec = vld1q_s8(b + i * ldb); + block_mul_1x16x1(a[i], b_vec, b_zero_point[i], b_scale[i], partial_sums); + } +} + +} // namespace + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + std::unique_ptr rhs_zero_points_transposed = std::make_unique(k); + std::unique_ptr rhs_scales_transposed = std::make_unique(k); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + k, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + constexpr int nr = 16; + constexpr int kr = 4; + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 16 cols at a time + // Access to partial tiles must be protected:w + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const float* lhs_ptr = lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = rhs + n_idx; + float32x4_t sums[nr / 4] = {vdupq_n_f32(0)}; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x16x4( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + rhs_zero_points + k_idx, + rhs_scales + k_idx, + sums); + lhs_ptr += kr; + rhs_ptr += kr * rhs_stride_n; + } + + for (int ki = 0; ki < (k - k_idx); ++ki) { + // For each of the remaining k values + // Load 1 int8_t from lhs + // Load 16 int8_t from rhs + // And multiply + add into the 16 accumulators + // arranged as int32x4_t[4] + int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); + block_mul_1x16x1( + lhs_ptr[ki], + rhs_vec, + rhs_zero_points[k_idx + ki], + rhs_scales[k_idx + ki], + sums); + } + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + // Note that the reason this case is being handled only for this kernel + // and not others in this directory is because only for this kernel + // we support accumulation. + float* store_loc = output + m_idx * out_stride_m + n_idx; + if (remaining < 16) { + // If remaining is < 16, then not all of the 16 accumulators are + // valid. That is not all of float32x4_t[4] are valid. We need to + // find the first valid one, and then store the rest of the + // accumulators in the same order. + // First valid one is at 3 - ((remaining - 1) / 4) because: + // If remaining is say 10 then first 6 are not valid. + // Thus first group of 4 at sums[0] is not valid. + // In the second group of 4, the first 2 are not valid. + // Rest are valid. + int start_sum_idx = 3 - ((remaining - 1) / 4); + // If remaining is 11, then the sums[1] has 3 valid values + // so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1 + // Thus there is 1 invalid value in the first group of 4 + int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4; + store_loc += start_sum_idx * 4; + store_loc += invalid_values_in_32x4_reg; + if (invalid_values_in_32x4_reg > 0) { + for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4; + ++val_idx) { + *store_loc = sums[start_sum_idx][val_idx] + (*store_loc) * beta; + store_loc += 1; + } + start_sum_idx++; + } + for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4; + out_idx += 4, ++sum_idx) { + float32x4_t sum_val = vld1q_f32(store_loc + out_idx); + sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); + vst1q_f32(store_loc + out_idx, sums[sum_idx]); + } + } else { + for (int out_idx = 0, sum_idx = 0; out_idx < nr; + out_idx += 4, ++sum_idx) { + float32x4_t sum_val = vld1q_f32(store_loc + out_idx); + sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); + vst1q_f32(store_loc + out_idx, sums[sum_idx]); + } + } + } // n_idx + } // m_idx + } +}; + +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal + +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + rhs_zero_points, + rhs_scales, + beta, + rhs_qparams_stride); +} +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h index 4005dee564..43f3dd4bce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -66,9 +66,30 @@ void kernel( const int rhs_qparams_stride); } // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal + +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { + +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); + +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 } // namespace torchao::kernels::cpu::aarch64::quantized_matmul #include #include +#include #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 1b3e11156f..e7e2d09c64 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -226,4 +226,189 @@ TEST( /*m=*/4, /*k=*/37, /*n=*/19); } +class FP32A_QuantizedB_FP32C_Test : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + // Here stride is only applicable to rhs + // and it means that k elements are stride * napart for k x n matrix + // and stride apart for n x k matrix + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = torchao::get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + torchao::test_utils::generate_per_token_quantized_tensor( + k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = torchao::get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + k /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + n * stride /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPoints) { + generate(1, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsLargeM) { + generate(4, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 24, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/24, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes2) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes3) { + generate(4, 27, 21, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/27, /*n=*/21, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsAlpha) { + generate(1, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsWithStrides) { + stride = 5; + generate(1, 128, 16, true, false, false, stride); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this, stride); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes2Strides) { + stride = 11; + generate(7, 37, 19, true, false, false, stride); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/7, /*k=*/37, /*n=*/19, beta(), *this, stride); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Test, + ::testing::Values(0.0, 1.0, 2.69)); + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 80ddcb690d..e411211eb4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -84,11 +84,9 @@ inline float get_float_from_bf16(uint16_t bf16) { return f; } -namespace { -auto generate_per_token_quantized_tensor( - int m, - int n, - bool transposed = false) { +namespace test_utils { +auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false); +auto generate_per_token_quantized_tensor(int m, int n, bool transposed) { auto activations = get_random_vector(m * n, -1.0, 1.0); auto activation_qvals = std::vector(m * n, 0); auto activation_scales = std::vector(m, 0); @@ -135,7 +133,7 @@ auto generate_per_token_quantized_tensor( return std::make_tuple( activations, activation_qvals, activation_scales, activation_zeros); } -} // namespace +} // namespace test_utils struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; @@ -236,7 +234,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // Generate activations auto [activations, activation_qvals, activation_scales, activation_zeros] = - generate_per_token_quantized_tensor(m, k); + test_utils::generate_per_token_quantized_tensor(m, k); // Generate weights assert(k % weight_group_size == 0); @@ -441,10 +439,11 @@ struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { assert(rhs_is_transposed || stride == 1); // Generate activations auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = - generate_per_token_quantized_tensor(m * stride, k); + test_utils::generate_per_token_quantized_tensor(m * stride, k); auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = - generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + test_utils::generate_per_token_quantized_tensor( + n * stride, k, !rhs_is_transposed); // Above function produces nxk matrix and to produce kxn you need transposed // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true // the shape should be nxk instead of kxn. From 8e8472cb31d8431c91bdf235f1f3457df3094dc0 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 13:56:54 -0700 Subject: [PATCH 08/19] Add quantized q @ k test for intented used in quantized attention Differential Revision: D71370604 Pull Request resolved: https://github.com/pytorch/ao/pull/2006 --- .../cpu/aarch64/tests/test_qmatmul.cpp | 98 ++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 1 + .../tests/test_utils_quantized_attention.h | 235 ++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index e7e2d09c64..344b2c4915 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -12,6 +12,7 @@ #include #include #include +#include float kTol = 0.0001; @@ -411,4 +412,101 @@ INSTANTIATE_TEST_SUITE_P( FP32A_QuantizedB_FP32C_Test, ::testing::Values(0.0, 1.0, 2.69)); +static void test_8bit_per_token_q_at_k_matmul_attention( + int b, + int s_q, + int s_k, + int h, + int d, + bool transpose = true) { + auto test_case = torchao:: + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case:: + generate(b, s_q, s_k, h, d, transpose); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + size_t q_b_stride = test_case.b_q_stride; + size_t q_h_stride = test_case.h_q_stride; + size_t q_s_q_stride = test_case.s_q_stride; + size_t q_scale_zp_b_stride = test_case.b_q_qparams_stride; + size_t q_scale_zp_h_stride = test_case.h_q_qparams_stride; + size_t q_scale_zp_s_stride = test_case.s_q_qparams_stride; + + size_t k_b_stride = test_case.b_k_stride; + size_t k_h_stride = test_case.h_k_stride; + size_t k_s_k_stride = test_case.s_k_stride; + size_t k_scale_zp_b_stride = test_case.b_k_qparams_stride; + size_t k_scale_zp_h_stride = test_case.h_k_qparams_stride; + size_t k_scale_zp_s_stride = test_case.s_k_qparams_stride; + + std::vector output(b * h * s_q * s_k); + size_t output_b_stride = h * s_q * s_k; + size_t output_h_stride = s_q * s_k; + size_t output_s_q_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_q, + s_k, + d, + test_case.q_qvals.data() + b_idx * q_b_stride + h_idx * q_h_stride, + q_s_q_stride /*lhs_stride_m*/, + test_case.k_qvals.data() + b_idx * k_b_stride + h_idx * k_h_stride, + k_s_k_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_q_stride /*out_stride_n*/, + test_case.q_zeros.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_zeros.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + test_case.q_scales.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_scales.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + q_scale_zp_s_stride /*lhs qparams stride*/, + k_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * h * s_q * s_k; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSk) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicNoTransposed) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSkNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index e411211eb4..4f96f8bf96 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -86,6 +86,7 @@ inline float get_float_from_bf16(uint16_t bf16) { namespace test_utils { auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false); + auto generate_per_token_quantized_tensor(int m, int n, bool transposed) { auto activations = get_random_vector(m * n, -1.0, 1.0); auto activation_qvals = std::vector(m * n, 0); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h new file mode 100644 index 0000000000..9ca86ece76 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -0,0 +1,235 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include +#include +#include +#include + +namespace torchao { +struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { + int b; + int s_q; + int s_k; + int h; + int d; + bool tranposed; + + size_t b_q_stride; + size_t h_q_stride; + size_t s_q_stride; + + size_t b_k_stride; + size_t h_k_stride; + size_t s_k_stride; + + size_t b_q_qparams_stride; + size_t h_q_qparams_stride; + size_t s_q_qparams_stride; + + size_t b_k_qparams_stride; + size_t h_k_qparams_stride; + size_t s_k_qparams_stride; + + std::vector expected_output; + + std::vector q; + std::vector q_qvals; + std::vector q_scales; + std::vector q_zeros; + + std::vector k; + std::vector k_qvals; + std::vector k_scales; + std::vector k_zeros; + + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + int b_, + int s_q_, + int s_k_, + int h_, + int d_, + int transposed_, + size_t b_q_stride_, + size_t h_q_stride_, + size_t s_q_stride_, + size_t b_k_stride_, + size_t h_k_stride_, + size_t s_k_stride_, + size_t b_q_qparams_stride_, + size_t h_q_qparams_stride_, + size_t s_q_qparams_stride_, + size_t b_k_qparams_stride_, + size_t h_k_qparams_stride_, + size_t s_k_qparams_stride_, + std::vector expected_output_, + std::vector q_, + std::vector q_qvals_, + std::vector q_scales_, + std::vector q_zeros_, + std::vector k_, + std::vector k_qvals_, + std::vector k_scales_, + std::vector k_zeros_) + : b(b_), + s_q(s_q_), + s_k(s_k_), + h(h_), + d(d_), + tranposed(transposed_), + b_q_stride(b_q_stride_), + h_q_stride(h_q_stride_), + s_q_stride(s_q_stride_), + b_k_stride(b_k_stride_), + h_k_stride(h_k_stride_), + s_k_stride(s_k_stride_), + b_q_qparams_stride(b_q_qparams_stride_), + h_q_qparams_stride(h_q_qparams_stride_), + s_q_qparams_stride(s_q_qparams_stride_), + b_k_qparams_stride(b_k_qparams_stride_), + h_k_qparams_stride(h_k_qparams_stride_), + s_k_qparams_stride(s_k_qparams_stride_), + expected_output(expected_output_), + q(q_), + q_qvals(q_qvals_), + q_scales(q_scales_), + q_zeros(q_zeros_), + k(k_), + k_qvals(k_qvals_), + k_scales(k_scales_), + k_zeros(k_zeros_) { + assert(expected_output.size() == b * s_q * h * s_k); + assert(q.size() == b * s_q * h * d); + assert(q_qvals.size() == b * s_q * h * d); + assert(q_scales.size() == b * s_q * h); + assert(q_zeros.size() == b * s_q * h); + assert(k.size() == b * s_k * h * d); + assert(k_qvals.size() == b * s_k * h * d); + assert(k_scales.size() == b * s_k * h); + assert(k_zeros.size() == b * s_k * h); + } + + static channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case + generate(int b, int s_q, int s_k, int h, int d, bool transposed = true) { + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_q * h, d); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_k * h, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_q_stride = h * s_q * d; + size_t h_q_stride = s_q * d; + size_t s_q_stride = d; + + size_t b_k_stride = h * s_k * d; + size_t h_k_stride = s_k * d; + size_t s_k_stride = d; + + size_t b_q_qparams_stride = h * s_q; + size_t h_q_qparams_stride = s_q; + size_t s_q_qparams_stride = 1; + + size_t b_k_qparams_stride = h * s_k; + size_t h_k_qparams_stride = s_k; + size_t s_k_qparams_stride = 1; + + if (!transposed) { + h_q_stride = d; + s_q_stride = h * d; + h_k_stride = d; + s_k_stride = h * d; + + s_q_qparams_stride = h; + h_q_qparams_stride = 1; + + s_k_qparams_stride = h; + h_k_qparams_stride = 1; + } + + // Compute expected output + std::vector expected_output(b * h * s_q * s_k); + size_t b_out_stride = h * s_q * s_k; + size_t h_out_stride = s_q * s_k; + size_t s_q_out_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_q_idx = 0; s_q_idx < s_q; s_q_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int s_k_idx = 0; s_k_idx < s_k; s_k_idx++) { + float res = 0.0; + for (int d_idx = 0; d_idx < d; d_idx++) { + int lhs_idx = b_idx * b_q_stride + s_q_idx * s_q_stride + + h_idx * h_q_stride + d_idx; + int rhs_idx = b_idx * b_k_stride + s_k_idx * s_k_stride + + h_idx * h_k_stride + d_idx; + int lhs_scales_zp_idx = b_idx * b_q_qparams_stride + + h_idx * h_q_qparams_stride + s_q_idx * s_q_qparams_stride; + int rhs_scales_zp_idx = b_idx * b_k_qparams_stride * h + + h_idx * h_k_qparams_stride + s_k_idx * s_k_qparams_stride; + float lhs_dequant = lhs_scales[lhs_scales_zp_idx] * + (lhs_qvals[lhs_idx] - lhs_zeros[lhs_scales_zp_idx]); + + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs_dequant * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_q_idx * s_q_out_stride + + h_idx * h_out_stride + s_k_idx] = res; + } + } + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + b, + s_q, + s_k, + h, + d, + transposed, + b_q_stride, + h_q_stride, + s_q_stride, + b_k_stride, + h_k_stride, + s_k_stride, + b_q_qparams_stride, + h_q_qparams_stride, + s_q_qparams_stride, + b_k_qparams_stride, + h_k_qparams_stride, + s_k_qparams_stride, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + +} // namespace torchao + +#endif // defined(__aarch64__) || defined(__ARM_NEON) From e52867a39caa90702c2840b495cc77bfcb3ba769 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 2 Apr 2025 15:57:03 -0700 Subject: [PATCH 09/19] Update version.txt (#2009) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 78bc1abd14..d9df1bbc0c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0 +0.11.0 From 620356dee0c1959c2429b34edaebc39b6c199200 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 2 Apr 2025 16:20:53 -0700 Subject: [PATCH 10/19] Initial prototype of differentiable _scaled_grouped_mm function (#1969) --- .../prototype/scaled_grouped_mm/__init__.py | 3 + .../scaled_grouped_mm/scaled_grouped_mm.py | 361 ++++++++++++++++++ .../test_scaled_grouped_mm.py | 196 ++++++++++ 3 files changed, 560 insertions(+) create mode 100644 torchao/prototype/scaled_grouped_mm/__init__.py create mode 100644 torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py create mode 100644 torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py diff --git a/torchao/prototype/scaled_grouped_mm/__init__.py b/torchao/prototype/scaled_grouped_mm/__init__.py new file mode 100644 index 0000000000..9c6278884a --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/__init__.py @@ -0,0 +1,3 @@ +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( + _scaled_grouped_mm as _scaled_grouped_mm, +) diff --git a/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py b/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py new file mode 100644 index 0000000000..a431288c07 --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py @@ -0,0 +1,361 @@ +from typing import Optional, Tuple + +import torch + +from torchao.float8.config import ScalingGranularity +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated + + +def _scaled_grouped_mm( + A: torch.Tensor, + B_t: torch.Tensor, + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + This function performs dynamic float8 quantization with row-wise scaling + on the input tensors A and B, then performs a scaled grouped GEMM and returns the results. + + Args: + A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K) + and in row-major memory layout. + B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N) + and in column-major memory layout. + offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. + out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. + """ + return _Float8GroupedMM.apply( + A, + B_t, + offs, + out_dtype, + ) + + +class _Float8GroupedMM(torch.autograd.Function): + """Differentiable implementation of grouped GEMM with dynamic float8 quantization.""" + + @staticmethod + def forward( + ctx, + A: torch.Tensor, + B_t: torch.Tensor, + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + # torchao _scaled_grouped_mm only supports A=2D, B=3D. + assert A.ndim == 2, "A must be 2D" + assert B_t.ndim == 3, "B must be 3D" + + assert ( + A.size(-1) % 16 == 0 + ), f"A must have a last dim divisible by 16, but got shape: {A.shape}" + assert ( + B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0 + ), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}" + + # Assert input tensors are in high-precision dtypes. + assert ( + A.dtype == torch.float32 or A.dtype == torch.bfloat16 + ), "A must be float32 or bfloat16" + assert ( + B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16 + ), "B must be float32 or bfloat16" + assert offs.dtype == torch.int32, "offs must be int32" + + # Assert A and B dims are compatible for a scaled grouped GEMM. + assert A.size(-1) == B_t.size( + -2 + ), f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm" + + # The left operand in the scaled grouped GEMM must be row-major due to hardware requirements. + assert not _is_column_major(A), "A must be row-major" + + # Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major. + assert _is_column_major(B_t), "B must be column-major" + + # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM. + # A shape: (M, K) + # A_scales shape: (M,1) + A_scales = tensor_to_scale( + A, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B to float8, column-major for right operand of grouped GEMM. + # B shape: (B, K, N) + # B scales must be computed rowwise keeping the outer/final dim, so: + # B_scales shape: (B, 1, N) + B_t_scales = tensor_to_scale( + B_t, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) + + # Precompute non-transposed B column-major for backward, to save memory by storing the + # low precision B tensor instead of the high precision B tensor. + # In the backward this is needed for grad_A: grad_output @ B. + B = B_t.contiguous().transpose(-2, -1) + + # - B shape: (B, K, N) + # - B scales must be computed rowwise keeping the outer/final dim, so: + # - B_scale shape: (B, 1, N) + B_scales = tensor_to_scale( + B, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_scaled = B.to(torch.float32) * B_scales + B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) + + # Store what we need for backward. + ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) + ctx.out_dtype = out_dtype + + # Perform scaled grouped GEMM and return result. + # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) + return torch._scaled_grouped_mm( + A_fp8_row_major, + B_t_fp8_col_major, + A_scales.squeeze().reciprocal(), + B_t_scales.squeeze().reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors + out_dtype = ctx.out_dtype + + # Convert grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_A: grad_output @ B + # + # grad_output shape: (M, N) + # grad_output_scale shape: (M, 1) + grad_output_scales = tensor_to_scale( + grad_output, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales + grad_output_fp8_row_major = to_fp8_saturated( + grad_output_scaled, torch.float8_e4m3fn + ) + + # Compute grad_A. + # + # grad_A = grad_output @ B + # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) + grad_A = torch._scaled_grouped_mm( + grad_output_fp8_row_major, + B_fp8_col_major, + grad_output_scales.squeeze().reciprocal(), + B_scales.squeeze().reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + # Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_B: grad_output_t @ A + grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous() + + # Convert A to float8, column-major for right operand of grouped GEMM: + # needed for grad_B: grad_output @ A + A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1) + + # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." + # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. + grad_output_t_fp8_row_major, grad_output_t_scales = ( + _to_2d_jagged_float8_tensor_rowwise( + grad_output_t_row_major, + offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + ) + A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise( + A_col_major, + offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + + # Compute grad_B = grad_output_t @ A. + # grad_B = grad_output_t @ A + # grad_B = (N,M) @ (M,K) = (N,K) + grad_B = torch._scaled_grouped_mm( + grad_output_t_fp8_row_major, + A_fp8_col_major, + grad_output_t_scales.reciprocal(), + A_scales.reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + return grad_A, grad_B.transpose(-2, -1), None, None, None, None + + +def _to_2d_jagged_float8_tensor_colwise( + A_col_major: torch.Tensor, + offs: torch.Tensor, + target_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 2D input tensor A to a jagged float8 tensor, + with scales computed along *logical columns* for each group individually, + where groups are determined based on the offsets. + + For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns. + (i.e., a tensor of (K,N) will have scales of shape (1,N). + + However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct + groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales + along the logical columns and apply it to the entire tensor. + + Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results + in scales of shape (1,N * num_groups). + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor. + + Returns: + A tuple containing the jagged float8 tensor and the scales used for the conversion. + """ + assert A_col_major.ndim == 2, "A must be 2D" + + num_groups = offs.numel() + A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype) + A_scales = torch.empty( + A_fp8_col_major.size(1) * num_groups, + dtype=torch.float32, + device=A_fp8_col_major.device, + ) + + start_idx = 0 + next_scale_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching the next group of rows, with all columns for each. + subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K) + + # Compute local rowwise scales for this subtensor, which are along logical columns for the right operand. + subtensor_scales = tensor_to_scale( + subtensor, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Apply scales to subtensor and convert to float8. + tensor_scaled = subtensor.to(torch.float32) * subtensor_scales + float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # Store this portion of the resulting float8 tensor and scales. + A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor + A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = ( + subtensor_scales.squeeze() + ) + + # Update start index for next group. + start_idx = end_idx + next_scale_idx += subtensor_scales.numel() + + return A_fp8_col_major, A_scales + + +def _to_2d_jagged_float8_tensor_rowwise( + x: torch.Tensor, + offs: torch.Tensor, + target_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 2D input tensor to a jagged float8 tensor, + with scales computed along *logical rows* for each group individually, + where groups are determined based on the offsets. + + For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows. + (i.e., a tensor of (M,K) will have scales of shape (M,1). + + However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct + groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales + along the logical rows and apply it to the entire tensor. + + Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results + in scales of shape (M * num_groups, 1). + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor. + + Returns: + A tuple containing the jagged float8 tensor and the scales used for the conversion. + """ + assert x.ndim == 2, "input tensor must be 2D" + + num_groups = offs.numel() + x_fp8 = torch.empty_like(x, dtype=target_dtype) + x_scales = torch.empty( + x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device + ) + + start_idx = 0 + next_scale_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching all rows with the next group of rows. + subtensor = x[:, start_idx:end_idx] # (M, local_group_size) + + # Compute local rowwise scales for this subtensor, which are along logical rows for the left operand. + subtensor_scales = tensor_to_scale( + subtensor, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Apply scales to subtensor and convert to float8. + tensor_scaled = subtensor.to(torch.float32) * subtensor_scales + float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # Store this portion of the resulting float8 tensor and scales. + x_fp8[:, start_idx:end_idx] = float8_subtensor + x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = ( + subtensor_scales.squeeze() + ) + + # Update start index for next group. + start_idx = end_idx + next_scale_idx += subtensor_scales.numel() + + return x_fp8, x_scales + + +def _is_column_major(x: torch.Tensor) -> bool: + """ + This function checks if the input tensor is column-major. + + Args: + x (torch.Tensor): The input tensor to be checked. + + Returns: + A boolean indicating whether the input tensor is column-major. + """ + assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" + return x.stride(-2) == 1 and x.stride(-1) > 1 diff --git a/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py b/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py new file mode 100644 index 0000000000..cd347c3d9d --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py @@ -0,0 +1,196 @@ +import pytest +import torch + +from torchao.float8.config import ( + Float8LinearConfig, + Float8LinearRecipeName, +) +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args +from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( + _scaled_grouped_mm, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_valid_scaled_grouped_mm_2d_3d(): + out_dtype = torch.bfloat16 + device = "cuda" + m, n, k, n_groups = 16, 32, 16, 4 + a = torch.randn( + m * n_groups, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + b = torch.randn( + n_groups, + n, + k, + device=device, + dtype=torch.bfloat16, + ) + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + # b must be transposed and in column major format. + b_t = b.contiguous().transpose(-2, -1).requires_grad_(True) + + # Compute output. + out = _scaled_grouped_mm( + a, + b_t, + offs=offs, + out_dtype=out_dtype, + ) + + # Validate result. + ref_a = a.detach().clone().requires_grad_(True) + ref_b_t = b_t.detach().clone().requires_grad_(True) + ref_out = compute_reference_forward( + out, + ref_a, + ref_b_t, + n_groups, + out_dtype, + offs, + ) + assert torch.equal(out, ref_out) + + # Run backward pass. + out.sum().backward() + ref_out.sum().backward() + + # Validate gradients. + assert torch.equal(a.grad, ref_a.grad) + assert torch.equal(b_t.grad, ref_b_t.grad) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("m", [16, 17]) +@pytest.mark.parametrize("k", [16, 18]) +@pytest.mark.parametrize("n", [32, 33]) +def test_K_or_N_dim_not_multiple_of_16(m, n, k): + # - Leading dim of A doesn't have to be divisible by 16, since it will be + # divided up into groups based on offset anyway. + # - Trailing dim of A must be divisible by 16. + # - Leading dim of B (n_groups) doesn't need to be divisible by 16. + # - Last 2 dims of B must be divisible by 16. + if n % 16 == 0 and k % 16 == 0: + return + out_dtype = torch.bfloat16 + device = "cuda" + n_groups = 4 + a = torch.randn( + m * n_groups, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + b = torch.randn( + n_groups, + n, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + + # b must be transposed and in column major format. + b_t = b.transpose(-2, -1) + b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1) + + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + # Compute output. + with pytest.raises(AssertionError): + _scaled_grouped_mm( + a, + b_t, + offs=offs, + out_dtype=out_dtype, + ) + + +def compute_reference_forward( + result: torch.Tensor, + A: torch.Tensor, + B_t: torch.Tensor, + n_groups: int, + out_dtype: torch.dtype, + offs: torch.Tensor, +): + assert result.dtype == out_dtype + + # Use official rowwise recipe as reference to ensure implementation is correct. + float8_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + + # Convert A to fp8. + A_scales = tensor_to_scale( + A, + float8_config.cast_config_input.target_dtype, + scaling_granularity=float8_config.cast_config_input.scaling_granularity, + axiswise_dim=-1, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8 = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B^t to fp8. + B_t_scales = tensor_to_scale( + B_t, + float8_config.cast_config_weight.target_dtype, + scaling_granularity=float8_config.cast_config_weight.scaling_granularity, + axiswise_dim=-2, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8 = to_fp8_saturated( + B_t_scaled, + torch.float8_e4m3fn, + ) + + # Split A and result into chunks, one for each group. + offs_cpu = offs.cpu() + A_list, A_list_fp8, A_scale_list, result_list = [], [], [], [] + start = 0 + for i in range(n_groups): + A_list.append(A[start : offs_cpu[i]]) + A_list_fp8.append(A_fp8[start : offs_cpu[i]]) + A_scale_list.append(A_scales[start : offs_cpu[i]]) + result_list.append(result[start : offs_cpu[i]]) + start = offs_cpu[i] + + # Validate each actual result group from the _scaled_grouped_mm is equal to: + # 1. A manual _scaled_mm for the group. + # 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients). + outputs = [] + list1 = list(zip(A_list_fp8, B_t_fp8, A_scale_list, B_t_scales, result_list)) + list2 = list(zip(A_list, B_t, result_list)) + for i in range(len(list1)): + a1, b1, a1scale, b1scale, result1 = list1[i] + ref_group_result1 = torch._scaled_mm( + a1, + b1, + a1scale.reciprocal(), + b1scale.reciprocal(), + out_dtype=out_dtype, + bias=None, + use_fast_accum=float8_config.gemm_config_output.use_fast_accum, + ) + a2, b2, result2 = list2[i] + ref_group_result2 = matmul_with_hp_or_float8_args.apply( + a2, + b2, + LinearMMConfig(), + float8_config, + ) + assert torch.equal(result1, ref_group_result1) + assert torch.equal(result2, ref_group_result2) + outputs.append(ref_group_result2) + + # Concatenate the outputs and verify the full result is correct. + output_ref = torch.cat(outputs, dim=0) + return output_ref From 6987576a0a9a44a3018d149fe3339780bd1beb6d Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 17:17:56 -0700 Subject: [PATCH 11/19] Add quantized attn_scores @ v test for intented used in quantized attention Differential Revision: D71370603 Pull Request resolved: https://github.com/pytorch/ao/pull/2008 --- ...input_channelwise_8bit_b_1x16x4_f32_impl.h | 6 + .../cpu/aarch64/tests/test_qmatmul.cpp | 87 +++++++++ .../tests/test_utils_quantized_attention.h | 168 ++++++++++++++++++ 3 files changed, 261 insertions(+) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h index 389abb32a5..bdad1b4a47 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -101,6 +101,12 @@ struct KernelImpl { const int rhs_qparams_stride); }; +/* +Document param meaning +rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n. +Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements +in k dim are. +*/ template <> struct KernelImpl { static void run( diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 344b2c4915..05dbf13aac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -509,4 +509,91 @@ TEST( test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); } +static void test_fp32_attn_scores_at_v_matmul_attention( + int b, + int s_attn, + int s_v, + int h, + int d, + bool transpose_v = true) { + auto test_case = + torchao::fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case::generate( + b, s_attn, s_v, h, d, transpose_v); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + size_t attn_b_stride = test_case.b_attn_stride; + size_t attn_h_stride = test_case.h_attn_stride; + size_t attn_s_q_stride = test_case.s_attn_stride; + + size_t v_b_stride = test_case.b_v_stride; + size_t v_h_stride = test_case.h_v_stride; + size_t v_s_v_stride = test_case.s_v_stride; + size_t v_scale_zp_b_stride = test_case.b_v_qparams_stride; + size_t v_scale_zp_h_stride = test_case.h_v_qparams_stride; + size_t v_scale_zp_s_stride = test_case.s_v_qparams_stride; + + std::vector output(b * s_attn * h * d); + size_t output_b_stride = s_attn * h * d; + size_t output_s_attn_stride = h * d; + size_t output_h_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_attn, + d, + s_v, + test_case.attn_scores.data() + b_idx * attn_b_stride + + h_idx * attn_h_stride, + attn_s_q_stride /*lhs_stride_m*/, + test_case.v_qvals.data() + b_idx * v_b_stride + h_idx * v_h_stride, + v_s_v_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_attn_stride /*out_stride_n*/, + test_case.v_zeros.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + test_case.v_scales.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + 0.0 /*beta*/, + v_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * s_attn * h * d; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, Basic) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_fp32_attn_scores_at_v_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17, false); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h index 9ca86ece76..52fb0851bc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { } }; +struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case { + int b; + int s_attn; + int s_v; + int h; + int d; + size_t b_attn_stride; + size_t h_attn_stride; + size_t s_attn_stride; + size_t b_v_stride; + size_t h_v_stride; + size_t s_v_stride; + size_t b_v_qparams_stride; + size_t h_v_qparams_stride; + size_t s_v_qparams_stride; + + std::vector expected_output; + + std::vector attn_scores; + + std::vector v; + std::vector v_qvals; + std::vector v_scales; + std::vector v_zeros; + + fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + int b_, + int s_attn_, + int s_v_, + int h_, + int d_, + size_t b_attn_stride_, + size_t h_attn_stride_, + size_t s_attn_stride_, + size_t b_v_stride_, + size_t h_v_stride_, + size_t s_v_stride_, + size_t b_v_qparams_stride_, + size_t h_v_qparams_stride_, + size_t s_v_qparams_stride_, + std::vector expected_output_, + std::vector attn_scores_, + std::vector v_, + std::vector v_qvals_, + std::vector v_scales_, + std::vector v_zeros_) + : b(b_), + s_attn(s_attn_), + s_v(s_v_), + h(h_), + d(d_), + b_attn_stride(b_attn_stride_), + h_attn_stride(h_attn_stride_), + s_attn_stride(s_attn_stride_), + b_v_stride(b_v_stride_), + h_v_stride(h_v_stride_), + s_v_stride(s_v_stride_), + b_v_qparams_stride(b_v_qparams_stride_), + h_v_qparams_stride(h_v_qparams_stride_), + s_v_qparams_stride(s_v_qparams_stride_), + expected_output(expected_output_), + attn_scores(attn_scores_), + v(v_), + v_qvals(v_qvals_), + v_scales(v_scales_), + v_zeros(v_zeros_) { + assert(expected_output.size() == b * s_attn * h * d); + assert(attn_scores.size() == b * h * s_attn * s_v); + assert(v.size() == b * h * s_v * d); + assert(v_qvals.size() == b * h * s_v * d); + assert(v_scales.size() == b * h * s_v); + assert(v_zeros.size() == b * h * s_v); + } + + static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case + generate(int b, int s_attn, int s_v, int h, int d, bool transposed_v = true) { + // Generate activations + auto lhs = get_random_vector(b * h * s_attn * s_v, -1.0, 1.0); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * h * s_v, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_attn_stride = h * s_attn * s_v; + size_t h_attn_stride = s_attn * s_v; + size_t s_attn_stride = s_v; + + size_t b_v_stride = h * s_v * d; + size_t h_v_stride = s_v * d; + size_t s_v_stride = d; + + size_t b_v_qparams_stride = h * s_v; + size_t h_v_qparams_stride = s_v; + size_t s_v_qparams_stride = 1; + + if (!transposed_v) { + h_v_stride = d; + s_v_stride = h * d; + + s_v_qparams_stride = h; + h_v_qparams_stride = 1; + } + + // Compute expected output + // Note that while the inputs can be in shape b x h x s_attn x s_v, + // and b x h x s_v x d the output is not in b x h x s_attn x s_v + // but rather b x s_attn x h x d. This is because the output of + // SDPA will normally be in b x h x s_attn x d, but we want to + // avoid any tranposes. Thus just aim to output in b x s_attn x h x d + // This is just for testing purposes. Kernel can actually write output + // in [B, H, S, D] if needed. + std::vector expected_output(b * s_attn * h * d); + size_t b_out_stride = s_attn * h * d; + size_t s_attn_out_stride = h * d; + size_t h_out_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_attn_idx = 0; s_attn_idx < s_attn; s_attn_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int d_idx = 0; d_idx < d; d_idx++) { + float res = 0.0; + for (int s_v_idx = 0; s_v_idx < s_v; s_v_idx++) { + int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride + + h_idx * h_attn_stride + s_v_idx; + int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx + + s_v_idx * s_v_stride; + int rhs_scales_zp_idx = b_idx * b_v_qparams_stride + + h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride; + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_attn_idx * s_attn_out_stride + + h_idx * h_out_stride + d_idx] = res; + } + } + } + } + + // Return test case + return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + b, + s_attn, + s_v, + h, + d, + b_attn_stride, + h_attn_stride, + s_attn_stride, + b_v_stride, + h_v_stride, + s_v_stride, + b_v_qparams_stride, + h_v_qparams_stride, + s_v_qparams_stride, + expected_output, + lhs, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; } // namespace torchao #endif // defined(__aarch64__) || defined(__ARM_NEON) From 97d6d7400054898cebb2d86492a9ca1041fe1645 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 3 Apr 2025 09:08:58 -0700 Subject: [PATCH 12/19] add fallback kernel and interface Differential Revision: D71370598 Pull Request resolved: https://github.com/pytorch/ao/pull/2010 --- .../cpu/aarch64/tests/test_qmatmul.cpp | 1 + .../channelwise_8bit_a_channelwise_8bit_b.h | 133 ++++++ .../kernels/cpu/interface/quantized_matmul.h | 88 ++++ .../cpu/interface/test_qmatmul_interface.cpp | 448 ++++++++++++++++++ 4 files changed, 670 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h create mode 100644 torchao/experimental/kernels/cpu/interface/quantized_matmul.h create mode 100644 torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 05dbf13aac..ff4f915b2d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -70,6 +70,7 @@ struct test_channelwise_8bit_channelwise_8bit_b< false, false> { static void Run(int m, int k, int n, int stride = 1) { + // TODO: make use of stride for this kernel auto test_case = torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: generate(m, k, n, a_has_zeros, a_has_zeros, false, false); diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h new file mode 100644 index 0000000000..3b070eb2b3 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h @@ -0,0 +1,133 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b::internal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + const int8_t* lhs_qvals = static_cast(lhs); + const int8_t* rhs_qvals = static_cast(rhs); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + + float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] * + (static_cast(lhs_qvals[lhs_idx]) - + static_cast( + lhs_zero_points[m_idx * lhs_qparams_stride])); + + float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast( + rhs_zero_points[n_idx * rhs_qparams_stride])); + + res += lhs_dequant * rhs_dequant; + } + output[m_idx * n + n_idx] = res; + } + } + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b::internal +} // namespace torchao::kernels::cpu::fallback::quantized_matmul + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + channelwise_8bit_a_channelwise_8bit_b::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h new file mode 100644 index 0000000000..01a4c704c5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#include +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +namespace torchao::kernels::cpu::quantized_matmul { + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)( + int, + int, + int, + const void*, + int, + const void*, + int, + float*, + int, + const int8_t*, + const int8_t*, + const float*, + const float*, + const int, + const int); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && b_transposed && n >= 8) { + a_stride_m = k; + b_stride_n = k; + return aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: + kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } else { + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } +} +} // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp new file mode 100644 index 0000000000..3629f0960b --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include + +float kTol = 0.0001; + +// This is unfortunately had to be copied over because code in test_utils.h +// depends on quantization kernels which are only buildable for ARM. +// I would like the testing code in this folder to be independent of the arch. +namespace { +void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric) { + if (is_symmetric) { + qmin = -(1 << (nbit - 1)) + 1; + qmax = -qmin; + } else { + qmin = -(1 << (nbit - 1)); + qmax = (1 << (nbit - 1)) - 1; + } +} + +void get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax) { + assert(qmin < qmax); + assert(vmin < vmax); + scale = (vmax - vmin) / (qmax - qmin); + zero = qmin - std::round(vmin / scale); +} + +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +void quantize( + // Output + int8_t* qvals, + // Inputs + const float* vals, + int size, + float scale, + int8_t zero, + int8_t qmin, + int8_t qmax) { + float invScale = 1.0 / (scale + 1e-16); + int i = 0; + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float val = vals[i]; + float qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); +} + +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + get_qvals_range(qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + auto minmax = std::minmax_element( + activations.data() + m_idx * n, activations.data() + (m_idx + 1) * n); + vmin = *minmax.first; + vmax = *minmax.second; + get_scale_and_zero(scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} + +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; +} // namespace + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::generate( + m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_int8_a_int8_b_channelwise_qmatmul( + m, n, k, false, true, a_stride_m, b_stride_n); + a_stride_m = a_stride_m * stride; + b_stride_n = b_stride_n * stride; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + a_stride_m /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposeBWithZeroPointsLargeMStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 16); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallbackStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5, 7); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1, 32); +} From 83d58e3035062e977ba444b31176c3314a351dd4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 3 Apr 2025 10:23:32 -0700 Subject: [PATCH 13/19] Add fallback kernel and interface for rhs only quantized matmul Differential Revision: D71370602 Pull Request resolved: https://github.com/pytorch/ao/pull/2011 --- .../matmul/fp32_a_channelwise_8bit_b_fp32_c.h | 50 +++++ .../kernels/cpu/interface/quantized_matmul.h | 70 +++++++ .../cpu/interface/test_qmatmul_interface.cpp | 182 ++++++++++++++++++ 3 files changed, 302 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h new file mode 100644 index 0000000000..58e2853617 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_fp32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + assert(a_transposed == false); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] * + (static_cast(rhs[rhs_idx]) - + static_cast(rhs_zero_points[k_idx * rhs_qparams_stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res; + } + } +} +} // namespace fp32_a_input_channelwise_8bit_b_fp32 +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h index 01a4c704c5..718f7eaad9 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -9,6 +9,8 @@ #include #include +#include + #if defined(__aarch64__) || defined(__ARM_NEON) #include #include @@ -85,4 +87,72 @@ get_int8_a_int8_b_channelwise_qmatmul( channelwise_8bit_a_channelwise_8bit_b::kernel; } } + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)( + int, + int, + int, + const float*, + int, + const int8_t*, + int, + float*, + int, + const int8_t*, + const float*, + const float, + const int); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && !b_transposed && n >= 16) { + a_stride_m = k; + b_stride_n = n; + return aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } else { + a_stride_m = k; + b_stride_n = n; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } +} } // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp index 3629f0960b..4024f3f1de 100644 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -446,3 +446,185 @@ TEST( Run( /*m=*/4, /*k=*/2, /*n=*/1, 32); } + +class FP32A_QuantizedB_FP32C_Interface_Test + : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Interface_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + m, n, k, false, false, a_stride_m, b_stride_n); + b_stride_n = b_stride_n * stride; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + a_stride_m /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) { + generate(3, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/3, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +// Test shapes for which we have to use fallback kernel +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesFallback) { + generate(4, 37, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2Fallback) { + generate(4, 1, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/1, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesStrided) { + generate(4, 37, 19, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2FallbackStrided) { + generate(4, 5, 3, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Interface_Test, + ::testing::Values(0.0, 1.0, 3.1)); From 76ec450666004bf4aab9bda466da00f783932219 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 3 Apr 2025 11:16:06 -0700 Subject: [PATCH 14/19] Add KleidiAI gemm kernels (#2000) Add KleidiAI gemm kernels (#2000) Summary: This PR pulls in two new KleidiAI kernels: * kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod (GEMV) * kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod (GEMM) and adds them for automatic mr-based kernel selection when TORCHAO_ENABLE_ARM_NEON_DOT is set. It also adds new tests for these kernels, and refactors the kleidiai testing code so that in future new kleidiai kernels can be tested with a one line addition: ``` TEST( test_linear_8bit_act_xbit_weight, matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) { test_linear_8bit_act_xbit_weight_kleidiai< matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); } ``` The exisitng testing code (still exists for more coverage) depended on code generation. Reviewed By: Jack-Khuu Differential Revision: D72179835 --- .../workflows/torchao_experimental_test.yml | 2 +- torchao/experimental/CMakeLists.txt | 1 + .../kernels/cpu/aarch64/CMakeLists.txt | 2 +- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 11 ++- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 1 + .../kernel_selector.h | 65 ++++++++++----- torchao/experimental/ops/tests/CMakeLists.txt | 1 + .../test_linear_8bit_act_xbit_weight.cpp | 82 ++++++++++++++++++- ...est_int8_dynamic_activation_intx_weight.py | 70 +++++++++++++++- 9 files changed, 207 insertions(+), 28 deletions(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 0cb470901e..2187eed8e3 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -38,7 +38,7 @@ jobs: pip install executorch pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall pip install -r dev-requirements.txt - USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install . + USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install . - name: Run python tests run: | conda activate venv diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index f05e6b392f..e6b2a6aff0 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -40,6 +40,7 @@ include_directories(${TORCHAO_INCLUDE_DIRS}) if(TORCHAO_BUILD_CPU_AARCH64) message(STATUS "Building with cpu/aarch64") add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) # Defines torchao_kernels_aarch64 add_subdirectory(kernels/cpu/aarch64) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 3cca338cbf..f38794d4a8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -19,7 +19,7 @@ if (TORCHAO_BUILD_CPU_AARCH64) # intelligence (AI) workloads tailored for ArmĀ® CPUs. FetchContent_Declare(kleidiai GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG v1.2.0) + GIT_TAG v1.5.0) FetchContent_MakeAvailable(kleidiai) target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 2a8e668fa7..aa338fc165 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -14,9 +14,14 @@ #include #include +#include + +#ifdef TORCHAO_ENABLE_ARM_NEON_DOT +#include #include #include -#include +#include +#endif // TORCHAO_ENABLE_ARM_NEON_DOT #ifdef TORCHAO_ENABLE_ARM_I8MM #include @@ -297,10 +302,14 @@ size_t get_preferred_alignement() { } \ } +#ifdef TORCHAO_ENABLE_ARM_NEON_DOT DEFINE_KERNEL_STRUCT( matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); DEFINE_KERNEL_STRUCT( matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod); +#endif // TORCHAO_ENABLE_ARM_NEON_DOT #ifdef TORCHAO_ENABLE_ARM_I8MM DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index a01afac68f..db736d84a3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -42,6 +42,7 @@ add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 $ if(TORCHAO_BUILD_KLEIDIAI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) endif() if(TORCHAO_BUILD_ARM_I8MM) diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 719c2e01e4..ffdd62f7a7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -8,18 +8,19 @@ #include #include #include - -#if defined(TORCHAO_BUILD_CPU_AARCH64) -#include -#endif // TORCHAO_BUILD_CPU_AARCH64 - #include #include #include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) +#include +#endif // TORCHAO_ENABLE_ARM_NEON_DOT + #if defined(TORCHAO_ENABLE_KLEIDI) #include #endif // TORCHAO_ENABLE_KLEIDI +#endif // TORCHAO_BUILD_CPU_AARCH64 namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -110,7 +111,7 @@ void register_ukernel_config_universal( constexpr int mr = 1; constexpr int m_step = 1; -#if defined(TORCHAO_BUILD_CPU_AARCH64) +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { log_registration(format, "universal: kernel_1x8x16_f32_neondot"); auto uk = UKernelConfig::make( @@ -159,7 +160,7 @@ void register_ukernel_config_universal( return; } } -#endif // TORCHAO_BUILD_CPU_AARCH64 +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } } @@ -213,18 +214,24 @@ void register_ukernel_config_kleidi( #if defined(TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - /*m_step=4*/ + log_registration( + format, + "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); + /*m_step=1*/ uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + + /*m_step=4*/ + uk.linear_configs[1] = get_linear_config_kleidi< op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>( uk.n_step, uk.nr, uk.kr, uk.sr); - log_registration( - format, - "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); table.register_ukernel_config(format, uarch, std::move(uk)); return; } #endif // TORCHAO_ENABLE_ARM_I8MM +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { log_registration( format, @@ -236,22 +243,27 @@ void register_ukernel_config_kleidi( table.register_ukernel_config(format, uarch, std::move(uk)); return; } +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } - if (format.nr == 4 && format.kr == 16 && format.sr == 2) { - uk.n_step = 4; + if (format.nr == 8 && format.kr == 8 && format.sr == 2) { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { - /*m_step=1*/ - uk.linear_configs[0] = get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>( - uk.n_step, uk.nr, uk.kr, uk.sr); - log_registration( format, - "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"); + "kleidiai: matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod, matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod"); + // m_step 1 + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + // m_step 4 + uk.linear_configs[1] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); table.register_ukernel_config(format, uarch, std::move(uk)); return; } +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } } #endif // TORCHAO_ENABLE_KLEIDI @@ -325,8 +337,7 @@ PackedWeightsFormat select_packed_weights_format( #if defined(TORCHAO_ENABLE_KLEIDI) if (!target || *target == "kleidiai") { if (weight_nbit == 4 && (!has_weight_zeros)) { - // KleidiAI will pack bias with weights always, - // even if bias is not provided 0s will be packed +#if defined(TORCHAO_ENABLE_ARM_I8MM) return PackedWeightsFormat( torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, @@ -335,12 +346,23 @@ PackedWeightsFormat select_packed_weights_format( /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); +#elif defined(TORCHAO_ENABLE_ARM_NEON_DOT) + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, + weight_nbit, + has_weight_zeros, + has_bias, + /*nr*/ 8, + /*kr*/ 8, + /*sr*/ 2); +#endif } } #endif // defined(TORCHAO_ENABLE_KLEIDI) // Select universal format if (!target || *target == "universal") { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) return PackedWeightsFormat( torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, weight_nbit, @@ -349,6 +371,7 @@ PackedWeightsFormat select_packed_weights_format( /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); +#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT) } throw std::runtime_error("No packed_weights_format was selected"); diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt index 8a9ad08f23..8245fdd746 100644 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ b/torchao/experimental/ops/tests/CMakeLists.txt @@ -24,6 +24,7 @@ enable_testing() if(TORCHAO_BUILD_CPU_AARCH64) add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64=1) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) endif() if(TORCHAO_BUILD_KLEIDIAI) diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index 980228a1a8..1d4127a43e 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -21,7 +21,7 @@ using namespace torchao::kernels::cpu::aarch64::kleidi:: #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; -const float kTolKleidiAI = 1.0e-2; +const float kTolKleidiAI = 5.0e-2; using namespace torchao::ops::linear_8bit_act_xbit_weight; @@ -208,6 +208,86 @@ UKernelConfig get_ukernel_config_kleidi_impl() { return ukernel_config; } +template +void test_linear_8bit_act_xbit_weight_kleidiai() { + constexpr int weight_nbit = 4; + constexpr bool has_kleidi = true; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + auto uk = get_ukernel_config_kleidi_impl(); + + for (auto m : {1, 3, 4, 8, 9, 13, 21, 43, 101}) { + for (auto n : + {1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 4 * 13, + 4 * 13 + 3, + 8 * 13, + 8 * 13 + 3, + 16 * 13, + 16 * 13 + 3}) { + for (auto k : {32, 64, 128}) { + int group_size = 32; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ true, + has_kleidi>(m, n, k, group_size, &uk); + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ false, + has_kleidi>(m, n, k, group_size, &uk); + + if (k >= 64) { + group_size = 64; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ true, + has_kleidi>(m, n, k, group_size, &uk); + } + } + } + } +} + +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod>(); +} +#endif // TORCHAO_ENABLE_ARM_NEON_DOT + template UKernelConfig get_ukernel_config_kleidi() { #if defined(TORCHAO_ENABLE_ARM_I8MM) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 098fc09696..dcd8eb74d5 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -105,6 +105,58 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity): expected_result = quantized_model_reference(activations) self._assert_close(result, expected_result) + def test_accuracy_kleidiai(self): + n = 1071 + k = 2048 + model = torch.nn.Sequential( + *[torch.nn.Linear(k, k, bias=False), torch.nn.Linear(k, n, bias=True)] + ) + weight_dtype = torch.int4 + granularity = PerGroup(128) + has_weight_zeros = False + + # We set round_weight_scale_to_bf16 to True for accuracy testing because + # some KleidiAI kernels do this internally + round_weight_scale_to_bf16 = True + + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="kleidiai" + ), + round_weight_scale_to_bf16=round_weight_scale_to_bf16, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=self._reference_layout(), + round_weight_scale_to_bf16=round_weight_scale_to_bf16, + ), + ) + + with torch.no_grad(): + for m in [1, 3, 5, 9, 13]: + activations = torch.randn(m, k) + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + # KleidiAI kernels require much higher tolerance when comparing to reference, + # especially for GEMM kernels + self._assert_close( + result, expected_result, mse_tol=1e-2, atol=1e-2, rtol=1 + ) + def test_accuracy_aten(self): m = 3 n = 1024 @@ -151,9 +203,21 @@ def test_accuracy_aten(self): self._assert_close(result, expected_result) - def _assert_close(self, result, expected_result): - self.assertTrue(torch.nn.functional.mse_loss(result, expected_result) <= 1e-6) - self.assertTrue(torch.allclose(result, expected_result, atol=1e-2)) + def _assert_close( + self, result, expected_result, mse_tol=1e-6, atol=1e-2, rtol=1e-5 + ): + mse_loss = torch.nn.functional.mse_loss(result, expected_result) + self.assertTrue( + mse_loss <= mse_tol, + f"Got mse_loss={mse_loss}, above mse tolerance {mse_tol}", + ) + + n_rand_idxs = 5 + rand_idxs = torch.randint(0, result.numel(), (n_rand_idxs,)) + self.assertTrue( + torch.allclose(result, expected_result, atol=atol, rtol=rtol), + f"Failed allclose at atol={atol}, rtol={rtol}. On {n_rand_idxs} random indices, we have result={result.reshape(-1)[rand_idxs]} vs expected_result={expected_result.reshape(-1)[rand_idxs]}.", + ) def _reference_layout(self): return PlainLayout() From 0231a681af2ffb32f5ec93ea182356da806a87aa Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 3 Apr 2025 17:00:59 -0700 Subject: [PATCH 15/19] Update float8nocompile test code to use new float8 matmul function (#2013) --- .../prototype/float8nocompile/float8nocompile_linear_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py index f62569cbb4..7df5ce768c 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py @@ -7,7 +7,7 @@ import torch from torchao.float8.config import Float8LinearConfig -from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig from torchao.prototype.float8nocompile.float8nocompile_linear import ( matmul_with_args_in_hp, @@ -72,7 +72,7 @@ def test_matmul_with_args_in_hp(input_shape: tuple[int, int]): ) # prod forward. expects transposed weight. - out_prod = manual_float8_matmul_with_args_in_hp.apply( + out_prod = matmul_with_hp_or_float8_args.apply( prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config ) From b375781a7f9cad7fa1649abf92f95374f8531c53 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 3 Apr 2025 20:41:37 -0700 Subject: [PATCH 16/19] Remove float8nocompile CI (#1976) remove float8nocompmile CI since it's flaky on sm89 --- .github/workflows/float8nocompile_test.yaml | 53 --------------------- 1 file changed, 53 deletions(-) delete mode 100644 .github/workflows/float8nocompile_test.yaml diff --git a/.github/workflows/float8nocompile_test.yaml b/.github/workflows/float8nocompile_test.yaml deleted file mode 100644 index b8707c148e..0000000000 --- a/.github/workflows/float8nocompile_test.yaml +++ /dev/null @@ -1,53 +0,0 @@ -name: Run Float8nocompile Tests - -on: - push: - branches: - - main - - 'gh/**' - paths: - - 'torchao/prototype/float8nocompile/**' - pull_request: - branches: - - main - - 'gh/**' - paths: - - 'torchao/prototype/float8nocompile/**' - -concurrency: - group: floatnocompile_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} - cancel-in-progress: true - -env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - -# jobs: -# test: -# strategy: -# fail-fast: false -# matrix: -# include: -# - name: H100 -# runs-on: linux.aws.h100 -# torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124' -# gpu-arch-type: "cuda" -# gpu-arch-version: "12.4" - -# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main -# with: -# timeout: 300 -# runner: ${{ matrix.runs-on }} -# gpu-arch-type: ${{ matrix.gpu-arch-type }} -# gpu-arch-version: ${{ matrix.gpu-arch-version }} -# submodules: recursive -# script: | -# conda create -n venv python=3.9 -y -# conda activate venv -# export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH -# python -m pip install --upgrade pip -# pip install ${{ matrix.torch-spec }} -# pip install -r dev-requirements.txt -# pip install . -# cd torchao/prototype/float8nocompile -# pytest kernels/ --verbose -s -# pytest test/train_test.py --verbose -s From 66d6a646203c8fd89eac64f5a1aacd53ba0da992 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 4 Apr 2025 06:20:39 -0700 Subject: [PATCH 17/19] Update clean_release_notes.py (#2014) --- scripts/clean_release_notes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/clean_release_notes.py b/scripts/clean_release_notes.py index 2caef0735b..92ce5996cc 100644 --- a/scripts/clean_release_notes.py +++ b/scripts/clean_release_notes.py @@ -223,7 +223,7 @@ def format_commit(commit_line: str) -> str: After: * Commit title (https://github.com/pytorch/ao/pull/123) """ # Remove author, put PR link in parentheses - commit_line = re.sub(" by @.* in (.*)", r" (\\g<1>)", commit_line) + commit_line = re.sub(" by @.* in (.*)", r" (\g<1>)", commit_line) # Capitalize first letter commit_line = commit_line.lstrip("* ") commit_line = "* " + commit_line[0].upper() + commit_line[1:] From 6922733cbd06aa04a0db4b3b3f29606662f6fe75 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 4 Apr 2025 10:56:28 -0400 Subject: [PATCH 18/19] Match QAT prepare and convert numerics exactly (#1964) **Summary:** Previously, `Int8DynActInt4QATQuantizer` had slightly diverging numerics between the prepare and convert steps. This is because the prepare step uses quantization primitives shared with AQT (specifically `quantize_affine` and `dequantize_affine`), while the convert step relies on old ops from the `torch.ops.quantized_decomposed` namespace. The diverging numerics is negligible for small models, but the quantization errors begin to compound for larger models with many linear layers. More specifically, there are three different places where the divergence occurs during activation quantization: 1. **Choose qparams.** The prepare step casts the qparams to `torch.float32`, whereas the convert step casts the scales to `torch.float64` and zero points to `torch.int64`. 2. **Quantize.** The prepare step performs round before adding zero points and uses torch functions, while the convert step adds before rounding and uses torch tensor methods. ``` x = torch.clamp( torch.round(x * (1.0 / scale)) + zero_point, qmin, qmax, ) x = ( x.mul(1.0 / scale) .add(zero_point) .round() .clamp(qmin, qmax) .to(quantize_dtype) ) ``` 3. **Dequantize.** The prepare step casts to `torch.int32` before adding the zero points, and casts back to the original dtype before multiplying the scale. The convert step only casts at the very end. ``` x = x.to(torch.int32) - zero_point.to(torch.int32) x = x.to(orig_dtype) x = x * scale x = x - zero_point x = x * scale x = x.to(orig_dtype) ``` This commit makes the convert path use the same torchao quantization primitives as the prepare path, thereby resolving the 3 above differences. Now, the prepare and convert steps match exactly in terms of numerics over many trials. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert --- test/quantization/test_qat.py | 71 +++++++++++++++++++++++++++++++++++ torchao/_executorch_ops.py | 2 + torchao/quantization/GPTQ.py | 17 +++++---- torchao/quantization/utils.py | 60 +++++++++++++++++------------ 4 files changed, 118 insertions(+), 32 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3c29028898..fcd4969bbf 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -133,6 +133,18 @@ def forward(self, x): return x +class M4(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 512).to(torch.float),) + + def forward(self, x): + return self.linear(x) + + class ModelWithLinearBias(torch.nn.Module): def __init__(self): super().__init__() @@ -1389,6 +1401,65 @@ def test_qat_linear_bias(self): example_inputs = m.example_inputs() m(*example_inputs) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_fake_quantize_per_token_vs_convert(self): + """ + Test that the following produce the exact same numerics: + 1. FakeQuantizer with asymmetric per_token config + 2. torchao.quantization.utils.per_token_dynamic_quant + """ + from torchao.quantization.utils import per_token_dynamic_quant + + torch.manual_seed(self.SEED) + x = torch.randn(1, 235, 2048) + config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + fake_quantizer = FakeQuantizer(config) + fake_quantizer_out = fake_quantizer(x) + baseline_out = per_token_dynamic_quant(x) + torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_8da4w_prepare_vs_convert(self): + """ + Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces + numerics that match exactly over N trials. + """ + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.utils import compute_error + + num_trials = 1000 + group_size = 16 + non_inf_sqnr = [] + + for seed in range(self.SEED, self.SEED + num_trials): + torch.manual_seed(seed) + m = M4() + torch.manual_seed(seed) + x = m.example_inputs() + + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + prepared = quantizer.prepare(m) + prepared_out = prepared(*x) + converted = quantizer.convert(prepared) + converted_out = converted(*x) + sqnr = compute_error(prepared_out, converted_out).item() + if sqnr != float("inf"): + non_inf_sqnr.append(sqnr) + + avg_sqnr = ( + sum(non_inf_sqnr) / len(non_inf_sqnr) if len(non_inf_sqnr) > 0 else -1 + ) + fail_message = "%s/%s trials did not match exactly, average sqnr = %s" % ( + len(non_inf_sqnr), + num_trials, + avg_sqnr, + ) + self.assertEqual(len(non_inf_sqnr), 0, fail_message) + if __name__ == "__main__": unittest.main() diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 29339bba8c..4b761ad725 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import torch +# TODO: delete these ops + def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): """ diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 6c63937051..63b1da440d 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -24,7 +24,10 @@ find_multiple, ) -from .quant_primitives import MappingType +from .quant_primitives import ( + MappingType, + dequantize_affine, +) from .unified import Quantizer from .utils import ( _MultiInput, @@ -940,19 +943,17 @@ def linear_forward_8da4w( n_bit = 4 quant_min = -(2 ** (n_bit - 1)) quant_max = 2 ** (n_bit - 1) - 1 - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_channel_group_wrapper, - ) + block_size = (1, groupsize) - w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( + w_dq = dequantize_affine( weight_int8, + block_size, scales, zeros, + torch.int8, quant_min, quant_max, - torch.int8, - groupsize, - precision, + output_dtype=precision, ) # x = x.to(torch.float16) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74c136ad00..b23f39c6d7 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -539,36 +539,48 @@ def group_quantize_tensor_symmetric( return w_int8, scales, zeros -def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: - orig_dtype = input.dtype - # TODO: we may need to make the choose_qparams op configurable - from torchao._executorch_ops import ( - _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper, - ) - - ( - scales, - zero_points, - ) = _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper( - input, torch.int8 - ) - - # TODO: get these from torch.int8 +def per_token_dynamic_quant( + input: torch.Tensor, + scale_dtype: torch.dtype = torch.float32, + zero_point_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(input) quant_min = -128 quant_max = 127 - from torchao._executorch_ops import _quantized_decomposed_quantize_per_token_wrapper + quant_dtype = torch.int8 + output_dtype = input.dtype - input = _quantized_decomposed_quantize_per_token_wrapper( - input, scales, zero_points, quant_min, quant_max, torch.int8 + scales, zero_points = choose_qparams_affine( + input, + mapping_type, + block_size, + quant_dtype, + quant_min, + quant_max, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, ) - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_token_wrapper, + q = quantize_affine( + input, + block_size, + scales, + zero_points, + quant_dtype, + quant_min, + quant_max, ) - - input = _quantized_decomposed_dequantize_per_token_wrapper( - input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype + dq = dequantize_affine( + q, + block_size, + scales, + zero_points, + quant_dtype, + quant_min, + quant_max, + output_dtype=output_dtype, ) - return input.to(orig_dtype) + return dq def recommended_inductor_config_setter(): From faf3c0ffc654759127d7b62d6be01bf93f8d41b3 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 4 Apr 2025 09:44:48 -0700 Subject: [PATCH 19/19] Update [ghstack-poisoned] --- .../microbenchmarks/test/benchmark_config.yml | 28 +++++ benchmarks/microbenchmarks/test/test_utils.py | 115 ++++++++++++++++++ benchmarks/microbenchmarks/utils.py | 110 +++++++++++++++++ 3 files changed, 253 insertions(+) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 227cb90948..4394d0208b 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -50,3 +50,31 @@ model_params: # device: "cpu" # model_type: "linear" # enable_profiler: true # Enable profiling for this model + + - name: "bf16_rms_norm_linear_activation" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "rms_norm_linear_activation" + enable_profiler: true + enable_memory_profile: true + + - name: "bf16_transformer_block" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], # For transformer_block, k is the hidden dimension + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "transformer_block" + enable_profiler: true + enable_memory_profile: true diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 14f226bd7e..46f6a74685 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -17,8 +17,11 @@ Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, LNLinearSigmoid, + RMSNorm, + RMSNormLinearActivation, SemiSparseWeightConfig, ToyLinearModel, + TransformerBlock, clean_caches, create_model_and_input, generate_results_csv, @@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range + def test_rms_norm(self): + # Test RMSNorm + rms_norm = RMSNorm(dim=64) + x = torch.randn(16, 64) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + # Test with different eps + rms_norm = RMSNorm(dim=64, eps=1e-5) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + def test_rms_norm_linear_activation(self): + # Test with default GELU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertEqual(out.dtype, torch.float32) + + # Test with ReLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertTrue(torch.all(out >= 0)) # Check ReLU output range + + # Test with SiLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + + # Test with invalid activation + with self.assertRaises(ValueError): + RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid") + + def test_transformer_block(self): + # Test with default parameters + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim] + out = model(x) + self.assertEqual(out.shape, (16, 16, 64)) + self.assertEqual(out.dtype, torch.float32) + + # Test with different parameters + model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32) + x = torch.randn(8, 32, 128) + out = model(x) + self.assertEqual(out.shape, (8, 32, 128)) + + # Test with different head dimensions + model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32) + x = torch.randn(4, 8, 96) + out = model(x) + self.assertEqual(out.shape, (4, 8, 96)) + def test_create_model_and_input(self): m, k, n = 16, 64, 32 model, input_data = create_model_and_input( @@ -186,6 +244,63 @@ def test_create_model_and_input(self): self.assertIsInstance(model, LNLinearSigmoid) self.assertEqual(input_data.shape, (m, k)) + # Test RMSNormLinearActivation + model, input_data = create_model_and_input( + model_type="rms_norm_linear_activation", + m=m, + k=k, + n=n, + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, RMSNormLinearActivation) + self.assertEqual(input_data.shape, (m, k)) + + # Test TransformerBlock + model, input_data = create_model_and_input( + model_type="transformer_block", + m=m, + k=k, + n=n, # n is not used for transformer_block + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, TransformerBlock) + self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim] + + def test_quantization_on_models(self): + # Test quantization on RMSNormLinearActivation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 32)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + + # Test quantization on TransformerBlock + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 16, 64)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + def test_generate_results_csv(self): results = [ BenchmarkResult( diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 677f66ac75..9e978f70fa 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -383,6 +383,108 @@ def forward(self, x): return x +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + + def forward(self, x): + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + + +class RMSNormLinearActivation(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"): + super().__init__() + self.rms_norm = RMSNorm(fc_dim1, dtype=dtype) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) + + if activation == "gelu": + self.activation = torch.nn.GELU() + elif activation == "relu": + self.activation = torch.nn.ReLU() + elif activation == "silu": + self.activation = torch.nn.SiLU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward(self, x): + x = self.rms_norm(x) + x = self.fc(x) + x = self.activation(x) + return x + + +class TransformerBlock(torch.nn.Module): + def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + # Self-attention + self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) + self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) + + # MLP + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype) + self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype) + + # Layer norms + self.norm1 = RMSNorm(hidden_dim, dtype=dtype) + self.norm2 = RMSNorm(hidden_dim, dtype=dtype) + + # Activation + self.activation = torch.nn.GELU() + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Self-attention + residual = x + x = self.norm1(x) + + # Reshape qkv projection for better memory layout + qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] + qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] + q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] + + # Scaled dot-product attention with proper reshaping + # Reshape for better memory layout and avoid broadcasting issues + q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + + # Compute attention scores + attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5)) + attn = torch.softmax(attn, dim=-1) + + # Apply attention to values + x = attn @ v # [batch_size * num_heads, seq_len, head_dim] + + # Reshape back to original dimensions + x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) + + # Project back to hidden dimension + x = self.proj(x) + x = residual + x + + # MLP + residual = x + x = self.norm2(x) + x = self.mlp_fc1(x) + x = self.activation(x) + x = self.mlp_fc2(x) + x = residual + x + + return x + + def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -576,6 +678,14 @@ def create_model_and_input( elif model_type == "ln_linear_sigmoid": model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "rms_norm_linear_activation": + model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "transformer_block": + # For transformer block, k is the hidden dimension + model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device) + # Input shape for transformer is [batch_size, seq_len, hidden_dim] + input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data