-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NEON and SVE implementations for Float16 conversions
- Loading branch information
Showing
5 changed files
with
277 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <open-source-office@arm.com> | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <arm_neon.h> | ||
#define FBGEMM_EXPORTS | ||
#include "fbgemm/FbgemmConvert.h" | ||
|
||
namespace fbgemm { | ||
|
||
void FloatToFloat16_neon( | ||
const float* src, | ||
float16* dst, | ||
size_t size, | ||
bool do_clip) { | ||
if (do_clip) { | ||
constexpr float FP16_MAX = 65504.f; | ||
auto vpos = vdupq_n_f32(FP16_MAX); | ||
auto vneg = vdupq_n_f32(-FP16_MAX); | ||
size_t i = 0; | ||
for (; i + 16 < size; i += 16) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
auto f32_vec3 = vld1q_f32(src + i + 8); | ||
auto f32_vec4 = vld1q_f32(src + i + 12); | ||
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg); | ||
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg); | ||
f32_vec3 = vmaxq_f32(vminq_f32(f32_vec3, vpos), vneg); | ||
f32_vec4 = vmaxq_f32(vminq_f32(f32_vec4, vpos), vneg); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
auto f16_vec3 = vcvt_f16_f32(f32_vec3); | ||
auto f16_vec4 = vcvt_f16_f32(f32_vec4); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
vst1_f16((__fp16*)dst + i + 8, f16_vec3); | ||
vst1_f16((__fp16*)dst + i + 12, f16_vec4); | ||
} | ||
for (; i + 8 < size; i += 8) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg); | ||
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
} | ||
for (; i + 4 < size; i += 4) { | ||
auto f32_vec = vld1q_f32(src + i); | ||
f32_vec = vmaxq_f32(vminq_f32(f32_vec, vpos), vneg); | ||
auto f16_vec = vcvt_f16_f32(f32_vec); | ||
vst1_f16((__fp16*)dst + i, f16_vec); | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); | ||
} else { | ||
size_t i = 0; | ||
for (; i + 16 < size; i += 16) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
auto f32_vec3 = vld1q_f32(src + i + 8); | ||
auto f32_vec4 = vld1q_f32(src + i + 12); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
auto f16_vec3 = vcvt_f16_f32(f32_vec3); | ||
auto f16_vec4 = vcvt_f16_f32(f32_vec4); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
vst1_f16((__fp16*)dst + i + 8, f16_vec3); | ||
vst1_f16((__fp16*)dst + i + 12, f16_vec4); | ||
} | ||
for (; i + 8 < size; i += 8) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
} | ||
for (; i + 4 < size; i += 4) { | ||
auto f32_vec = vld1q_f32(src + i); | ||
auto f16_vec = vcvt_f16_f32(f32_vec); | ||
vst1_f16((__fp16*)dst + i, f16_vec); | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i); | ||
} | ||
} | ||
|
||
void Float16ToFloat_neon(const float16* src, float* dst, size_t size) { | ||
size_t i = 0; | ||
for (; i + 16 < size; i += 16) { | ||
auto f16_vec1 = vld1_f16((__fp16*)src + i); | ||
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4); | ||
auto f16_vec3 = vld1_f16((__fp16*)src + i + 8); | ||
auto f16_vec4 = vld1_f16((__fp16*)src + i + 12); | ||
auto f32_vec1 = vcvt_f32_f16(f16_vec1); | ||
auto f32_vec2 = vcvt_f32_f16(f16_vec2); | ||
auto f32_vec3 = vcvt_f32_f16(f16_vec3); | ||
auto f32_vec4 = vcvt_f32_f16(f16_vec4); | ||
vst1q_f32(dst + i, f32_vec1); | ||
vst1q_f32(dst + i + 4, f32_vec2); | ||
vst1q_f32(dst + i + 8, f32_vec3); | ||
vst1q_f32(dst + i + 12, f32_vec4); | ||
} | ||
for (; i + 8 < size; i += 8) { | ||
auto f16_vec1 = vld1_f16((__fp16*)src + i); | ||
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4); | ||
auto f32_vec1 = vcvt_f32_f16(f16_vec1); | ||
auto f32_vec2 = vcvt_f32_f16(f16_vec2); | ||
vst1q_f32(dst + i, f32_vec1); | ||
vst1q_f32(dst + i + 4, f32_vec2); | ||
} | ||
for (; i + 4 < size; i += 4) { | ||
auto f16_vec = vld1_f16((__fp16*)src + i); | ||
auto f32_vec = vcvt_f32_f16(f16_vec); | ||
vst1q_f32(dst + i, f32_vec); | ||
} | ||
Float16ToFloat_ref(src + i, dst + i, size - i); | ||
} | ||
|
||
} // namespace fbgemm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <open-source-office@arm.com> | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <arm_sve.h> | ||
#define FBGEMM_EXPORTS | ||
#include "fbgemm/FbgemmConvert.h" | ||
|
||
namespace fbgemm { | ||
|
||
void FloatToFloat16_sve( | ||
const float* src, | ||
float16* dst, | ||
size_t size, | ||
bool do_clip) { | ||
if (do_clip) { | ||
constexpr float FP16_MAX = 65504.f; | ||
size_t i = 0; | ||
int lanes = svcntw(); | ||
auto p_32 = svptrue_b32(); | ||
auto p_16 = svptrue_b16(); | ||
auto pfalse = svpfalse(); | ||
auto p_16_half = svuzp1_b16(p_16, pfalse); | ||
while (i + 2 * lanes < size) { | ||
auto f32_vec1 = svld1_f32(p_32, src + i); | ||
auto f32_vec2 = svld1_f32(p_32, src + i + lanes); | ||
f32_vec1 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec1, FP16_MAX), -FP16_MAX); | ||
f32_vec2 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec2, FP16_MAX), -FP16_MAX); | ||
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1); | ||
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2); | ||
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2); | ||
svst1_f16(p_16, (__fp16*)dst + i, f16_vec); | ||
i += 2 * lanes; | ||
} | ||
while (i + lanes < size) { | ||
auto f32_vec = svld1_f32(p_32, src + i); | ||
f32_vec = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec, FP16_MAX), -FP16_MAX); | ||
auto f16_vec = svcvt_f16_f32_x(p_16, f32_vec); | ||
f16_vec = svuzp1_f16(f16_vec, f16_vec); | ||
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec); | ||
i += lanes; | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); | ||
} else { | ||
size_t i = 0; | ||
int lanes = svcntw(); | ||
auto p_32 = svptrue_b32(); | ||
auto p_16 = svptrue_b16(); | ||
auto pfalse = svpfalse(); | ||
auto p_16_half = svuzp1_b16(p_16, pfalse); | ||
while (i + 2 * lanes < size) { | ||
auto f32_vec1 = svld1_f32(p_32, src + i); | ||
auto f32_vec2 = svld1_f32(p_32, src + i + lanes); | ||
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1); | ||
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2); | ||
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2); | ||
svst1_f16(p_16, (__fp16*)dst + i, f16_vec); | ||
i += 2 * lanes; | ||
} | ||
while (i + lanes < size) { | ||
auto f32_vec = svld1_f32(p_32, src + i); | ||
auto f16_vec = svcvt_f16_f32_x(p_32, f32_vec); | ||
f16_vec = svuzp1_f16(f16_vec, f16_vec); | ||
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec); | ||
i += lanes; | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i); | ||
} | ||
} | ||
|
||
void Float16ToFloat_sve(const float16* src, float* dst, size_t size) { | ||
size_t i = 0; | ||
int lanes = svcntw(); | ||
auto p_32 = svptrue_b32(); | ||
auto p_16 = svptrue_b16(); | ||
auto pfalse = svpfalse(); | ||
auto p_16_half = svuzp1_b16(p_16, pfalse); | ||
while (i + 2 * lanes < size) { | ||
auto f16_vec = svld1_f16(p_16, (__fp16*)src + i); | ||
auto f16_vec1 = svzip1(f16_vec, f16_vec); | ||
auto f16_vec2 = svzip2(f16_vec, f16_vec); | ||
auto f32_vec1 = svcvt_f32_f16_x(p_16, f16_vec1); | ||
auto f32_vec2 = svcvt_f32_f16_x(p_16, f16_vec2); | ||
svst1_f32(p_32, dst + i, f32_vec1); | ||
svst1_f32(p_32, dst + i + lanes, f32_vec2); | ||
i += 2 * lanes; | ||
} | ||
while (i + lanes < size) { | ||
auto f16_vec = svld1_f16(p_16_half, (__fp16*)src + i); | ||
f16_vec = svzip1_f16(f16_vec, f16_vec); | ||
auto f32_vec = svcvt_f32_f16_x(p_32, f16_vec); | ||
svst1_f32(p_32, dst + i, f32_vec); | ||
i += lanes; | ||
} | ||
Float16ToFloat_ref(src + i, dst + i, size - i); | ||
} | ||
|
||
} // namespace fbgemm |