Skip to content

Add quantized attn_scores @ v test for intented used in quantized attention #2008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ struct KernelImpl {
const int rhs_qparams_stride);
};

/*
Document param meaning
rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n.
Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements
in k dim are.
*/
template <>
struct KernelImpl<true, false, false> {
static void run(
Expand Down
87 changes: 87 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,91 @@ TEST(
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false);
}

static void test_fp32_attn_scores_at_v_matmul_attention(
int b,
int s_attn,
int s_v,
int h,
int d,
bool transpose_v = true) {
auto test_case =
torchao::fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case::generate(
b, s_attn, s_v, h, d, transpose_v);

using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
fp32_a_input_channelwise_8bit_b_1x16x4_f32;

size_t attn_b_stride = test_case.b_attn_stride;
size_t attn_h_stride = test_case.h_attn_stride;
size_t attn_s_q_stride = test_case.s_attn_stride;

size_t v_b_stride = test_case.b_v_stride;
size_t v_h_stride = test_case.h_v_stride;
size_t v_s_v_stride = test_case.s_v_stride;
size_t v_scale_zp_b_stride = test_case.b_v_qparams_stride;
size_t v_scale_zp_h_stride = test_case.h_v_qparams_stride;
size_t v_scale_zp_s_stride = test_case.s_v_qparams_stride;

std::vector<float> output(b * s_attn * h * d);
size_t output_b_stride = s_attn * h * d;
size_t output_s_attn_stride = h * d;
size_t output_h_stride = d;

for (int b_idx = 0; b_idx < b; b_idx++) {
for (int h_idx = 0; h_idx < h; h_idx++) {
kernel<true, false, false>(
s_attn,
d,
s_v,
test_case.attn_scores.data() + b_idx * attn_b_stride +
h_idx * attn_h_stride,
attn_s_q_stride /*lhs_stride_m*/,
test_case.v_qvals.data() + b_idx * v_b_stride + h_idx * v_h_stride,
v_s_v_stride /*rhs_stride_n*/,
output.data() + b_idx * output_b_stride + h_idx * output_h_stride,
output_s_attn_stride /*out_stride_n*/,
test_case.v_zeros.data() + b_idx * v_scale_zp_b_stride +
h_idx * v_scale_zp_h_stride,
test_case.v_scales.data() + b_idx * v_scale_zp_b_stride +
h_idx * v_scale_zp_h_stride,
0.0 /*beta*/,
v_scale_zp_s_stride /*rhs qparams stride*/);
}
}

for (int i = 0; i < b * s_attn * h * d; i++) {
EXPECT_NEAR(output[i], test_case.expected_output[i], kTol);
}
}

TEST(test_fp32_attn_scores_at_v_matmul_attention, Basic) {
test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16);
}

TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndHeadDim) {
test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 33);
}

TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDim) {
test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33);
}

TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndSmallHeadDim) {
test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17);
}

TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose) {
test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16, false);
}

TEST(
test_fp32_attn_scores_at_v_matmul_attention,
PrimeHeadsAndSmallHeadDimNoTranspose) {
test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17, false);
}

TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) {
test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false);
}

#endif // defined(__aarch64__) || defined(__ARM_NEON)
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case {
}
};

struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case {
int b;
int s_attn;
int s_v;
int h;
int d;
size_t b_attn_stride;
size_t h_attn_stride;
size_t s_attn_stride;
size_t b_v_stride;
size_t h_v_stride;
size_t s_v_stride;
size_t b_v_qparams_stride;
size_t h_v_qparams_stride;
size_t s_v_qparams_stride;

std::vector<float> expected_output;

std::vector<float> attn_scores;

std::vector<float> v;
std::vector<int8_t> v_qvals;
std::vector<float> v_scales;
std::vector<int8_t> v_zeros;

fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case(
int b_,
int s_attn_,
int s_v_,
int h_,
int d_,
size_t b_attn_stride_,
size_t h_attn_stride_,
size_t s_attn_stride_,
size_t b_v_stride_,
size_t h_v_stride_,
size_t s_v_stride_,
size_t b_v_qparams_stride_,
size_t h_v_qparams_stride_,
size_t s_v_qparams_stride_,
std::vector<float> expected_output_,
std::vector<float> attn_scores_,
std::vector<float> v_,
std::vector<int8_t> v_qvals_,
std::vector<float> v_scales_,
std::vector<int8_t> v_zeros_)
: b(b_),
s_attn(s_attn_),
s_v(s_v_),
h(h_),
d(d_),
b_attn_stride(b_attn_stride_),
h_attn_stride(h_attn_stride_),
s_attn_stride(s_attn_stride_),
b_v_stride(b_v_stride_),
h_v_stride(h_v_stride_),
s_v_stride(s_v_stride_),
b_v_qparams_stride(b_v_qparams_stride_),
h_v_qparams_stride(h_v_qparams_stride_),
s_v_qparams_stride(s_v_qparams_stride_),
expected_output(expected_output_),
attn_scores(attn_scores_),
v(v_),
v_qvals(v_qvals_),
v_scales(v_scales_),
v_zeros(v_zeros_) {
assert(expected_output.size() == b * s_attn * h * d);
assert(attn_scores.size() == b * h * s_attn * s_v);
assert(v.size() == b * h * s_v * d);
assert(v_qvals.size() == b * h * s_v * d);
assert(v_scales.size() == b * h * s_v);
assert(v_zeros.size() == b * h * s_v);
}

static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case
generate(int b, int s_attn, int s_v, int h, int d, bool transposed_v = true) {
// Generate activations
auto lhs = get_random_vector(b * h * s_attn * s_v, -1.0, 1.0);

auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] =
torchao::test_utils::generate_per_token_quantized_tensor(
b * h * s_v, d);
// Above function produces nxk matrix and to produce kxn you need transposed
// = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true
// the shape should be nxk instead of kxn.

size_t b_attn_stride = h * s_attn * s_v;
size_t h_attn_stride = s_attn * s_v;
size_t s_attn_stride = s_v;

size_t b_v_stride = h * s_v * d;
size_t h_v_stride = s_v * d;
size_t s_v_stride = d;

size_t b_v_qparams_stride = h * s_v;
size_t h_v_qparams_stride = s_v;
size_t s_v_qparams_stride = 1;

if (!transposed_v) {
h_v_stride = d;
s_v_stride = h * d;

s_v_qparams_stride = h;
h_v_qparams_stride = 1;
}

// Compute expected output
// Note that while the inputs can be in shape b x h x s_attn x s_v,
// and b x h x s_v x d the output is not in b x h x s_attn x s_v
// but rather b x s_attn x h x d. This is because the output of
// SDPA will normally be in b x h x s_attn x d, but we want to
// avoid any tranposes. Thus just aim to output in b x s_attn x h x d
// This is just for testing purposes. Kernel can actually write output
// in [B, H, S, D] if needed.
std::vector<float> expected_output(b * s_attn * h * d);
size_t b_out_stride = s_attn * h * d;
size_t s_attn_out_stride = h * d;
size_t h_out_stride = d;

for (int b_idx = 0; b_idx < b; b_idx++) {
for (int s_attn_idx = 0; s_attn_idx < s_attn; s_attn_idx++) {
for (int h_idx = 0; h_idx < h; h_idx++) {
for (int d_idx = 0; d_idx < d; d_idx++) {
float res = 0.0;
for (int s_v_idx = 0; s_v_idx < s_v; s_v_idx++) {
int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride +
h_idx * h_attn_stride + s_v_idx;
int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx +
s_v_idx * s_v_stride;
int rhs_scales_zp_idx = b_idx * b_v_qparams_stride +
h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride;
float rhs_dequant = rhs_scales[rhs_scales_zp_idx] *
(rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]);

res += lhs[lhs_idx] * rhs_dequant;
}
expected_output
[b_idx * b_out_stride + s_attn_idx * s_attn_out_stride +
h_idx * h_out_stride + d_idx] = res;
}
}
}
}

// Return test case
return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case(
b,
s_attn,
s_v,
h,
d,
b_attn_stride,
h_attn_stride,
s_attn_stride,
b_v_stride,
h_v_stride,
s_v_stride,
b_v_qparams_stride,
h_v_qparams_stride,
s_v_qparams_stride,
expected_output,
lhs,
rhs,
rhs_qvals,
rhs_scales,
rhs_zeros);
}
};
} // namespace torchao

#endif // defined(__aarch64__) || defined(__ARM_NEON)
Loading