Skip to content

Commit

Permalink
Introducing 1-bit quantization for Llama in torchchat (#910)
Browse files Browse the repository at this point in the history
Differential Revision: D63052325

Pull Request resolved: #911
  • Loading branch information
vaishnavi17 authored Sep 20, 2024
1 parent 23321fb commit 4bce694
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <benchmark/benchmark.h>

#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
Expand All @@ -16,6 +17,128 @@

namespace {

// Benchmark utility to compare variants of uint1 packing
void pack_uint1_values(
uint8_t* packed,
uint8_t* unpacked,
int packed_size,
int unpacked_size,
int variant) {
constexpr int nbit = 1;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;
uint8x16_t unpacked4;
uint8x16_t unpacked5;
uint8x16_t unpacked6;
uint8x16_t unpacked7;

switch (variant) {
case 8:
for (int i = 0; i < unpacked_size; i += 8) {
torchao::bitpacking::internal::pack_8_uint1_values(
packed + ((i * nbit) / bitsPerByte), unpacked + i);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
torchao::bitpacking::internal::vec_pack_64_uint1_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0,
unpacked1,
unpacked2,
unpacked3);
}
break;
case 128:
for (int i = 0; i < unpacked_size; i += 128) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64);
torchao::bitpacking::internal::vec_pack_128_uint1_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0,
unpacked1,
unpacked2,
unpacked3,
unpacked4,
unpacked5,
unpacked6,
unpacked7);
}
break;
}
}

// Benchmark utility to compare variants of uint1 packing
void unpack_uint1_values(
uint8_t* unpacked,
uint8_t* packed,
int unpacked_size,
int packed_size,
int variant) {
constexpr int nbit = 1;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;
uint8x16_t unpacked4;
uint8x16_t unpacked5;
uint8x16_t unpacked6;
uint8x16_t unpacked7;

switch (variant) {
case 8:
for (int i = 0; i < unpacked_size; i += 8) {
torchao::bitpacking::internal::unpack_8_uint1_values(
unpacked + i, packed + ((i * nbit) / bitsPerByte));
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_unpack_64_uint1_values(
unpacked0,
unpacked1,
unpacked2,
unpacked3,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
}
break;
case 128:
for (int i = 0; i < unpacked_size; i += 128) {
torchao::bitpacking::internal::vec_unpack_128_uint1_values(
unpacked0,
unpacked1,
unpacked2,
unpacked3,
unpacked4,
unpacked5,
unpacked6,
unpacked7,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7);
}
break;
}
}

// Benchmark utility to compare variants of uint2 packing
void pack_uint2_values(
uint8_t* packed,
Expand Down Expand Up @@ -470,6 +593,44 @@ void unpack_uint5_values(

} // namespace

static void benchmark_pack_uint1_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 1;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint1_values(
packed.data(), unpacked.data(), packed_size, unpacked_size, variant);
}
}

static void benchmark_unpack_uint1_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 1;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = torchao::get_random_lowbit_vector(packed_size, 8);
auto unpacked = std::vector<uint8_t>(unpacked_size, 0);

for (auto _ : state) {
unpack_uint1_values(
unpacked.data(),
packed.data(),
unpacked.size(),
packed.size(),
variant);
}
}

static void benchmark_pack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
Expand All @@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint2_values(
Expand Down Expand Up @@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint3_values(
Expand Down Expand Up @@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint4_values(
Expand Down Expand Up @@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint5_values(
Expand Down Expand Up @@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
}
}

BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
false>) \
->ArgsProduct(BENCHMARK_PARAMS)

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
Expand All @@ -236,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
Expand All @@ -244,6 +248,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
Expand Down
Loading

0 comments on commit 4bce694

Please sign in to comment.