diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index 88bf8988a..16096a6c4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -16,6 +17,128 @@ namespace { +// Benchmark utility to compare variants of uint1 packing +void pack_uint1_values( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 1; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + switch (variant) { + case 8: + for (int i = 0; i < unpacked_size; i += 8) { + torchao::bitpacking::internal::pack_8_uint1_values( + packed + ((i * nbit) / bitsPerByte), unpacked + i); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0, + unpacked1, + unpacked2, + unpacked3); + } + break; + case 128: + for (int i = 0; i < unpacked_size; i += 128) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64); + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7); + } + break; + } +} + +// Benchmark utility to compare variants of uint1 packing +void unpack_uint1_values( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 1; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + switch (variant) { + case 8: + for (int i = 0; i < unpacked_size; i += 8) { + torchao::bitpacking::internal::unpack_8_uint1_values( + unpacked + i, packed + ((i * nbit) / bitsPerByte)); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); + } + break; + case 128: + for (int i = 0; i < unpacked_size; i += 128) { + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7); + } + break; + } +} + // Benchmark utility to compare variants of uint2 packing void pack_uint2_values( uint8_t* packed, @@ -470,6 +593,44 @@ void unpack_uint5_values( } // namespace +static void benchmark_pack_uint1_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 1; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); + + for (auto _ : state) { + pack_uint1_values( + packed.data(), unpacked.data(), packed_size, unpacked_size, variant); + } +} + +static void benchmark_unpack_uint1_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 1; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = torchao::get_random_lowbit_vector(packed_size, 8); + auto unpacked = std::vector(unpacked_size, 0); + + for (auto _ : state) { + unpack_uint1_values( + unpacked.data(), + packed.data(), + unpacked.size(), + packed.size(), + variant); + } +} + static void benchmark_pack_uint2_values(benchmark::State& state) { int unpacked_size = state.range(0); int variant = state.range(1); @@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint2_values( @@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint3_values( @@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint4_values( @@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint5_values( @@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) { } } +BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp index 0d21bc5e5..02a8d7ac9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -228,6 +228,8 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( false>) \ ->ArgsProduct(BENCHMARK_PARAMS) +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( @@ -236,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( @@ -244,6 +248,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 37db7926a..ae5a716a5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -7,6 +7,7 @@ #pragma once #include #include +#include #include #include #include @@ -72,10 +73,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( const int8x16_t& unpacked0, const int8x16_t& unpacked1) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); // Shift unpacked values to nonnegative range @@ -84,6 +85,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); switch (nbit) { + case 1: + uint8_t buffer1[32]; + vst1q_u8(buffer1, shifted0); + vst1q_u8(buffer1 + 16, shifted1); + + torchao::bitpacking::internal::pack_8_uint1_values(packed, buffer1); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 1, buffer1 + 8); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 2, buffer1 + 16); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 3, buffer1 + 24); + break; case 2: torchao::bitpacking::internal::vec_pack_32_uint2_values( packed, @@ -132,16 +146,28 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( int8x16_t& unpacked1, uint8_t* packed) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); uint8x16_t shifted0; uint8x16_t shifted1; switch (nbit) { + case 1: + uint8_t buffer1[32]; + torchao::bitpacking::internal::unpack_8_uint1_values(buffer1, packed); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 8, packed + 1); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 16, packed + 2); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 24, packed + 3); + shifted0 = vld1q_u8(buffer1); + shifted1 = vld1q_u8(buffer1 + 16); + break; case 2: uint8x8_t shifted0_low; uint8x8_t shifted0_high; @@ -197,10 +223,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( const int8x16_t& unpacked2, const int8x16_t& unpacked3) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); // Shift unpacked values to nonnegative range @@ -211,6 +237,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed, shifted0, shifted1, shifted2, shifted3); + break; case 2: torchao::bitpacking::internal::vec_pack_64_uint2_values( packed, shifted0, shifted1, shifted2, shifted3); @@ -242,10 +272,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( int8x16_t& unpacked3, uint8_t* packed) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); uint8x16_t shifted0; @@ -254,6 +284,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( uint8x16_t shifted3; switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + shifted0, shifted1, shifted2, shifted3, packed); + break; case 2: torchao::bitpacking::internal::vec_unpack_64_uint2_values( shifted0, shifted1, shifted2, shifted3, packed); @@ -296,10 +330,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( const int8x16_t& unpacked6, const int8x16_t& unpacked7) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); // Shift unpacked values to nonnegative range @@ -314,6 +348,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift)); switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed, + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7); + break; case 2: torchao::bitpacking::internal::vec_pack_64_uint2_values( packed, shifted0, shifted1, shifted2, shifted3); @@ -371,10 +417,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( int8x16_t& unpacked7, uint8_t* packed) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); uint8x16_t shifted0; @@ -387,6 +433,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( uint8x16_t shifted7; switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7, + packed); + break; case 2: torchao::bitpacking::internal::vec_unpack_64_uint2_values( shifted0, shifted1, shifted2, shifted3, packed); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h new file mode 100644 index 000000000..0a16c7398 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +// This file contains bitpacking and unpacking methods for uint1. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_8_uint1_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Input is 8 bytes + // Output is 1 bytes + packed[0] = 0; + for (int i = 0; i < 8; i++) { + packed[0] |= (unpacked[i] << (7 - i)); + } +} + +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint1_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpacks data packed by pack_8_uint1_values + // + // Input is 8 bits = 1 byte + // Output is 8 bytes + for (int i = 0; i < 8; i++) { + unpacked[i] = (packed[0] >> (7 - i)) & 1; + } +} + +// This function is a vectorized version of pack_8_uint1_values +// To understand it, please see pack_8_uint1_values first. +// +// Input is 64 bytes +// Output is 64 bits = 8 bytes +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint1_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + uint8x16_t vec_packed; + uint8x8_t vec_packed_low; + uint8x8_t vec_packed_high; + vec_packed = vshlq_n_u8(unpacked0, 3); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked1, 2)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked2, 1)); + vec_packed = vorrq_u8(vec_packed, unpacked3); + + vec_packed_low = vget_low_u8(vec_packed); + vec_packed_high = vget_high_u8(vec_packed); + + vst1_u8(packed, vsli_n_u8(vec_packed_low, vec_packed_high, 4)); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint1_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + uint8x8_t vec_packed; + vec_packed = vld1_u8(packed); + + uint8x8_t vec_packed_low; + uint8x8_t vec_packed_high; + vec_packed_low = vand_u8(vec_packed, vdup_n_u8(0xF)); + vec_packed_high = vshr_n_u8(vec_packed, 4); + + uint8x16_t combined = vcombine_u8(vec_packed_low, vec_packed_high); + unpacked0 = vshrq_n_u8(vandq_u8(combined, vdupq_n_u8(8)), 3); + unpacked1 = vshrq_n_u8(vandq_u8(combined, vdupq_n_u8(4)), 2); + unpacked2 = vshrq_n_u8(vandq_u8(combined, vdupq_n_u8(2)), 1); + unpacked3 = vandq_u8(combined, vdupq_n_u8(1)); +} + +// This function is a vectorized version of pack_8_uint1_values +// To understand it, please see `pack_8_uint1_values` first. +// +// Input is 128 bytes +// Output is 128 bytes * 1 bit/8bits = 16 bytes +TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint1_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3, + const uint8x16_t& unpacked4, + const uint8x16_t& unpacked5, + const uint8x16_t& unpacked6, + const uint8x16_t& unpacked7) { + uint8x16_t vec_packed; + + vec_packed = vshlq_n_u8(unpacked0, 7); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked1, 6)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked2, 5)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked3, 4)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked4, 3)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked5, 2)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked6, 1)); + vec_packed = vorrq_u8(vec_packed, unpacked7); + + vst1q_u8(packed, vec_packed); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint1_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + uint8x16_t& unpacked4, + uint8x16_t& unpacked5, + uint8x16_t& unpacked6, + uint8x16_t& unpacked7, + const uint8_t* packed) { + uint8x16_t vec_packed; + vec_packed = vld1q_u8(packed); + + unpacked0 = vandq_u8(vshrq_n_u8(vec_packed, 7), vdupq_n_u8(1)); + unpacked1 = vandq_u8(vshrq_n_u8(vec_packed, 6), vdupq_n_u8(1)); + unpacked2 = vandq_u8(vshrq_n_u8(vec_packed, 5), vdupq_n_u8(1)); + unpacked3 = vandq_u8(vshrq_n_u8(vec_packed, 4), vdupq_n_u8(1)); + unpacked4 = vandq_u8(vshrq_n_u8(vec_packed, 3), vdupq_n_u8(1)); + unpacked5 = vandq_u8(vshrq_n_u8(vec_packed, 2), vdupq_n_u8(1)); + unpacked6 = vandq_u8(vshrq_n_u8(vec_packed, 1), vdupq_n_u8(1)); + unpacked7 = vandq_u8(vec_packed, vdupq_n_u8(1)); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h index 0c6bd8f22..0e8e101ea 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h @@ -19,7 +19,7 @@ namespace internal { TORCHAO_ALWAYS_INLINE inline void pack_8_uint5_values( uint8_t* packed, const uint8_t* unpacked) { - // Given 8 unpacked uint3 values: 0abcd, 1efgh, 2ijkl, 3mnop, 4qrst, 5uvwx, + // Given 8 unpacked uint5 values: 0abcd, 1efgh, 2ijkl, 3mnop, 4qrst, 5uvwx, // 6yzAB, 7CDEF, this function packs them as: // b4: 7|6|5|4|3|2|1|0 (upper bits for all values) // b3210_0: efgh|abcd (lower 4 bits for first 2 values) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index 4c53ec28d..581c3b3e3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,116 @@ #include #include +TEST(test_bitpacking_8_uint1_values, PackUnpackAreSame) { + int unpacked_bytes = 8; + int packed_bytes = 1; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_8_uint1_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_8_uint1_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint1_values, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed.data(), input0, input1, input2, input3); + + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +TEST(test_bitpacking_128_uint1_values, PackUnpackAreSame) { + int unpacked_bytes = 128; + int packed_bytes = 16; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + uint8x16_t input4; + uint8x16_t input5; + uint8x16_t input6; + uint8x16_t input7; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + input4, input5, input6, input7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed.data(), + input0, + input1, + input2, + input3, + input4, + input5, + input6, + input7); + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + EXPECT_EQ(input4[i], unpacked4[i]); + EXPECT_EQ(input5[i], unpacked5[i]); + EXPECT_EQ(input6[i], unpacked6[i]); + EXPECT_EQ(input7[i], unpacked7[i]); + } +} + TEST(test_bitpacking_4_uint2_values, PackUnpackAreSame) { int unpacked_bytes = 4; int packed_bytes = 1; @@ -534,16 +645,19 @@ void test_bitpacking_128_lowbit_values() { test_bitpacking_128_lowbit_values(); \ } +TEST_BITPACKING_32_LOWBIT_VALUES(1); TEST_BITPACKING_32_LOWBIT_VALUES(2); TEST_BITPACKING_32_LOWBIT_VALUES(3); TEST_BITPACKING_32_LOWBIT_VALUES(4); TEST_BITPACKING_32_LOWBIT_VALUES(5); +TEST_BITPACKING_64_LOWBIT_VALUES(1); TEST_BITPACKING_64_LOWBIT_VALUES(2); TEST_BITPACKING_64_LOWBIT_VALUES(3); TEST_BITPACKING_64_LOWBIT_VALUES(4); TEST_BITPACKING_64_LOWBIT_VALUES(5); +TEST_BITPACKING_128_LOWBIT_VALUES(1); TEST_BITPACKING_128_LOWBIT_VALUES(2); TEST_BITPACKING_128_LOWBIT_VALUES(3); TEST_BITPACKING_128_LOWBIT_VALUES(4); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 4e5083d9e..b9b03c777 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -25,7 +25,7 @@ get_random_vector(int size, float min = -1.0, float max = 1.0) { } inline std::vector get_random_lowbit_vector(int size, int nbit) { - assert(nbit >= 2); + assert(nbit >= 1); assert(nbit <= 8); int min = 0;