@@ -584,7 +584,14 @@ struct LargeKernel {
584584
585585 const auto low_swizzled = swizzle_bytes (bytes, kLowSwizzles );
586586 const auto low_words = xsimd::bitwise_cast<unpacked_type>(low_swizzled);
587- const auto low_shifted = right_shift_by_excess (low_words, kLowRShifts );
587+ simd_batch low_shifted;
588+ if constexpr (kShape .unpacked_byte_size () == 1 ) {
589+ // The logic of the fallback in right_shift_by_excess does not work for this single
590+ // byte case case, so we use directly xsimd and its scalar fallback.
591+ low_shifted = low_words >> kLowRShifts ;
592+ } else {
593+ low_shifted = right_shift_by_excess (low_words, kLowRShifts );
594+ }
588595
589596 const auto high_swizzled = swizzle_bytes (bytes, kHighSwizzles );
590597 const auto high_words = xsimd::bitwise_cast<unpacked_type>(high_swizzled);
@@ -641,23 +648,6 @@ struct Kernel : DispatchKernelType<UnpackedUint, kPackedBitSize, kSimdBitSize> {
641648 using Base::unpack;
642649};
643650
644- template <int kPackedBitSize , int kSimdBitSize >
645- struct Kernel <uint8_t , kPackedBitSize , kSimdBitSize >
646- : Kernel<uint16_t , kPackedBitSize , kSimdBitSize > {
647- using Base = DispatchKernelType<uint16_t , kPackedBitSize , kSimdBitSize >;
648- using Base::kValuesUnpacked ;
649- using unpacked_type = uint8_t ;
650-
651- static const uint8_t * unpack (const uint8_t * in, unpacked_type* out) {
652- uint16_t buffer[kValuesUnpacked ] = {};
653- in = Base::unpack (in, buffer);
654- for (int k = 0 ; k < kValuesUnpacked ; ++k) {
655- out[k] = static_cast <unpacked_type>(buffer[k]);
656- }
657- return in;
658- }
659- };
660-
661651template <int kPackedBitSize , int kSimdBitSize >
662652struct Kernel <bool , kPackedBitSize , kSimdBitSize >
663653 : Kernel<uint16_t , kPackedBitSize , kSimdBitSize > {
0 commit comments