Skip to content

add fallback kernel and interface #2010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct test_channelwise_8bit_channelwise_8bit_b<
false,
false> {
static void Run(int m, int k, int n, int stride = 1) {
// TODO: make use of stride for this kernel
auto test_case =
torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::
generate(m, k, n, a_has_zeros, a_has_zeros, false, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <cstdint>

namespace torchao::kernels::cpu::fallback::quantized_matmul {
namespace channelwise_8bit_a_channelwise_8bit_b::internal {

template <
bool a_has_zeros,
bool b_has_zeros,
bool a_transposed,
bool b_tranposed>
struct KernelImpl {
static void run(
int m,
int n,
int k,
const void* lhs,
int lhs_stride_m,
const void* rhs,
int rhs_stride_n,
float* output,
int out_stride_m,
const int8_t* lhs_zero_points,
const int8_t* rhs_zero_points,
const float* lhs_scales,
const float* rhs_scales,
const int lhs_qparams_stride,
const int rhs_qparams_stride);
};

template <bool b_transposed>
struct KernelImpl<true, true, false, b_transposed> {
static void run(
int m,
int n,
int k,
const void* lhs,
int lhs_stride_m,
const void* rhs,
int rhs_stride_n,
float* output,
int out_stride_m,
const int8_t* lhs_zero_points,
const int8_t* rhs_zero_points,
const float* lhs_scales,
const float* rhs_scales,
const int lhs_qparams_stride,
const int rhs_qparams_stride) {
const int8_t* lhs_qvals = static_cast<const int8_t*>(lhs);
const int8_t* rhs_qvals = static_cast<const int8_t*>(rhs);
for (int m_idx = 0; m_idx < m; m_idx++) {
for (int n_idx = 0; n_idx < n; n_idx++) {
float res = 0.0;
for (int k_idx = 0; k_idx < k; k_idx++) {
int lhs_idx = m_idx * lhs_stride_m + k_idx;
int rhs_idx = k_idx * rhs_stride_n + n_idx;
if (b_transposed) {
rhs_idx = n_idx * rhs_stride_n + k_idx;
}

float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] *
(static_cast<int16_t>(lhs_qvals[lhs_idx]) -
static_cast<int16_t>(
lhs_zero_points[m_idx * lhs_qparams_stride]));

float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] *
(static_cast<int16_t>(rhs_qvals[rhs_idx]) -
static_cast<int16_t>(
rhs_zero_points[n_idx * rhs_qparams_stride]));

res += lhs_dequant * rhs_dequant;
}
output[m_idx * n + n_idx] = res;
}
}
}
};

} // namespace
// channelwise_8bit_a_channelwise_8bit_b::internal
} // namespace torchao::kernels::cpu::fallback::quantized_matmul

// TODO: Remove all ::kernels. No need for extra namespace.
namespace torchao::kernels::cpu::fallback::quantized_matmul {
namespace channelwise_8bit_a_channelwise_8bit_b {
template <
bool a_has_zeros,
bool b_has_zeros,
bool a_transposed,
bool b_transposed>
void kernel(
int m,
int n,
int k,
const void* lhs,
int lhs_stride_m,
const void* rhs,
int rhs_stride_n,
float* output,
int out_stride_m,
const int8_t* lhs_zero_points,
const int8_t* rhs_zero_points,
const float* lhs_scales,
const float* rhs_scales,
const int lhs_qparams_stride,
const int rhs_qparams_stride) {
channelwise_8bit_a_channelwise_8bit_b::internal::
KernelImpl<a_has_zeros, b_has_zeros, a_transposed, b_transposed>::run(
m,
n,
k,
lhs,
lhs_stride_m,
rhs,
rhs_stride_n,
output,
out_stride_m,
lhs_zero_points,
rhs_zero_points,
lhs_scales,
rhs_scales,
lhs_qparams_stride,
rhs_qparams_stride);
}
} // namespace channelwise_8bit_a_channelwise_8bit_b
} // namespace torchao::kernels::cpu::fallback::quantized_matmul
88 changes: 88 additions & 0 deletions torchao/experimental/kernels/cpu/interface/quantized_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <cassert>

#include <torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h>
#if defined(__aarch64__) || defined(__ARM_NEON)
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
#endif // defined(__aarch64__) || defined(__ARM_NEON)

namespace torchao::kernels::cpu::quantized_matmul {

/*
a_stride_m: stride of a in memory to indiciate how far apart each row is.
b_stride_n: stride of b in memory to indiciate how far apart each row is.
If b is transposed (n x k), then this is how many bytes to skip to get to the
next row. If b is not transposed (k x n), then this is how many bytes to skip to
get to the next row.

It also returns the stride of a and b, that should be used in the kernel.

Will need to think of a better way to find the right
ukernel. Perhaps via ukernelconfig + registry?.
*/
using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)(
int,
int,
int,
const void*,
int,
const void*,
int,
float*,
int,
const int8_t*,
const int8_t*,
const float*,
const float*,
const int,
const int);

int8_a_int8_b_channelwise_fp32_c_qmatmul_type
get_int8_a_int8_b_channelwise_qmatmul(
int m,
int n,
int k,
bool a_transposed,
bool b_transposed,
int& a_stride_m,
int& b_stride_n);

int8_a_int8_b_channelwise_fp32_c_qmatmul_type
get_int8_a_int8_b_channelwise_qmatmul(
int m,
int n,
int k,
bool a_transposed,
bool b_transposed,
int& a_stride_m,
int& b_stride_n) {
#if defined(__aarch64__) || defined(__ARM_NEON)
if (!a_transposed && b_transposed && n >= 8) {
a_stride_m = k;
b_stride_n = k;
return aarch64::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::
kernel<true, true, false, true>;
}
#endif // defined(__aarch64__) || defined(__ARM_NEON)
assert(!a_transposed);
if (b_transposed) {
a_stride_m = k;
b_stride_n = k;
return torchao::kernels::cpu::fallback::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b::kernel<true, true, false, true>;
} else {
return torchao::kernels::cpu::fallback::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b::kernel<true, true, false, false>;
}
}
} // namespace torchao::kernels::cpu::quantized_matmul
Loading
Loading