Skip to content

Commit 9c12198

Browse files
swolchokpytorchmergebot
authored andcommitted
[PyTorch] Port ExecuTorch bfdot improvement back to ATen BlasKernel, Try #2 (pytorch#137377)
ExecuTorch's fork of BlasKernel.cpp grew bfdot support, complete with demonstration that it helps. Port it back to PyTorch. First attempt was pytorch#136331 . Differential Revision: [D63923166](https://our.internmc.facebook.com/intern/diff/D63923166/) Pull Request resolved: pytorch#137377 Approved by: https://github.com/malfet
1 parent 080f02a commit 9c12198

File tree

1 file changed

+187
-55
lines changed

1 file changed

+187
-55
lines changed

aten/src/ATen/native/BlasKernel.cpp

Lines changed: 187 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#if defined(__aarch64__) && !defined(C10_MOBILE)
1717
#include <arm_neon.h>
18+
#include <cpuinfo.h>
1819
#endif
1920

2021
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
@@ -301,7 +302,7 @@ static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationS
301302
static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift;
302303
static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister);
303304

304-
static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
305+
static inline float reduce(float16x8_t x[kF16RegistersPerIteration]) {
305306
int offset = kF16RegistersPerIteration;
306307
c10::ForcedUnroll<kF16RegistersPerIterationShift>{}([&offset, &x](auto idx) {
307308
offset /= 2;
@@ -311,7 +312,7 @@ static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
311312
});
312313
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0]));
313314
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0]));
314-
return (double)vaddvq_f32(vaddq_f32(t0, t1));
315+
return vaddvq_f32(vaddq_f32(t0, t1));
315316
}
316317

317318
static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
@@ -333,12 +334,12 @@ static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, in
333334
sum[k] = f16_fma(sum[k], temp_x, temp_a);
334335
}
335336
}
336-
auto reducedSum = reduce(sum);
337+
auto reduced_sum = reduce(sum);
337338

338339
for (int j = len_aligned; j < len; ++j) {
339-
reducedSum += x[j] * a[j];
340+
reduced_sum += x[j] * a[j];
340341
}
341-
return reducedSum;
342+
return reduced_sum;
342343
}
343344

344345
// Rather than unrolling to process multiple rows (transposed columns)
@@ -352,7 +353,7 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n,
352353
});
353354
}
354355

355-
#endif
356+
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
356357

357358
static inline float reduce(float32x4_t x) {
358359
auto sum = vpaddq_f32(x, x);
@@ -412,7 +413,7 @@ static constexpr auto kF32RegistersPerIterationShift = 3;
412413
static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister);
413414
static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
414415

415-
static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
416+
static inline float reduce(float32x4_t x[kF32RegistersPerIteration]) {
416417
int offset = kF32RegistersPerIteration;
417418
c10::ForcedUnroll<kF32RegistersPerIterationShift>{}([&offset, &x](auto idx) {
418419
offset /= 2;
@@ -423,7 +424,7 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
423424
return vaddvq_f32(x[0]);
424425
}
425426

426-
static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
427+
static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
427428
const float16_t* vec1,
428429
const float16_t* vec2,
429430
float32x4_t sum[kF32RegistersPerIteration],
@@ -436,86 +437,217 @@ static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
436437
sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2);
437438
}
438439

439-
static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
440-
const float16_t* vec1,
441-
const float16_t* vec2,
442-
float32x4_t* tailSum,
443-
int idx) {
440+
static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
441+
const float16_t* vec1,
442+
const float16_t* vec2,
443+
float32x4_t* tail_sum,
444+
int idx) {
444445
const auto temp_vec1 = vld1_f16(&vec1[idx]);
445446
const auto temp_vec2 = vld1_f16(&vec2[idx]);
446-
*tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2);
447+
*tail_sum = f32_fma_f16(*tail_sum, temp_vec1, temp_vec2);
447448
}
448449

449-
static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
450+
static float32x4_t to_bfloat16(uint16x4_t u16) {
450451
int32x4_t shift = vdupq_n_s32(16);
451452
return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
452453
}
453454

454-
static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
455+
static float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
455456
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
456457
}
457458

458-
static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
459-
const at::BFloat16* vec1,
460-
const at::BFloat16* vec2,
461-
float32x4_t sum[kF32RegistersPerIteration],
462-
int registerPairIndex) {
463-
// TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16
464-
// Load a pair of f32 registers at a time.
465-
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
466-
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
459+
#if defined(__clang__) && __clang_major__ > 15
460+
// https://godbolt.org/z/z8P4Yncra
461+
#define COMPILER_SUPPORTS_BF16_TARGET 1
462+
#elif !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
463+
// https://gcc.gnu.org/gcc-10/changes.html
464+
// https://godbolt.org/z/cdGG7vn8o
465+
#define COMPILER_SUPPORTS_BF16_TARGET 1
466+
#else
467+
#define COMPILER_SUPPORTS_BF16_TARGET 0
468+
#endif
469+
470+
#if COMPILER_SUPPORTS_BF16_TARGET
471+
#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16")))
472+
473+
TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE float32x4_t
474+
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
475+
return vbfdotq_f32(a, b, c);
476+
}
467477

468-
sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2));
469-
sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2));
478+
TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE void
479+
dot_with_fp32_arith_main_inner_loop_bfdot(
480+
const BFloat16* vec1,
481+
const BFloat16* vec2,
482+
float32x4_t sum[kF32RegistersPerIteration],
483+
int registerPairIndex) {
484+
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
485+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
486+
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
487+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
488+
sum[registerPairIndex] =
489+
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
470490
}
471491

472-
static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
473-
const at::BFloat16* vec1,
474-
const at::BFloat16* vec2,
475-
float32x4_t* tailSum,
476-
int idx) {
492+
// See NOTE [GCC code duplication] below for why we have _bfdot and
493+
// _no_bfdot versions of
494+
// dot_with_fp32_arith_vectorized_tail_inner_loop.
495+
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
496+
static void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
497+
const at::BFloat16* vec1,
498+
const at::BFloat16* vec2,
499+
float32x4_t* tail_sum,
500+
int idx) {
477501
const auto temp_vec1 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
478502
const auto temp_vec2 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
479-
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
503+
*tail_sum = f32_fma_bf16(*tail_sum, temp_vec1, temp_vec2);
480504
}
481505

482-
template <typename T>
483-
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
506+
#else
507+
#define TARGET_ARM_BF16_ATTRIBUTE
508+
#endif // COMPILER_SUPPORTS_BF16_TARGET
509+
510+
static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
511+
const BFloat16* vec1,
512+
const BFloat16* vec2,
513+
float32x4_t sum[kF32RegistersPerIteration],
514+
int registerPairIndex) {
515+
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
516+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
517+
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
518+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
519+
520+
sum[2 * registerPairIndex] = f32_fma_bf16(
521+
sum[2 * registerPairIndex],
522+
vget_low_u16(temp_vec1),
523+
vget_low_u16(temp_vec2));
524+
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
525+
sum[2 * registerPairIndex + 1],
526+
vget_high_u16(temp_vec1),
527+
vget_high_u16(temp_vec2));
528+
}
529+
530+
static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
531+
const at::BFloat16* vec1,
532+
const at::BFloat16* vec2,
533+
float32x4_t* tail_sum,
534+
int idx) {
535+
const auto temp_vec1 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
536+
const auto temp_vec2 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
537+
*tail_sum = f32_fma_bf16(*tail_sum, temp_vec1, temp_vec2);
538+
}
539+
540+
namespace {
541+
#if COMPILER_SUPPORTS_BF16_TARGET
542+
template <int n>
543+
struct ForcedUnrollTargetBFloat16 {
544+
template <typename Func>
545+
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
546+
ForcedUnrollTargetBFloat16<n - 1>{}(f);
547+
f(n - 1);
548+
}
549+
};
550+
551+
template <>
552+
struct ForcedUnrollTargetBFloat16<1> {
553+
template <typename Func>
554+
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
555+
f(0);
556+
}
557+
};
558+
559+
C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto
560+
dot_with_fp32_arith_main_loop_bfdot(
561+
const BFloat16* vec1,
562+
const BFloat16* vec2,
563+
int64_t len) {
484564
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
485565
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
486566
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
487567
const auto* vec1_ = vec1 + j;
488568
const auto* vec2_ = vec2 + j;
489-
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) {
490-
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
569+
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k)
570+
C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
571+
dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k);
491572
});
492573
}
493-
auto reducedSum = reduce(sum);
494-
495-
// First-tier tail fixup: make sure we handle workloads that can
496-
// benefit from vectorization, but don't fit into our fully unrolled
497-
// loop above.
498-
float32x4_t tailSum = vdupq_n_f32(0);
499-
const auto len_aligned_4 = len & ~3;
500-
for (int j = len_aligned; j < len_aligned_4; j += 4) {
501-
dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
502-
}
503-
auto reducedTail = vpaddq_f32(tailSum, tailSum);
504-
reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
574+
return reduce(sum);
575+
}
576+
#endif // COMPILER_SUPPORTS_BF16_TARGET
505577

506-
// Second-tier tail fixup: handle all workloads.
507-
for (int j = len_aligned_4; j < len; ++j) {
508-
reducedSum += vec1[j] * vec2[j];
578+
template <typename T>
579+
C10_ALWAYS_INLINE auto
580+
dot_with_fp32_arith_main_loop_no_bfdot(
581+
const T* vec1,
582+
const T* vec2,
583+
int64_t len) {
584+
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
585+
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
586+
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
587+
const auto* vec1_ = vec1 + j;
588+
const auto* vec2_ = vec2 + j;
589+
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
590+
dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k);
591+
});
509592
}
510-
return reducedSum;
593+
return reduce(sum);
594+
}
595+
596+
// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
597+
// TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not
598+
// allow inlining a non-bf16-specific function into a bf16-specific
599+
// function. We can work around this by duplicating the code into the
600+
// bfdot and non-bfdot callsites. The code is in this macro to avoid
601+
// actual copy/paste.
602+
#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \
603+
/* First-tier tail fixup: make sure we handle workloads that can */ \
604+
/* benefit from vectorization, but don't fit into our fully unrolled */ \
605+
/* loop above. */ \
606+
float32x4_t tail_sum = vdupq_n_f32(0); \
607+
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \
608+
const auto len_aligned_4 = len & ~3; \
609+
for (int j = len_aligned; j < len_aligned_4; j += 4) { \
610+
dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \
611+
} \
612+
auto reduced_tail = vpaddq_f32(tail_sum, tail_sum); \
613+
reduced_sum += vgetq_lane_f32(vpaddq_f32(reduced_tail, reduced_tail), 0); \
614+
\
615+
/* Second-tier tail fixup: handle all workloads. */ \
616+
for (int j = len_aligned_4; j < len; ++j) { \
617+
reduced_sum += vec1[j] * vec2[j]; \
618+
} \
619+
return reduced_sum
620+
621+
#if COMPILER_SUPPORTS_BF16_TARGET
622+
TARGET_ARM_BF16_ATTRIBUTE float
623+
dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
624+
auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len);
625+
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot);
626+
}
627+
#endif // COMPILER_SUPPORTS_BF16_TARGET
628+
629+
template <typename T>
630+
C10_ALWAYS_INLINE float
631+
dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) {
632+
auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot(vec1, vec2, len);
633+
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_no_bfdot);
511634
}
635+
#undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY
636+
} // namespace
512637

513638
float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) {
514-
return dot_with_fp32_arith(vec1, vec2, len);
639+
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
515640
}
516641

517642
float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
518-
return dot_with_fp32_arith(vec1, vec2, len);
643+
#if COMPILER_SUPPORTS_BF16_TARGET
644+
if (cpuinfo_has_arm_bf16()) {
645+
return dot_with_fp32_arith_bfdot(vec1, vec2, len);
646+
} else
647+
#endif
648+
{
649+
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
650+
}
519651
}
520652

521653
// On my Apple M1 Macbook (which is ARM v8.5 and thus has the

0 commit comments

Comments
 (0)