diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h index 389abb32a5..bdad1b4a47 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -101,6 +101,12 @@ struct KernelImpl { const int rhs_qparams_stride); }; +/* +Document param meaning +rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n. +Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements +in k dim are. +*/ template <> struct KernelImpl { static void run( diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 344b2c4915..05dbf13aac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -509,4 +509,91 @@ TEST( test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); } +static void test_fp32_attn_scores_at_v_matmul_attention( + int b, + int s_attn, + int s_v, + int h, + int d, + bool transpose_v = true) { + auto test_case = + torchao::fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case::generate( + b, s_attn, s_v, h, d, transpose_v); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + size_t attn_b_stride = test_case.b_attn_stride; + size_t attn_h_stride = test_case.h_attn_stride; + size_t attn_s_q_stride = test_case.s_attn_stride; + + size_t v_b_stride = test_case.b_v_stride; + size_t v_h_stride = test_case.h_v_stride; + size_t v_s_v_stride = test_case.s_v_stride; + size_t v_scale_zp_b_stride = test_case.b_v_qparams_stride; + size_t v_scale_zp_h_stride = test_case.h_v_qparams_stride; + size_t v_scale_zp_s_stride = test_case.s_v_qparams_stride; + + std::vector output(b * s_attn * h * d); + size_t output_b_stride = s_attn * h * d; + size_t output_s_attn_stride = h * d; + size_t output_h_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_attn, + d, + s_v, + test_case.attn_scores.data() + b_idx * attn_b_stride + + h_idx * attn_h_stride, + attn_s_q_stride /*lhs_stride_m*/, + test_case.v_qvals.data() + b_idx * v_b_stride + h_idx * v_h_stride, + v_s_v_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_attn_stride /*out_stride_n*/, + test_case.v_zeros.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + test_case.v_scales.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + 0.0 /*beta*/, + v_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * s_attn * h * d; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, Basic) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_fp32_attn_scores_at_v_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17, false); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h index 9ca86ece76..52fb0851bc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { } }; +struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case { + int b; + int s_attn; + int s_v; + int h; + int d; + size_t b_attn_stride; + size_t h_attn_stride; + size_t s_attn_stride; + size_t b_v_stride; + size_t h_v_stride; + size_t s_v_stride; + size_t b_v_qparams_stride; + size_t h_v_qparams_stride; + size_t s_v_qparams_stride; + + std::vector expected_output; + + std::vector attn_scores; + + std::vector v; + std::vector v_qvals; + std::vector v_scales; + std::vector v_zeros; + + fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + int b_, + int s_attn_, + int s_v_, + int h_, + int d_, + size_t b_attn_stride_, + size_t h_attn_stride_, + size_t s_attn_stride_, + size_t b_v_stride_, + size_t h_v_stride_, + size_t s_v_stride_, + size_t b_v_qparams_stride_, + size_t h_v_qparams_stride_, + size_t s_v_qparams_stride_, + std::vector expected_output_, + std::vector attn_scores_, + std::vector v_, + std::vector v_qvals_, + std::vector v_scales_, + std::vector v_zeros_) + : b(b_), + s_attn(s_attn_), + s_v(s_v_), + h(h_), + d(d_), + b_attn_stride(b_attn_stride_), + h_attn_stride(h_attn_stride_), + s_attn_stride(s_attn_stride_), + b_v_stride(b_v_stride_), + h_v_stride(h_v_stride_), + s_v_stride(s_v_stride_), + b_v_qparams_stride(b_v_qparams_stride_), + h_v_qparams_stride(h_v_qparams_stride_), + s_v_qparams_stride(s_v_qparams_stride_), + expected_output(expected_output_), + attn_scores(attn_scores_), + v(v_), + v_qvals(v_qvals_), + v_scales(v_scales_), + v_zeros(v_zeros_) { + assert(expected_output.size() == b * s_attn * h * d); + assert(attn_scores.size() == b * h * s_attn * s_v); + assert(v.size() == b * h * s_v * d); + assert(v_qvals.size() == b * h * s_v * d); + assert(v_scales.size() == b * h * s_v); + assert(v_zeros.size() == b * h * s_v); + } + + static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case + generate(int b, int s_attn, int s_v, int h, int d, bool transposed_v = true) { + // Generate activations + auto lhs = get_random_vector(b * h * s_attn * s_v, -1.0, 1.0); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * h * s_v, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_attn_stride = h * s_attn * s_v; + size_t h_attn_stride = s_attn * s_v; + size_t s_attn_stride = s_v; + + size_t b_v_stride = h * s_v * d; + size_t h_v_stride = s_v * d; + size_t s_v_stride = d; + + size_t b_v_qparams_stride = h * s_v; + size_t h_v_qparams_stride = s_v; + size_t s_v_qparams_stride = 1; + + if (!transposed_v) { + h_v_stride = d; + s_v_stride = h * d; + + s_v_qparams_stride = h; + h_v_qparams_stride = 1; + } + + // Compute expected output + // Note that while the inputs can be in shape b x h x s_attn x s_v, + // and b x h x s_v x d the output is not in b x h x s_attn x s_v + // but rather b x s_attn x h x d. This is because the output of + // SDPA will normally be in b x h x s_attn x d, but we want to + // avoid any tranposes. Thus just aim to output in b x s_attn x h x d + // This is just for testing purposes. Kernel can actually write output + // in [B, H, S, D] if needed. + std::vector expected_output(b * s_attn * h * d); + size_t b_out_stride = s_attn * h * d; + size_t s_attn_out_stride = h * d; + size_t h_out_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_attn_idx = 0; s_attn_idx < s_attn; s_attn_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int d_idx = 0; d_idx < d; d_idx++) { + float res = 0.0; + for (int s_v_idx = 0; s_v_idx < s_v; s_v_idx++) { + int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride + + h_idx * h_attn_stride + s_v_idx; + int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx + + s_v_idx * s_v_stride; + int rhs_scales_zp_idx = b_idx * b_v_qparams_stride + + h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride; + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_attn_idx * s_attn_out_stride + + h_idx * h_out_stride + d_idx] = res; + } + } + } + } + + // Return test case + return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + b, + s_attn, + s_v, + h, + d, + b_attn_stride, + h_attn_stride, + s_attn_stride, + b_v_stride, + h_v_stride, + s_v_stride, + b_v_qparams_stride, + h_v_qparams_stride, + s_v_qparams_stride, + expected_output, + lhs, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; } // namespace torchao #endif // defined(__aarch64__) || defined(__ARM_NEON)