1515
1616#if defined(__aarch64__) && !defined(C10_MOBILE)
1717#include < arm_neon.h>
18+ #include < cpuinfo.h>
1819#endif
1920
2021C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED (" -Wunused-function" )
@@ -301,7 +302,7 @@ static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationS
301302static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift ;
302303static_assert (kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister );
303304
304- static inline double reduce (float16x8_t x[kF16RegistersPerIteration ]) {
305+ static inline float reduce (float16x8_t x[kF16RegistersPerIteration ]) {
305306 int offset = kF16RegistersPerIteration ;
306307 c10::ForcedUnroll<kF16RegistersPerIterationShift >{}([&offset, &x](auto idx) {
307308 offset /= 2 ;
@@ -311,7 +312,7 @@ static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
311312 });
312313 const float32x4_t t0 = vcvt_f32_f16 (vget_low_f16 (x[0 ]));
313314 const float32x4_t t1 = vcvt_f32_f16 (vget_high_f16 (x[0 ]));
314- return ( double ) vaddvq_f32 (vaddq_f32 (t0, t1));
315+ return vaddvq_f32 (vaddq_f32 (t0, t1));
315316}
316317
317318static inline float16x8_t f16_fma (float16x8_t a, float16x8_t b, float16x8_t c) {
@@ -333,12 +334,12 @@ static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, in
333334 sum[k] = f16_fma (sum[k], temp_x, temp_a);
334335 }
335336 }
336- auto reducedSum = reduce (sum);
337+ auto reduced_sum = reduce (sum);
337338
338339 for (int j = len_aligned; j < len; ++j) {
339- reducedSum += x[j] * a[j];
340+ reduced_sum += x[j] * a[j];
340341 }
341- return reducedSum ;
342+ return reduced_sum ;
342343}
343344
344345// Rather than unrolling to process multiple rows (transposed columns)
@@ -352,7 +353,7 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n,
352353 });
353354}
354355
355- #endif
356+ #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
356357
357358static inline float reduce (float32x4_t x) {
358359 auto sum = vpaddq_f32 (x, x);
@@ -412,7 +413,7 @@ static constexpr auto kF32RegistersPerIterationShift = 3;
412413static_assert (kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister );
413414static_assert (kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift );
414415
415- static inline double reduce (float32x4_t x[kF32RegistersPerIteration ]) {
416+ static inline float reduce (float32x4_t x[kF32RegistersPerIteration ]) {
416417 int offset = kF32RegistersPerIteration ;
417418 c10::ForcedUnroll<kF32RegistersPerIterationShift >{}([&offset, &x](auto idx) {
418419 offset /= 2 ;
@@ -423,7 +424,7 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
423424 return vaddvq_f32 (x[0 ]);
424425}
425426
426- static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop (
427+ static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot (
427428 const float16_t * vec1,
428429 const float16_t * vec2,
429430 float32x4_t sum[kF32RegistersPerIteration ],
@@ -436,86 +437,217 @@ static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
436437 sum[2 * registerPairIndex + 1 ] = f32_fma_high_f16 (sum[2 * registerPairIndex + 1 ], temp_vec1, temp_vec2);
437438}
438439
439- static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop (
440- const float16_t * vec1,
441- const float16_t * vec2,
442- float32x4_t * tailSum ,
443- int idx) {
440+ static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot (
441+ const float16_t * vec1,
442+ const float16_t * vec2,
443+ float32x4_t * tail_sum ,
444+ int idx) {
444445 const auto temp_vec1 = vld1_f16 (&vec1[idx]);
445446 const auto temp_vec2 = vld1_f16 (&vec2[idx]);
446- *tailSum = f32_fma_f16 (*tailSum , temp_vec1, temp_vec2);
447+ *tail_sum = f32_fma_f16 (*tail_sum , temp_vec1, temp_vec2);
447448}
448449
449- static C10_ALWAYS_INLINE float32x4_t to_bfloat16 (uint16x4_t u16 ) {
450+ static float32x4_t to_bfloat16 (uint16x4_t u16 ) {
450451 int32x4_t shift = vdupq_n_s32 (16 );
451452 return vreinterpretq_f32_u32 (vshlq_u32 (vmovl_u16 (u16 ), shift));
452453}
453454
454- static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16 (float32x4_t a, uint16x4_t b, uint16x4_t c) {
455+ static float32x4_t f32_fma_bf16 (float32x4_t a, uint16x4_t b, uint16x4_t c) {
455456 return f32_fma (a, to_bfloat16 (b), to_bfloat16 (c));
456457}
457458
458- static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop (
459- const at::BFloat16* vec1,
460- const at::BFloat16* vec2,
461- float32x4_t sum[kF32RegistersPerIteration ],
462- int registerPairIndex) {
463- // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16
464- // Load a pair of f32 registers at a time.
465- const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
466- const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
459+ #if defined(__clang__) && __clang_major__ > 15
460+ // https://godbolt.org/z/z8P4Yncra
461+ #define COMPILER_SUPPORTS_BF16_TARGET 1
462+ #elif !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
463+ // https://gcc.gnu.org/gcc-10/changes.html
464+ // https://godbolt.org/z/cdGG7vn8o
465+ #define COMPILER_SUPPORTS_BF16_TARGET 1
466+ #else
467+ #define COMPILER_SUPPORTS_BF16_TARGET 0
468+ #endif
469+
470+ #if COMPILER_SUPPORTS_BF16_TARGET
471+ #define TARGET_ARM_BF16_ATTRIBUTE __attribute__ ((target(" arch=armv8.2-a+bf16" )))
472+
473+ TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE float32x4_t
474+ f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
475+ return vbfdotq_f32 (a, b, c);
476+ }
467477
468- sum[2 * registerPairIndex] = f32_fma_bf16 (sum[2 * registerPairIndex], vget_low_u16 (temp_vec1), vget_low_u16 (temp_vec2));
469- sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (sum[2 * registerPairIndex + 1 ], vget_high_u16 (temp_vec1), vget_high_u16 (temp_vec2));
478+ TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE void
479+ dot_with_fp32_arith_main_inner_loop_bfdot (
480+ const BFloat16* vec1,
481+ const BFloat16* vec2,
482+ float32x4_t sum[kF32RegistersPerIteration ],
483+ int registerPairIndex) {
484+ const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
485+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
486+ const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
487+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
488+ sum[registerPairIndex] =
489+ f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
470490}
471491
472- static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop (
473- const at::BFloat16* vec1,
474- const at::BFloat16* vec2,
475- float32x4_t * tailSum,
476- int idx) {
492+ // See NOTE [GCC code duplication] below for why we have _bfdot and
493+ // _no_bfdot versions of
494+ // dot_with_fp32_arith_vectorized_tail_inner_loop.
495+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
496+ static void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot (
497+ const at::BFloat16* vec1,
498+ const at::BFloat16* vec2,
499+ float32x4_t * tail_sum,
500+ int idx) {
477501 const auto temp_vec1 = vld1_u16 (reinterpret_cast <const uint16_t *>(&vec1[idx]));
478502 const auto temp_vec2 = vld1_u16 (reinterpret_cast <const uint16_t *>(&vec2[idx]));
479- *tailSum = f32_fma_bf16 (*tailSum , temp_vec1, temp_vec2);
503+ *tail_sum = f32_fma_bf16 (*tail_sum , temp_vec1, temp_vec2);
480504}
481505
482- template <typename T>
483- float dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
506+ #else
507+ #define TARGET_ARM_BF16_ATTRIBUTE
508+ #endif // COMPILER_SUPPORTS_BF16_TARGET
509+
510+ static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot (
511+ const BFloat16* vec1,
512+ const BFloat16* vec2,
513+ float32x4_t sum[kF32RegistersPerIteration ],
514+ int registerPairIndex) {
515+ const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
516+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
517+ const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
518+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
519+
520+ sum[2 * registerPairIndex] = f32_fma_bf16 (
521+ sum[2 * registerPairIndex],
522+ vget_low_u16 (temp_vec1),
523+ vget_low_u16 (temp_vec2));
524+ sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
525+ sum[2 * registerPairIndex + 1 ],
526+ vget_high_u16 (temp_vec1),
527+ vget_high_u16 (temp_vec2));
528+ }
529+
530+ static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot (
531+ const at::BFloat16* vec1,
532+ const at::BFloat16* vec2,
533+ float32x4_t * tail_sum,
534+ int idx) {
535+ const auto temp_vec1 = vld1_u16 (reinterpret_cast <const uint16_t *>(&vec1[idx]));
536+ const auto temp_vec2 = vld1_u16 (reinterpret_cast <const uint16_t *>(&vec2[idx]));
537+ *tail_sum = f32_fma_bf16 (*tail_sum, temp_vec1, temp_vec2);
538+ }
539+
540+ namespace {
541+ #if COMPILER_SUPPORTS_BF16_TARGET
542+ template <int n>
543+ struct ForcedUnrollTargetBFloat16 {
544+ template <typename Func>
545+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(const Func& f) const {
546+ ForcedUnrollTargetBFloat16<n - 1 >{}(f);
547+ f (n - 1 );
548+ }
549+ };
550+
551+ template <>
552+ struct ForcedUnrollTargetBFloat16 <1 > {
553+ template <typename Func>
554+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(const Func& f) const {
555+ f (0 );
556+ }
557+ };
558+
559+ C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto
560+ dot_with_fp32_arith_main_loop_bfdot (
561+ const BFloat16* vec1,
562+ const BFloat16* vec2,
563+ int64_t len) {
484564 float32x4_t sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
485565 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
486566 for (int j = 0 ; j < len_aligned ; j += kF32ElementsPerIteration ) {
487567 const auto * vec1_ = vec1 + j;
488568 const auto * vec2_ = vec2 + j;
489- c10::ForcedUnroll<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k) {
490- dot_with_fp32_arith_main_inner_loop (vec1_, vec2_, sum, k);
569+ ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k)
570+ C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
571+ dot_with_fp32_arith_main_inner_loop_bfdot (vec1_, vec2_, sum, k);
491572 });
492573 }
493- auto reducedSum = reduce (sum);
494-
495- // First-tier tail fixup: make sure we handle workloads that can
496- // benefit from vectorization, but don't fit into our fully unrolled
497- // loop above.
498- float32x4_t tailSum = vdupq_n_f32 (0 );
499- const auto len_aligned_4 = len & ~3 ;
500- for (int j = len_aligned; j < len_aligned_4; j += 4 ) {
501- dot_with_fp32_arith_vectorized_tail_inner_loop (vec1, vec2, &tailSum, j);
502- }
503- auto reducedTail = vpaddq_f32 (tailSum, tailSum);
504- reducedSum += vgetq_lane_f32 (vpaddq_f32 (reducedTail, reducedTail), 0 );
574+ return reduce (sum);
575+ }
576+ #endif // COMPILER_SUPPORTS_BF16_TARGET
505577
506- // Second-tier tail fixup: handle all workloads.
507- for (int j = len_aligned_4; j < len; ++j) {
508- reducedSum += vec1[j] * vec2[j];
578+ template <typename T>
579+ C10_ALWAYS_INLINE auto
580+ dot_with_fp32_arith_main_loop_no_bfdot (
581+ const T* vec1,
582+ const T* vec2,
583+ int64_t len) {
584+ float32x4_t sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
585+ const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
586+ for (int j = 0 ; j < len_aligned ; j += kF32ElementsPerIteration ) {
587+ const auto * vec1_ = vec1 + j;
588+ const auto * vec2_ = vec2 + j;
589+ c10::ForcedUnroll<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
590+ dot_with_fp32_arith_main_inner_loop_no_bfdot (vec1_, vec2_, sum, k);
591+ });
509592 }
510- return reducedSum;
593+ return reduce (sum);
594+ }
595+
596+ // NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
597+ // TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not
598+ // allow inlining a non-bf16-specific function into a bf16-specific
599+ // function. We can work around this by duplicating the code into the
600+ // bfdot and non-bfdot callsites. The code is in this macro to avoid
601+ // actual copy/paste.
602+ #define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (bfdot_suffix ) \
603+ /* First-tier tail fixup: make sure we handle workloads that can */ \
604+ /* benefit from vectorization, but don't fit into our fully unrolled */ \
605+ /* loop above. */ \
606+ float32x4_t tail_sum = vdupq_n_f32(0 ); \
607+ const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 ); \
608+ const auto len_aligned_4 = len & ~3 ; \
609+ for (int j = len_aligned; j < len_aligned_4; j += 4 ) { \
610+ dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix (vec1, vec2, &tail_sum, j); \
611+ } \
612+ auto reduced_tail = vpaddq_f32(tail_sum, tail_sum); \
613+ reduced_sum += vgetq_lane_f32(vpaddq_f32(reduced_tail, reduced_tail), 0 ); \
614+ \
615+ /* Second-tier tail fixup: handle all workloads. */ \
616+ for (int j = len_aligned_4; j < len; ++j) { \
617+ reduced_sum += vec1[j] * vec2[j]; \
618+ } \
619+ return reduced_sum
620+
621+ #if COMPILER_SUPPORTS_BF16_TARGET
622+ TARGET_ARM_BF16_ATTRIBUTE float
623+ dot_with_fp32_arith_bfdot (const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
624+ auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot (vec1, vec2, len);
625+ DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (_bfdot);
626+ }
627+ #endif // COMPILER_SUPPORTS_BF16_TARGET
628+
629+ template <typename T>
630+ C10_ALWAYS_INLINE float
631+ dot_with_fp32_arith_no_bfdot (const T* vec1, const T* vec2, int64_t len) {
632+ auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot (vec1, vec2, len);
633+ DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (_no_bfdot);
511634}
635+ #undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY
636+ } // namespace
512637
513638float fp16_dot_with_fp32_arith (const float16_t * vec1, const float16_t * vec2, int64_t len) {
514- return dot_with_fp32_arith (vec1, vec2, len);
639+ return dot_with_fp32_arith_no_bfdot (vec1, vec2, len);
515640}
516641
517642float bf16_dot_with_fp32_arith (const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
518- return dot_with_fp32_arith (vec1, vec2, len);
643+ #if COMPILER_SUPPORTS_BF16_TARGET
644+ if (cpuinfo_has_arm_bf16 ()) {
645+ return dot_with_fp32_arith_bfdot (vec1, vec2, len);
646+ } else
647+ #endif
648+ {
649+ return dot_with_fp32_arith_no_bfdot (vec1, vec2, len);
650+ }
519651}
520652
521653// On my Apple M1 Macbook (which is ARM v8.5 and thus has the
0 commit comments