diff --git a/include/fbgemm/FbgemmConvert.h b/include/fbgemm/FbgemmConvert.h index 298d539a9b..a20e56b0ad 100644 --- a/include/fbgemm/FbgemmConvert.h +++ b/include/fbgemm/FbgemmConvert.h @@ -1,5 +1,6 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright 2024 Arm Limited and/or its affiliates * All rights reserved. * * This source code is licensed under the BSD-style license found in the @@ -135,6 +136,26 @@ FBGEMM_API void FloatToFloat16_avx512( size_t size, bool do_clip = false); +/** + * @brief SVE implementation to convert fp32 numbers to fp16 numbers. + * + */ +FBGEMM_API void FloatToFloat16_sve( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); + +/** + * @brief NEON implementation to convert fp32 numbers to fp16 numbers. + * + */ +FBGEMM_API void FloatToFloat16_neon( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); + /** * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers. * @@ -149,6 +170,20 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size); FBGEMM_API void Float16ToFloat_avx512(const float16* src, float* dst, size_t size); +/** + * @brief SVE implementation to convert fp16 numbers to fp32 numbers. + * + */ +FBGEMM_API void +Float16ToFloat_sve(const float16* src, float* dst, size_t size); + +/** + * @brief NEON implementation to convert fp16 numbers to fp32 numbers. + * + */ +FBGEMM_API void +Float16ToFloat_neon(const float16* src, float* dst, size_t size); + /** * @brief Transform all entries in a matrix from fp32 to float16 and back to * fp32. diff --git a/src/FbgemmFloat16Convert.cc b/src/FbgemmFloat16Convert.cc index d2d3756038..483f5377d8 100644 --- a/src/FbgemmFloat16Convert.cc +++ b/src/FbgemmFloat16Convert.cc @@ -1,5 +1,6 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright 2024 Arm Limited and/or its affiliates * All rights reserved. * * This source code is licensed under the BSD-style license found in the @@ -39,10 +40,17 @@ void FloatToFloat16_simd( bool do_clip) { // Run time CPU detection if (cpuinfo_initialize()) { +#ifdef __aarch64__ + if (fbgemmHasArmSveSupport()) { + FloatToFloat16_sve(src, dst, size, do_clip); + } else if (fbgemmHasArmNeonSupport()) { + FloatToFloat16_neon(src, dst, size, do_clip); +#else if (fbgemmHasAvx512Support()) { FloatToFloat16_avx512(src, dst, size, do_clip); } else if (fbgemmHasAvx2Support()) { FloatToFloat16_avx2(src, dst, size, do_clip); +#endif } else { FloatToFloat16_ref(src, dst, size, do_clip); return; @@ -55,10 +63,17 @@ void FloatToFloat16_simd( void Float16ToFloat_simd(const float16* src, float* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { +#ifdef __aarch64__ + if (fbgemmHasArmSveSupport()) { + Float16ToFloat_sve(src, dst, size); + } else if (fbgemmHasArmNeonSupport()) { + Float16ToFloat_neon(src, dst, size); +#else if (fbgemmHasAvx512Support()) { Float16ToFloat_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { Float16ToFloat_avx2(src, dst, size); +#endif } else { Float16ToFloat_ref(src, dst, size); return; diff --git a/src/FbgemmFloat16ConvertNeon.cc b/src/FbgemmFloat16ConvertNeon.cc new file mode 100644 index 0000000000..ee175b6c2a --- /dev/null +++ b/src/FbgemmFloat16ConvertNeon.cc @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include +#define FBGEMM_EXPORTS +#include "fbgemm/FbgemmConvert.h" + +namespace fbgemm { + +void FloatToFloat16_neon( + const float* src, + float16* dst, + size_t size, + bool do_clip) { + if (do_clip) { + constexpr float FP16_MAX = 65504.f; + auto vpos = vdupq_n_f32(FP16_MAX); + auto vneg = vdupq_n_f32(-FP16_MAX); + size_t i = 0; + for (; i + 16 < size; i += 16) { + auto f32_vec1 = vld1q_f32(src + i); + auto f32_vec2 = vld1q_f32(src + i + 4); + auto f32_vec3 = vld1q_f32(src + i + 8); + auto f32_vec4 = vld1q_f32(src + i + 12); + f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg); + f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg); + f32_vec3 = vmaxq_f32(vminq_f32(f32_vec3, vpos), vneg); + f32_vec4 = vmaxq_f32(vminq_f32(f32_vec4, vpos), vneg); + auto f16_vec1 = vcvt_f16_f32(f32_vec1); + auto f16_vec2 = vcvt_f16_f32(f32_vec2); + auto f16_vec3 = vcvt_f16_f32(f32_vec3); + auto f16_vec4 = vcvt_f16_f32(f32_vec4); + vst1_f16((__fp16*)dst + i, f16_vec1); + vst1_f16((__fp16*)dst + i + 4, f16_vec2); + vst1_f16((__fp16*)dst + i + 8, f16_vec3); + vst1_f16((__fp16*)dst + i + 12, f16_vec4); + } + for (; i + 8 < size; i += 8) { + auto f32_vec1 = vld1q_f32(src + i); + auto f32_vec2 = vld1q_f32(src + i + 4); + f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg); + f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg); + auto f16_vec1 = vcvt_f16_f32(f32_vec1); + auto f16_vec2 = vcvt_f16_f32(f32_vec2); + vst1_f16((__fp16*)dst + i, f16_vec1); + vst1_f16((__fp16*)dst + i + 4, f16_vec2); + } + for (; i + 4 < size; i += 4) { + auto f32_vec = vld1q_f32(src + i); + f32_vec = vmaxq_f32(vminq_f32(f32_vec, vpos), vneg); + auto f16_vec = vcvt_f16_f32(f32_vec); + vst1_f16((__fp16*)dst + i, f16_vec); + } + FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); + } else { + size_t i = 0; + for (; i + 16 < size; i += 16) { + auto f32_vec1 = vld1q_f32(src + i); + auto f32_vec2 = vld1q_f32(src + i + 4); + auto f32_vec3 = vld1q_f32(src + i + 8); + auto f32_vec4 = vld1q_f32(src + i + 12); + auto f16_vec1 = vcvt_f16_f32(f32_vec1); + auto f16_vec2 = vcvt_f16_f32(f32_vec2); + auto f16_vec3 = vcvt_f16_f32(f32_vec3); + auto f16_vec4 = vcvt_f16_f32(f32_vec4); + vst1_f16((__fp16*)dst + i, f16_vec1); + vst1_f16((__fp16*)dst + i + 4, f16_vec2); + vst1_f16((__fp16*)dst + i + 8, f16_vec3); + vst1_f16((__fp16*)dst + i + 12, f16_vec4); + } + for (; i + 8 < size; i += 8) { + auto f32_vec1 = vld1q_f32(src + i); + auto f32_vec2 = vld1q_f32(src + i + 4); + auto f16_vec1 = vcvt_f16_f32(f32_vec1); + auto f16_vec2 = vcvt_f16_f32(f32_vec2); + vst1_f16((__fp16*)dst + i, f16_vec1); + vst1_f16((__fp16*)dst + i + 4, f16_vec2); + } + for (; i + 4 < size; i += 4) { + auto f32_vec = vld1q_f32(src + i); + auto f16_vec = vcvt_f16_f32(f32_vec); + vst1_f16((__fp16*)dst + i, f16_vec); + } + FloatToFloat16_ref(src + i, dst + i, size - i); + } +} + +void Float16ToFloat_neon(const float16* src, float* dst, size_t size) { + size_t i = 0; + for (; i + 16 < size; i += 16) { + auto f16_vec1 = vld1_f16((__fp16*)src + i); + auto f16_vec2 = vld1_f16((__fp16*)src + i + 4); + auto f16_vec3 = vld1_f16((__fp16*)src + i + 8); + auto f16_vec4 = vld1_f16((__fp16*)src + i + 12); + auto f32_vec1 = vcvt_f32_f16(f16_vec1); + auto f32_vec2 = vcvt_f32_f16(f16_vec2); + auto f32_vec3 = vcvt_f32_f16(f16_vec3); + auto f32_vec4 = vcvt_f32_f16(f16_vec4); + vst1q_f32(dst + i, f32_vec1); + vst1q_f32(dst + i + 4, f32_vec2); + vst1q_f32(dst + i + 8, f32_vec3); + vst1q_f32(dst + i + 12, f32_vec4); + } + for (; i + 8 < size; i += 8) { + auto f16_vec1 = vld1_f16((__fp16*)src + i); + auto f16_vec2 = vld1_f16((__fp16*)src + i + 4); + auto f32_vec1 = vcvt_f32_f16(f16_vec1); + auto f32_vec2 = vcvt_f32_f16(f16_vec2); + vst1q_f32(dst + i, f32_vec1); + vst1q_f32(dst + i + 4, f32_vec2); + } + for (; i + 4 < size; i += 4) { + auto f16_vec = vld1_f16((__fp16*)src + i); + auto f32_vec = vcvt_f32_f16(f16_vec); + vst1q_f32(dst + i, f32_vec); + } + Float16ToFloat_ref(src + i, dst + i, size - i); +} + +} // namespace fbgemm diff --git a/src/FbgemmFloat16ConvertSve.cc b/src/FbgemmFloat16ConvertSve.cc new file mode 100644 index 0000000000..7ecdc23c6e --- /dev/null +++ b/src/FbgemmFloat16ConvertSve.cc @@ -0,0 +1,99 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include +#define FBGEMM_EXPORTS +#include "fbgemm/FbgemmConvert.h" + +namespace fbgemm { + +void FloatToFloat16_sve( + const float* src, + float16* dst, + size_t size, + bool do_clip) { + if (do_clip) { + constexpr float FP16_MAX = 65504.f; + size_t i = 0; + int lanes = svcntw(); + auto p_32 = svptrue_b32(); + auto p_16 = svptrue_b16(); + auto pfalse = svpfalse(); + auto p_16_half = svuzp1_b16(p_16, pfalse); + while (i + 2 * lanes < size) { + auto f32_vec1 = svld1_f32(p_32, src + i); + auto f32_vec2 = svld1_f32(p_32, src + i + lanes); + f32_vec1 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec1, FP16_MAX), -FP16_MAX); + f32_vec2 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec2, FP16_MAX), -FP16_MAX); + auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1); + auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2); + auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2); + svst1_f16(p_16, (__fp16*)dst + i, f16_vec); + i += 2 * lanes; + } + while (i + lanes < size) { + auto f32_vec = svld1_f32(p_32, src + i); + f32_vec = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec, FP16_MAX), -FP16_MAX); + auto f16_vec = svcvt_f16_f32_x(p_16, f32_vec); + f16_vec = svuzp1_f16(f16_vec, f16_vec); + svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec); + i += lanes; + } + FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); + } else { + size_t i = 0; + int lanes = svcntw(); + auto p_32 = svptrue_b32(); + auto p_16 = svptrue_b16(); + auto pfalse = svpfalse(); + auto p_16_half = svuzp1_b16(p_16, pfalse); + while (i + 2 * lanes < size) { + auto f32_vec1 = svld1_f32(p_32, src + i); + auto f32_vec2 = svld1_f32(p_32, src + i + lanes); + auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1); + auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2); + auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2); + svst1_f16(p_16, (__fp16*)dst + i, f16_vec); + i += 2 * lanes; + } + while (i + lanes < size) { + auto f32_vec = svld1_f32(p_32, src + i); + auto f16_vec = svcvt_f16_f32_x(p_32, f32_vec); + f16_vec = svuzp1_f16(f16_vec, f16_vec); + svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec); + i += lanes; + } + FloatToFloat16_ref(src + i, dst + i, size - i); + } +} + +void Float16ToFloat_sve(const float16* src, float* dst, size_t size) { + size_t i = 0; + int lanes = svcntw(); + auto p_32 = svptrue_b32(); + auto p_16 = svptrue_b16(); + auto pfalse = svpfalse(); + auto p_16_half = svuzp1_b16(p_16, pfalse); + while (i + 2 * lanes < size) { + auto f16_vec = svld1_f16(p_16, (__fp16*)src + i); + auto f16_vec1 = svzip1(f16_vec, f16_vec); + auto f16_vec2 = svzip2(f16_vec, f16_vec); + auto f32_vec1 = svcvt_f32_f16_x(p_16, f16_vec1); + auto f32_vec2 = svcvt_f32_f16_x(p_16, f16_vec2); + svst1_f32(p_32, dst + i, f32_vec1); + svst1_f32(p_32, dst + i + lanes, f32_vec2); + i += 2 * lanes; + } + while (i + lanes < size) { + auto f16_vec = svld1_f16(p_16_half, (__fp16*)src + i); + f16_vec = svzip1_f16(f16_vec, f16_vec); + auto f32_vec = svcvt_f32_f16_x(p_32, f16_vec); + svst1_f32(p_32, dst + i, f32_vec); + i += lanes; + } + Float16ToFloat_ref(src + i, dst + i, size - i); +} + +} // namespace fbgemm