diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 05dbf13aac..ff4f915b2d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -70,6 +70,7 @@ struct test_channelwise_8bit_channelwise_8bit_b< false, false> { static void Run(int m, int k, int n, int stride = 1) { + // TODO: make use of stride for this kernel auto test_case = torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: generate(m, k, n, a_has_zeros, a_has_zeros, false, false); diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h new file mode 100644 index 0000000000..3b070eb2b3 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h @@ -0,0 +1,133 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b::internal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + const int8_t* lhs_qvals = static_cast(lhs); + const int8_t* rhs_qvals = static_cast(rhs); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + + float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] * + (static_cast(lhs_qvals[lhs_idx]) - + static_cast( + lhs_zero_points[m_idx * lhs_qparams_stride])); + + float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast( + rhs_zero_points[n_idx * rhs_qparams_stride])); + + res += lhs_dequant * rhs_dequant; + } + output[m_idx * n + n_idx] = res; + } + } + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b::internal +} // namespace torchao::kernels::cpu::fallback::quantized_matmul + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + channelwise_8bit_a_channelwise_8bit_b::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h new file mode 100644 index 0000000000..01a4c704c5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#include +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +namespace torchao::kernels::cpu::quantized_matmul { + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)( + int, + int, + int, + const void*, + int, + const void*, + int, + float*, + int, + const int8_t*, + const int8_t*, + const float*, + const float*, + const int, + const int); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && b_transposed && n >= 8) { + a_stride_m = k; + b_stride_n = k; + return aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: + kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } else { + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } +} +} // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp new file mode 100644 index 0000000000..3629f0960b --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include + +float kTol = 0.0001; + +// This is unfortunately had to be copied over because code in test_utils.h +// depends on quantization kernels which are only buildable for ARM. +// I would like the testing code in this folder to be independent of the arch. +namespace { +void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric) { + if (is_symmetric) { + qmin = -(1 << (nbit - 1)) + 1; + qmax = -qmin; + } else { + qmin = -(1 << (nbit - 1)); + qmax = (1 << (nbit - 1)) - 1; + } +} + +void get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax) { + assert(qmin < qmax); + assert(vmin < vmax); + scale = (vmax - vmin) / (qmax - qmin); + zero = qmin - std::round(vmin / scale); +} + +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +void quantize( + // Output + int8_t* qvals, + // Inputs + const float* vals, + int size, + float scale, + int8_t zero, + int8_t qmin, + int8_t qmax) { + float invScale = 1.0 / (scale + 1e-16); + int i = 0; + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float val = vals[i]; + float qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); +} + +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + get_qvals_range(qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + auto minmax = std::minmax_element( + activations.data() + m_idx * n, activations.data() + (m_idx + 1) * n); + vmin = *minmax.first; + vmax = *minmax.second; + get_scale_and_zero(scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} + +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; +} // namespace + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::generate( + m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_int8_a_int8_b_channelwise_qmatmul( + m, n, k, false, true, a_stride_m, b_stride_n); + a_stride_m = a_stride_m * stride; + b_stride_n = b_stride_n * stride; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + a_stride_m /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposeBWithZeroPointsLargeMStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 16); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallbackStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5, 7); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1, 32); +}