2727
2828#include " arrow/util/bit_util.h"
2929#include " arrow/util/bpacking_dispatch_internal.h"
30+ #include " arrow/util/type_traits.h"
3031
3132namespace arrow ::internal {
3233
@@ -38,7 +39,6 @@ namespace arrow::internal {
3839// - array to batch constant to xsimd
3940// - Shifts per swizzle can be improved when self.packed_max_byte_spread == 1 and the
4041// byte can be reused (when val_bit_width divides packed_max_byte_spread).
41- // - Try for uint16_t and uint8_t and bool (currently copy)
4242// - Add unpack_exact to benchmarks
4343// - Reduce input size on small bit width using a broadcast.
4444// - For Avx2:
@@ -112,7 +112,10 @@ struct KernelTraits {
112112 };
113113
114114 using unpacked_type = UnpackedUint;
115- using simd_batch = xsimd::make_sized_batch_t <unpacked_type, kShape .unpacked_per_simd()>;
115+ // The integer type to work with, `unpacked_type` or an appropriate type for bool.
116+ using uint_type = std::conditional_t <std::is_same_v<unpacked_type, bool >,
117+ SizedUint<sizeof (bool )>, unpacked_type>;
118+ using simd_batch = xsimd::make_sized_batch_t <uint_type, kShape .unpacked_per_simd()>;
116119 using simd_bytes = xsimd::make_sized_batch_t <uint8_t , kShape .simd_byte_size()>;
117120 using arch_type = typename simd_batch::arch_type;
118121};
@@ -184,6 +187,7 @@ constexpr MediumKernelPlanSize BuildMediumPlanSize(const KernelShape& shape) {
184187template <typename UnpackedUint, int kPackedBitSize , int kSimdBitSize >
185188struct MediumKernelPlan {
186189 using Traits = KernelTraits<UnpackedUint, kPackedBitSize , kSimdBitSize >;
190+ using uint_type = typename Traits::uint_type;
187191 static constexpr auto kShape = Traits::kShape ;
188192 static constexpr auto kPlanSize = BuildMediumPlanSize(kShape );
189193
@@ -193,7 +197,7 @@ struct MediumKernelPlan {
193197 using SwizzlesPerRead = std::array<Swizzle, kPlanSize .swizzles_per_read()>;
194198 using SwizzlesPerKernel = std::array<SwizzlesPerRead, kPlanSize .reads_per_kernel()>;
195199
196- using Shift = std::array<UnpackedUint , kShape .unpacked_per_simd()>;
200+ using Shift = std::array<uint_type , kShape .unpacked_per_simd()>;
197201 using ShiftsPerSwizzle = std::array<Shift, kPlanSize .shifts_per_swizzle()>;
198202 using ShiftsPerRead = std::array<ShiftsPerSwizzle, kPlanSize .swizzles_per_read()>;
199203 using ShitsPerKernel = std::array<ShiftsPerRead, kPlanSize .reads_per_kernel()>;
@@ -212,7 +216,7 @@ struct MediumKernelPlan {
212216 ReadsPerKernel reads;
213217 SwizzlesPerKernel swizzles;
214218 ShitsPerKernel shifts;
215- UnpackedUint mask = bit_util::LeastSignificantBitMask<UnpackedUint>(kPackedBitSize );
219+ uint_type mask = bit_util::LeastSignificantBitMask<UnpackedUint>(kPackedBitSize );
216220};
217221
218222template <typename UnpackedUint, int kPackedBitSize , int kSimdBitSize >
@@ -427,6 +431,7 @@ struct MediumKernel {
427431 static constexpr auto kShape = kPlan .kShape ;
428432 using Traits = typename decltype (kPlan )::Traits;
429433 using unpacked_type = typename Traits::unpacked_type;
434+ using uint_type = typename Traits::uint_type;
430435 using simd_batch = typename Traits::simd_batch;
431436 using simd_bytes = typename Traits::simd_bytes;
432437 using arch_type = typename Traits::arch_type;
@@ -448,7 +453,12 @@ struct MediumKernel {
448453 // can use the fallback on these platforms.
449454 const auto shifted = right_shift_by_excess (words, kRightShifts );
450455 const auto vals = shifted & kMask ;
451- xsimd::store_unaligned (out + kOutOffset , vals);
456+ if constexpr (std::is_same_v<unpacked_type, bool >) {
457+ const xsimd::batch_bool<uint_type, arch_type> bools (vals);
458+ bools.store_unaligned (out + kOutOffset );
459+ } else {
460+ vals.store_unaligned (out + kOutOffset );
461+ }
452462 }
453463
454464 template <int kReadIdx , int kSwizzleIdx , int ... kShiftIds >
@@ -458,7 +468,7 @@ struct MediumKernel {
458468 constexpr auto kSwizzles = make_batch_constant<kSwizzlesArr , arch_type>();
459469
460470 const auto swizzled = swizzle_bytes (bytes, kSwizzles );
461- const auto words = xsimd::bitwise_cast<unpacked_type >(swizzled);
471+ const auto words = xsimd::bitwise_cast<uint_type >(swizzled);
462472 (unpack_one_shift_impl<kReadIdx , kSwizzleIdx , kShiftIds >(words, out), ...);
463473 }
464474
@@ -487,6 +497,7 @@ struct MediumKernel {
487497template <typename UnpackedUint, int kPackedBitSize , int kSimdBitSize >
488498struct LargeKernelPlan {
489499 using Traits = KernelTraits<UnpackedUint, kPackedBitSize , kSimdBitSize >;
500+ using uint_type = typename Traits::uint_type;
490501 static constexpr auto kShape = Traits::kShape ;
491502
492503 static constexpr int kUnpackedPerkernel = std::lcm(kShape .unpacked_per_simd(), 8 );
@@ -498,20 +509,21 @@ struct LargeKernelPlan {
498509 using Swizzle = std::array<uint8_t , kShape .simd_byte_size()>;
499510 using SwizzlesPerKernel = std::array<Swizzle, kReadsPerKernel >;
500511
501- using Shift = std::array<UnpackedUint , kShape .unpacked_per_simd()>;
512+ using Shift = std::array<uint_type , kShape .unpacked_per_simd()>;
502513 using ShitsPerKernel = std::array<Shift, kReadsPerKernel >;
503514
504515 ReadsPerKernel reads;
505516 SwizzlesPerKernel low_swizzles;
506517 SwizzlesPerKernel high_swizzles;
507518 ShitsPerKernel low_rshifts;
508519 ShitsPerKernel high_lshifts;
509- UnpackedUint mask;
520+ uint_type mask;
510521};
511522
512523template <typename UnpackedUint, int kPackedBitSize , int kSimdBitSize >
513524constexpr LargeKernelPlan<UnpackedUint, kPackedBitSize , kSimdBitSize > BuildLargePlan () {
514525 using Plan = LargeKernelPlan<UnpackedUint, kPackedBitSize , kSimdBitSize >;
526+ using uint_type = typename Plan::Traits::uint_type;
515527 constexpr auto kShape = Plan::kShape ;
516528 static_assert (kShape .is_large ());
517529 constexpr int kOverBytes =
@@ -550,7 +562,7 @@ constexpr LargeKernelPlan<UnpackedUint, kPackedBitSize, kSimdBitSize> BuildLarge
550562 }
551563 }
552564
553- plan.mask = bit_util::LeastSignificantBitMask<UnpackedUint >(kPackedBitSize );
565+ plan.mask = bit_util::LeastSignificantBitMask<uint_type >(kPackedBitSize );
554566
555567 return plan;
556568}
@@ -648,21 +660,4 @@ struct Kernel : DispatchKernelType<UnpackedUint, kPackedBitSize, kSimdBitSize> {
648660 using Base::unpack;
649661};
650662
651- template <int kPackedBitSize , int kSimdBitSize >
652- struct Kernel <bool , kPackedBitSize , kSimdBitSize >
653- : Kernel<uint16_t , kPackedBitSize , kSimdBitSize > {
654- using Base = DispatchKernelType<uint16_t , kPackedBitSize , kSimdBitSize >;
655- using Base::kValuesUnpacked ;
656- using unpacked_type = bool ;
657-
658- static const uint8_t * unpack (const uint8_t * in, unpacked_type* out) {
659- uint16_t buffer[kValuesUnpacked ] = {};
660- in = Base::unpack (in, buffer);
661- for (int k = 0 ; k < kValuesUnpacked ; ++k) {
662- out[k] = static_cast <unpacked_type>(buffer[k]);
663- }
664- return in;
665- }
666- };
667-
668663} // namespace arrow::internal
0 commit comments