diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h index 9227410b28..897ec44549 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h @@ -44,7 +44,17 @@ chunked and interleaved during the packing process. * @param input Pointer to the source activation matrix (float32, row-major). */ template -inline void pack_activations(float* output, int m, int k, const float* input) { +inline void pack_activations( + float* output, + int m, + int k, + const float* input, + int mr, + int kr, + int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused { activation_packing::pack_activations(output, m, k, input); } @@ -100,7 +110,7 @@ row-major). * @param bias Pointer to the bias vector (float32, row-major). */ template -void pack_weights_for_groupwise_lut_kernel( +void pack_weights( /*output*/ void* packed_weights_ptr, /*inputs*/ @@ -113,7 +123,13 @@ void pack_weights_for_groupwise_lut_kernel( int lut_group_size, bool has_scales, bool has_bias, - const float* bias) { + const float* bias, + int nr, + int kr, + int sr) { + (void)nr; // unused + (void)kr; // unused + (void)sr; // unused weight_packing::pack_weights( packed_weights_ptr, weight_qvals_indices, @@ -190,7 +206,11 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32( * @param k The K dimension (width) of the activation matrix. * @return The byte offset from the start of the buffer. */ -inline size_t packed_activations_offset(int m_idx, int k) { +inline size_t +packed_activations_offset(int m_idx, int k, int mr, int kr, int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused // For a simple padded row-major format, the offset is just m_idx * k. return sizeof(float) * m_idx * k; } diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp index 059c62c027..6cd9ee8dfa 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp @@ -71,7 +71,13 @@ void test_groupwise_lowbit_lut_kernel( std::vector packed_activations_buffer( kernel_api::packed_activations_size(m, k, mr_, kr_, sr_)); kernel_api::pack_activations( - packed_activations_buffer.data(), m, k, source_activations.data()); + packed_activations_buffer.data(), + m, + k, + source_activations.data(), + mr_, + kr_, + sr_); // 3. Pack Weights std::vector packed_weights(kernel_api::packed_weights_size( n, @@ -83,19 +89,21 @@ void test_groupwise_lowbit_lut_kernel( nr_, kr_, sr_)); - kernel_api:: - pack_weights_for_groupwise_lut_kernel( - packed_weights.data(), - test_case.weight_qval_indices.data(), - test_case.weight_scales.data(), - test_case.weight_luts.data(), - n, - k, - flat_scale_group_size, - flat_lut_group_size, - has_scales_, - has_bias, - test_case.bias.data()); + kernel_api::pack_weights( + packed_weights.data(), + test_case.weight_qval_indices.data(), + test_case.weight_scales.data(), + test_case.weight_luts.data(), + n, + k, + flat_scale_group_size, + flat_lut_group_size, + has_scales_, + has_bias, + test_case.bias.data(), + nr_, + kr_, + sr_); // 4. Run the kernel std::vector output(m * n); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index aeb9042210..159a6d6dac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -640,11 +640,10 @@ struct groupwise_lowbit_weight_lut_test_case { const int total_weights = n * k; // Frequencies are controlled by their group sizes. assert(total_weights % scale_group_size == 0); - assert(total_weights % lut_group_size == 0); // The number of unique scales/LUTs is derived directly from their group size. const int num_scales = total_weights / scale_group_size; - const int num_luts = total_weights / lut_group_size; + const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size; const int lut_size = 1 << weight_nbit; std::mt19937 gen(std::random_device{}()); @@ -726,9 +725,6 @@ struct groupwise_lowbit_weight_lut_test_case { int weight_nbit, bool has_scales, bool has_bias, bool has_clamp) { - std::cout << "[Generator Info] Using 'Per-Group' model.\n" - << " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl; - // Just call the decoupled generator with the same group size for both. return _generate_master( m, k, n, @@ -748,10 +744,6 @@ struct groupwise_lowbit_weight_lut_test_case { int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales, bool has_bias, bool has_clamp) { - std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n" - << " - Scales will switch every " << scale_group_size << " weights.\n" - << " - LUTs will switch every " << lut_group_size << " weights." << std::endl; - return _generate_master( m, k, n, scale_group_size, lut_group_size, diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp b/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp index e5c37ea7a6..c0d452c95b 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp +++ b/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp @@ -28,10 +28,12 @@ void pack_weights_operator( const float* weight_scales, const float* weight_luts, const float* bias) { - TORCHAO_CHECK( - lut_group_size % scale_group_size == 0, - "scale_group_size must devide lut_group_size"); - TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k"); + if (uk.has_scales) { + TORCHAO_CHECK( + lut_group_size % scale_group_size == 0, + "scale_group_size must devide lut_group_size"); + TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k"); + } TORCHAO_CHECK( lut_group_size % (k * uk.nr) == 0, "lut_group_size must be a multiple of k*nr"); @@ -139,14 +141,17 @@ void groupwise_lowbit_weight_lut_parallel_operator( bool has_clamp, float clamp_min, float clamp_max) { - TORCHAO_CHECK( - lut_group_size % scale_group_size == 0, - "scale_group_size must divide lut_group_size"); - TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k"); + if (uk.has_scales) { + TORCHAO_CHECK( + lut_group_size % scale_group_size == 0, + "scale_group_size must divide lut_group_size"); + TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k"); + TORCHAO_CHECK( + scale_group_size % uk.kr == 0, "kr must divide scale_group_size"); + } + TORCHAO_CHECK( lut_group_size % (k * uk.nr) == 0, "(k * nr) must divide lut_group_size"); - TORCHAO_CHECK( - scale_group_size % uk.kr == 0, "kr must divide scale_group_size"); int config_idx = uk.select_config_idx(m); auto& kernel_config = uk.configs[config_idx]; int n_step = uk.n_step; @@ -191,7 +196,7 @@ void groupwise_lowbit_weight_lut_parallel_operator( mc_tile_size, k, activation_row_ptr, - kernel_config.mr, + uk.nr, uk.kr, uk.sr); diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h b/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h index 2a27110174..6b3ab28310 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h +++ b/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h @@ -150,32 +150,37 @@ struct UKernelConfig { packed_weights_offset != nullptr, "packed_weights_offset_fn_type must be set"); TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); - // 2. Validate the Array of Linear Configurations // At least one configuration must be defined. TORCHAO_CHECK( !configs.empty(), "At least one valid kernel configuration must be provided."); + bool configs_set = true; // first linear config must be set for (size_t i = 0; i < configs.size(); ++i) { - const auto& config = configs[i]; - - TORCHAO_CHECK( - config.packed_activations_size != nullptr, - "config.packed_activations_size must be set"); - TORCHAO_CHECK( - config.pack_activations != nullptr, - "config.pack_activations must be set"); - TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set"); - - if (i > 0) { - const auto& prev_config = configs[i - 1]; + if (configs_set) { + const auto& config = configs[i]; + TORCHAO_CHECK( - prev_config.m_step > 0, - "There cannot be a gap in configurations (m_step=0 followed by m_step>0)"); + config.packed_activations_size != nullptr, + "config.packed_activations_size must be set"); TORCHAO_CHECK( - prev_config.m_step < config.m_step, - "m_step values in configs must be strictly increasing."); + config.pack_activations != nullptr, + "config.pack_activations must be set"); + TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set"); + + if (i > 0) { + const auto& prev_config = configs[i - 1]; + TORCHAO_CHECK( + prev_config.m_step > 0, + "There cannot be a gap in configurations (m_step=0 followed by m_step>0)"); + TORCHAO_CHECK( + prev_config.m_step < config.m_step, + "m_step values in configs must be strictly increasing."); + } + if (i + 1 < configs.size()) { + configs_set = (configs[i + 1].m_step >= 1); + } } } } diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h b/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h index ae1b568994..e898ba5af4 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h +++ b/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h @@ -13,9 +13,7 @@ #include #if defined(TORCHAO_BUILD_CPU_AARCH64) -#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) -#include -#endif // TORCHAO_ENABLE_ARM_NEON_DOT +#include #endif // TORCHAO_BUILD_CPU_AARCH64 namespace torchao::ops::groupwise_lowbit_weight_lut { @@ -122,19 +120,22 @@ void register_ukernel_config( torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut; using kernel_fn_ptr_t = - decltype(&kernel_api::kernel_lowbit_1x4x32_f32); + decltype(&kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32< + weight_nbit, + true>); kernel_fn_ptr_t kernel_dispatcher; if (format.has_scales) { - kernel_dispatcher = - &kernel_api::kernel_lowbit_1x4x32_f32; + kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32< + weight_nbit, + /*has_scales=*/true>; } else { - kernel_dispatcher = - &kernel_api:: - kernel_lowbit_1x4x32_f32; + kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32< + weight_nbit, + /*has_scales=*/false>; } if (format.nr == 4 && format.kr == 32 && format.sr == 8) { - log_registration(format, "lut: kernel_lowbit_1x4x32_f32"); + log_registration(format, "lut: groupwise_lowbit_weight_lut_kernel_1x4x32"); constexpr int nr = 4; constexpr int kr = 32; constexpr int sr = 8; @@ -152,22 +153,25 @@ void register_ukernel_config( /*has_scales=*/format.has_scales, /*has_bias=*/format.has_bias, /*packed_weights_size_fn_type=*/ - &kernel_api::packed_weights_size, + &kernel_api::packed_weights_size, + /*packed_weights_offset_fn_type=*/ + &kernel_api::packed_weights_offset, /*pack_weights_fn_type=*/ &kernel_api:: - pack_weights_for_groupwise_lut_kernel, + pack_weights, /*configs=*/{}); - uk.configs[0] = UKernelConfig::group_config_type( + uk.configs[0] = UKernelConfig::config_type {m_step, mr, &kernel_api::packed_activations_size, &kernel_api::packed_activations_offset, &kernel_api::pack_activations, - kernel_dispatcher}); + kernel_dispatcher}; // Resgister the kernel config. table.register_ukernel_config(format, uarch, std::move(uk)); + return; } } #endif // TORCHAO_BUILD_CPU_AARCH64 @@ -206,7 +210,9 @@ UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) { register_ukernel_config(table, format, uarch); ukernel = table.get_ukernel_config(header, uarch); - assert(ukernel.has_value() && "Kernel registration failed for the current CPU microarchitecture."); + assert( + ukernel.has_value() && + "Kernel registration failed for the current CPU microarchitecture."); return ukernel.value(); #else throw std::runtime_error( diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h b/torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h index 4fba6edb09..9ea50425b7 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h +++ b/torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h @@ -63,7 +63,7 @@ struct PackedWeightsFormat { static_cast(header.params[4]), // has_bias header.params[5], // nr header.params[6], // kr - header.params[7], // sr + header.params[7] // sr ); }