Skip to content

Commit 6d4228c

Browse files
committed
Add NEON and SVE implementations for Float16 conversions
1 parent 357b54c commit 6d4228c

4 files changed

+271
-0
lines changed

include/fbgemm/FbgemmConvert.h

+35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
34
* All rights reserved.
45
*
56
* This source code is licensed under the BSD-style license found in the
@@ -135,6 +136,26 @@ FBGEMM_API void FloatToFloat16_avx512(
135136
size_t size,
136137
bool do_clip = false);
137138

139+
/**
140+
* @brief SVE implementation to convert fp32 numbers to fp16 numbers.
141+
*
142+
*/
143+
FBGEMM_API void FloatToFloat16_sve(
144+
const float* src,
145+
float16* dst,
146+
size_t size,
147+
bool do_clip = false);
148+
149+
/**
150+
* @brief NEON implementation to convert fp32 numbers to fp16 numbers.
151+
*
152+
*/
153+
FBGEMM_API void FloatToFloat16_neon(
154+
const float* src,
155+
float16* dst,
156+
size_t size,
157+
bool do_clip = false);
158+
138159
/**
139160
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
140161
*
@@ -149,6 +170,20 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
149170
FBGEMM_API void
150171
Float16ToFloat_avx512(const float16* src, float* dst, size_t size);
151172

173+
/**
174+
* @brief SVE implementation to convert fp16 numbers to fp32 numbers.
175+
*
176+
*/
177+
FBGEMM_API void
178+
Float16ToFloat_sve(const float16* src, float* dst, size_t size);
179+
180+
/**
181+
* @brief NEON implementation to convert fp16 numbers to fp32 numbers.
182+
*
183+
*/
184+
FBGEMM_API void
185+
Float16ToFloat_neon(const float16* src, float* dst, size_t size);
186+
152187
/**
153188
* @brief Transform all entries in a matrix from fp32 to float16 and back to
154189
* fp32.

src/FbgemmFloat16Convert.cc

+15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
34
* All rights reserved.
45
*
56
* This source code is licensed under the BSD-style license found in the
@@ -39,10 +40,17 @@ void FloatToFloat16_simd(
3940
bool do_clip) {
4041
// Run time CPU detection
4142
if (cpuinfo_initialize()) {
43+
#ifdef __aarch64__
44+
if (fbgemmHasArmSveSupport()) {
45+
FloatToFloat16_sve(src, dst, size, do_clip);
46+
} else if (fbgemmHasArmNeonSupport()) {
47+
FloatToFloat16_neon(src, dst, size, do_clip);
48+
#else
4249
if (fbgemmHasAvx512Support()) {
4350
FloatToFloat16_avx512(src, dst, size, do_clip);
4451
} else if (fbgemmHasAvx2Support()) {
4552
FloatToFloat16_avx2(src, dst, size, do_clip);
53+
#endif
4654
} else {
4755
FloatToFloat16_ref(src, dst, size, do_clip);
4856
return;
@@ -55,10 +63,17 @@ void FloatToFloat16_simd(
5563
void Float16ToFloat_simd(const float16* src, float* dst, size_t size) {
5664
// Run time CPU detection
5765
if (cpuinfo_initialize()) {
66+
#ifdef __aarch64__
67+
if (fbgemmHasArmSveSupport()) {
68+
Float16ToFloat_sve(src, dst, size);
69+
} else if (fbgemmHasArmNeonSupport()) {
70+
Float16ToFloat_neon(src, dst, size);
71+
#else
5872
if (fbgemmHasAvx512Support()) {
5973
Float16ToFloat_avx512(src, dst, size);
6074
} else if (fbgemmHasAvx2Support()) {
6175
Float16ToFloat_avx2(src, dst, size);
76+
#endif
6277
} else {
6378
Float16ToFloat_ref(src, dst, size);
6479
return;

src/FbgemmFloat16ConvertNeon.cc

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <open-source-office@arm.com>
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*/
5+
6+
#include <arm_neon.h>
7+
#define FBGEMM_EXPORTS
8+
#include "fbgemm/FbgemmConvert.h"
9+
10+
namespace fbgemm {
11+
12+
void FloatToFloat16_neon(
13+
const float* src,
14+
float16* dst,
15+
size_t size,
16+
bool do_clip) {
17+
if (do_clip) {
18+
constexpr float FP16_MAX = 65504.f;
19+
auto vpos = vdupq_n_f32(FP16_MAX);
20+
auto vneg = vdupq_n_f32(-FP16_MAX);
21+
size_t i = 0;
22+
for (; i + 16 < size; i += 16) {
23+
auto f32_vec1 = vld1q_f32(src + i);
24+
auto f32_vec2 = vld1q_f32(src + i + 4);
25+
auto f32_vec3 = vld1q_f32(src + i + 8);
26+
auto f32_vec4 = vld1q_f32(src + i + 12);
27+
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg);
28+
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg);
29+
f32_vec3 = vmaxq_f32(vminq_f32(f32_vec3, vpos), vneg);
30+
f32_vec4 = vmaxq_f32(vminq_f32(f32_vec4, vpos), vneg);
31+
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
32+
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
33+
auto f16_vec3 = vcvt_f16_f32(f32_vec3);
34+
auto f16_vec4 = vcvt_f16_f32(f32_vec4);
35+
vst1_f16((__fp16*)dst + i, f16_vec1);
36+
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
37+
vst1_f16((__fp16*)dst + i + 8, f16_vec3);
38+
vst1_f16((__fp16*)dst + i + 12, f16_vec4);
39+
}
40+
for (; i + 8 < size; i += 8) {
41+
auto f32_vec1 = vld1q_f32(src + i);
42+
auto f32_vec2 = vld1q_f32(src + i + 4);
43+
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg);
44+
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg);
45+
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
46+
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
47+
vst1_f16((__fp16*)dst + i, f16_vec1);
48+
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
49+
}
50+
for (; i + 4 < size; i += 4) {
51+
auto f32_vec = vld1q_f32(src + i);
52+
f32_vec = vmaxq_f32(vminq_f32(f32_vec, vpos), vneg);
53+
auto f16_vec = vcvt_f16_f32(f32_vec);
54+
vst1_f16((__fp16*)dst + i, f16_vec);
55+
}
56+
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
57+
} else {
58+
size_t i = 0;
59+
for (; i + 16 < size; i += 16) {
60+
auto f32_vec1 = vld1q_f32(src + i);
61+
auto f32_vec2 = vld1q_f32(src + i + 4);
62+
auto f32_vec3 = vld1q_f32(src + i + 8);
63+
auto f32_vec4 = vld1q_f32(src + i + 12);
64+
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
65+
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
66+
auto f16_vec3 = vcvt_f16_f32(f32_vec3);
67+
auto f16_vec4 = vcvt_f16_f32(f32_vec4);
68+
vst1_f16((__fp16*)dst + i, f16_vec1);
69+
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
70+
vst1_f16((__fp16*)dst + i + 8, f16_vec3);
71+
vst1_f16((__fp16*)dst + i + 12, f16_vec4);
72+
}
73+
for (; i + 8 < size; i += 8) {
74+
auto f32_vec1 = vld1q_f32(src + i);
75+
auto f32_vec2 = vld1q_f32(src + i + 4);
76+
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
77+
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
78+
vst1_f16((__fp16*)dst + i, f16_vec1);
79+
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
80+
}
81+
for (; i + 4 < size; i += 4) {
82+
auto f32_vec = vld1q_f32(src + i);
83+
auto f16_vec = vcvt_f16_f32(f32_vec);
84+
vst1_f16((__fp16*)dst + i, f16_vec);
85+
}
86+
FloatToFloat16_ref(src + i, dst + i, size - i);
87+
}
88+
}
89+
90+
void Float16ToFloat_neon(const float16* src, float* dst, size_t size) {
91+
size_t i = 0;
92+
for (; i + 16 < size; i += 16) {
93+
auto f16_vec1 = vld1_f16((__fp16*)src + i);
94+
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4);
95+
auto f16_vec3 = vld1_f16((__fp16*)src + i + 8);
96+
auto f16_vec4 = vld1_f16((__fp16*)src + i + 12);
97+
auto f32_vec1 = vcvt_f32_f16(f16_vec1);
98+
auto f32_vec2 = vcvt_f32_f16(f16_vec2);
99+
auto f32_vec3 = vcvt_f32_f16(f16_vec3);
100+
auto f32_vec4 = vcvt_f32_f16(f16_vec4);
101+
vst1q_f32(dst + i, f32_vec1);
102+
vst1q_f32(dst + i + 4, f32_vec2);
103+
vst1q_f32(dst + i + 8, f32_vec3);
104+
vst1q_f32(dst + i + 12, f32_vec4);
105+
}
106+
for (; i + 8 < size; i += 8) {
107+
auto f16_vec1 = vld1_f16((__fp16*)src + i);
108+
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4);
109+
auto f32_vec1 = vcvt_f32_f16(f16_vec1);
110+
auto f32_vec2 = vcvt_f32_f16(f16_vec2);
111+
vst1q_f32(dst + i, f32_vec1);
112+
vst1q_f32(dst + i + 4, f32_vec2);
113+
}
114+
for (; i + 4 < size; i += 4) {
115+
auto f16_vec = vld1_f16((__fp16*)src + i);
116+
auto f32_vec = vcvt_f32_f16(f16_vec);
117+
vst1q_f32(dst + i, f32_vec);
118+
}
119+
Float16ToFloat_ref(src + i, dst + i, size - i);
120+
}
121+
122+
} // namespace fbgemm

src/FbgemmFloat16ConvertSve.cc

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <open-source-office@arm.com>
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*/
5+
6+
#include <arm_sve.h>
7+
#define FBGEMM_EXPORTS
8+
#include "fbgemm/FbgemmConvert.h"
9+
10+
namespace fbgemm {
11+
12+
void FloatToFloat16_sve(
13+
const float* src,
14+
float16* dst,
15+
size_t size,
16+
bool do_clip) {
17+
if (do_clip) {
18+
constexpr float FP16_MAX = 65504.f;
19+
size_t i = 0;
20+
int lanes = svcntw();
21+
auto p_32 = svptrue_b32();
22+
auto p_16 = svptrue_b16();
23+
auto pfalse = svpfalse();
24+
auto p_16_half = svuzp1_b16(p_16, pfalse);
25+
while (i + 2 * lanes < size) {
26+
auto f32_vec1 = svld1_f32(p_32, src + i);
27+
auto f32_vec2 = svld1_f32(p_32, src + i + lanes);
28+
f32_vec1 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec1, FP16_MAX), -FP16_MAX);
29+
f32_vec2 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec2, FP16_MAX), -FP16_MAX);
30+
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1);
31+
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2);
32+
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2);
33+
svst1_f16(p_16, (__fp16*)dst + i, f16_vec);
34+
i += 2 * lanes;
35+
}
36+
while (i + lanes < size) {
37+
auto f32_vec = svld1_f32(p_32, src + i);
38+
f32_vec = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec, FP16_MAX), -FP16_MAX);
39+
auto f16_vec = svcvt_f16_f32_x(p_16, f32_vec);
40+
f16_vec = svuzp1_f16(f16_vec, f16_vec);
41+
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec);
42+
i += lanes;
43+
}
44+
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
45+
} else {
46+
size_t i = 0;
47+
int lanes = svcntw();
48+
auto p_32 = svptrue_b32();
49+
auto p_16 = svptrue_b16();
50+
auto pfalse = svpfalse();
51+
auto p_16_half = svuzp1_b16(p_16, pfalse);
52+
while (i + 2 * lanes < size) {
53+
auto f32_vec1 = svld1_f32(p_32, src + i);
54+
auto f32_vec2 = svld1_f32(p_32, src + i + lanes);
55+
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1);
56+
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2);
57+
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2);
58+
svst1_f16(p_16, (__fp16*)dst + i, f16_vec);
59+
i += 2 * lanes;
60+
}
61+
while (i + lanes < size) {
62+
auto f32_vec = svld1_f32(p_32, src + i);
63+
auto f16_vec = svcvt_f16_f32_x(p_32, f32_vec);
64+
f16_vec = svuzp1_f16(f16_vec, f16_vec);
65+
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec);
66+
i += lanes;
67+
}
68+
FloatToFloat16_ref(src + i, dst + i, size - i);
69+
}
70+
}
71+
72+
void Float16ToFloat_sve(const float16* src, float* dst, size_t size) {
73+
size_t i = 0;
74+
int lanes = svcntw();
75+
auto p_32 = svptrue_b32();
76+
auto p_16 = svptrue_b16();
77+
auto pfalse = svpfalse();
78+
auto p_16_half = svuzp1_b16(p_16, pfalse);
79+
while (i + 2 * lanes < size) {
80+
auto f16_vec = svld1_f16(p_16, (__fp16*)src + i);
81+
auto f16_vec1 = svzip1(f16_vec, f16_vec);
82+
auto f16_vec2 = svzip2(f16_vec, f16_vec);
83+
auto f32_vec1 = svcvt_f32_f16_x(p_16, f16_vec1);
84+
auto f32_vec2 = svcvt_f32_f16_x(p_16, f16_vec2);
85+
svst1_f32(p_32, dst + i, f32_vec1);
86+
svst1_f32(p_32, dst + i + lanes, f32_vec2);
87+
i += 2 * lanes;
88+
}
89+
while (i + lanes < size) {
90+
auto f16_vec = svld1_f16(p_16_half, (__fp16*)src + i);
91+
f16_vec = svzip1_f16(f16_vec, f16_vec);
92+
auto f32_vec = svcvt_f32_f16_x(p_32, f16_vec);
93+
svst1_f32(p_32, dst + i, f32_vec);
94+
i += lanes;
95+
}
96+
Float16ToFloat_ref(src + i, dst + i, size - i);
97+
}
98+
99+
} // namespace fbgemm

0 commit comments

Comments
 (0)