Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ class FBGEMM_API PackWeightsForConv {
return W_im2col_packed_;
}

#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
return W_dw_packed_;
}
Expand Down Expand Up @@ -672,7 +672,7 @@ class FBGEMM_API PackWeightsForConv {
const conv_param_t<SPATIAL_DIM> conv_param_;
// Packed weights if we use im2col based convolution implementation
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
// Packed weights if we use depthwise convolution implementation
std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
#endif // __aarch64__
Expand Down
8 changes: 3 additions & 5 deletions include/fbgemm/FbgemmConvert.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);
FBGEMM_API void
Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);

#if !defined(__aarch64__)
/**
* @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
*
Expand All @@ -58,10 +59,8 @@ FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);
* @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
*
*/
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
FBGEMM_API void
FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);
#endif

/**
* @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
Expand All @@ -74,7 +73,6 @@ Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);
* @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
*
*/
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
FBGEMM_API void
Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);
#endif
Expand Down Expand Up @@ -124,6 +122,7 @@ Float16ToFloat_simd(const float16* src, float* dst, size_t size);
* @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
*
*/
#if !defined(__aarch64__)
FBGEMM_API void FloatToFloat16_avx2(
const float* src,
float16* dst,
Expand All @@ -134,7 +133,6 @@ FBGEMM_API void FloatToFloat16_avx2(
* @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
*
*/
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
FBGEMM_API void FloatToFloat16_avx512(
const float* src,
float16* dst,
Expand All @@ -152,6 +150,7 @@ FBGEMM_API void FloatToFloat16_sve2(
size_t size,
bool do_clip = false);

#if !defined(__aarch64__)
/**
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
*
Expand All @@ -163,7 +162,6 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
* @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
*
*/
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
FBGEMM_API void
Float16ToFloat_avx512(const float16* src, float* dst, size_t size);
#endif
Expand Down
2 changes: 1 addition & 1 deletion include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ FBGEMM_API bool EmbeddingSpMDMBlockSize1_(
bool use_offsets = true,
bool is_bf16 = false);

#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
template <typename IndexType, bool HAS_WEIGHTS>
void compressed_indices_remap_avx512(
std::int32_t offsets_numel,
Expand Down
4 changes: 4 additions & 0 deletions include/fbgemm/FbgemmI8DepthwiseAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#pragma once

#if !defined(__aarch64__)

#include <cstdint>
#include "fbgemm/ConvUtils.h"
#include "fbgemm/FbgemmBuild.h"
Expand Down Expand Up @@ -110,3 +112,5 @@ FBGEMM_API void depthwise_3d_same_pad(
int num_threads = 1);

} // namespace fbgemm

#endif // !defined(__aarch64__)
2 changes: 1 addition & 1 deletion include/fbgemm/FbgemmSparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void SparseDenseMMAvx2(
int ldc,
bool accum = false);

#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
void SparseDenseMMAvx512(
int M,
int N,
Expand Down
4 changes: 2 additions & 2 deletions include/fbgemm/OutputProcessing-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
}
}

#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)

} else if constexpr (
instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
Expand Down Expand Up @@ -249,7 +249,7 @@ inline int ReQuantizeForFloat<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
}
}

#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
} else if constexpr (
instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
bool b_symmetric =
Expand Down
14 changes: 9 additions & 5 deletions include/fbgemm/QuantUtilsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ struct FBGEMM_API RequantizationParams {
TensorQuantizationParams target_qparams;
};

/// @ingroup fbgemm-quant-utils-avx2
///
/// @brief Find the min and max value in a float matrix.
void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len);

#if !defined(__aarch64__)

////////////////////////////////////////////////////////////////////////////////
// Utility functions
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -77,11 +84,6 @@ void FusedQuantizeDequantizeAvx2(
/// <a href="https://www.jstatsoft.org/v08/i14/paper">this paper</a>.
uint32_t FBGEMM_API Xor128();

/// @ingroup fbgemm-quant-utils-avx2
///
/// @brief Find the min and max value in a float matrix.
void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len);

void RequantizeFixedPointAvx2(
const std::int32_t* src,
std::uint8_t* dst,
Expand Down Expand Up @@ -176,4 +178,6 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
int input_columns,
OutputType* output);

#endif // !defined(__aarch64__)

} // namespace fbgemm
2 changes: 1 addition & 1 deletion include/fbgemm/QuantUtilsAvx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#pragma once

#include "Types.h"
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)

#include <cstdint>
#include "./FbgemmBuild.h" // @manual
Expand Down
4 changes: 2 additions & 2 deletions src/FbgemmBfloat16Convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace fbgemm {
void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (fbgemmHasAvx512Support()) {
FloatToBfloat16_avx512(src, dst, size);
} else if (fbgemmHasAvx2Support()) {
Expand All @@ -48,7 +48,7 @@ void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) {
void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (fbgemmHasAvx512Support()) {
Bfloat16ToFloat_avx512(src, dst, size);
} else if (fbgemmHasAvx2Support()) {
Expand Down
4 changes: 2 additions & 2 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace {
// the restrictions of ymm register numbers (16).
constexpr kernel_array_t<float16> kernel_fp16_avx2 = {
nullptr,
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_2x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_3x2_Avx2_fp16_fA0fB0fC0,
Expand Down Expand Up @@ -79,7 +79,7 @@ constexpr kernel_array_t<float16> kernel_fp16_neon = {

constexpr kernel_array_t<float16> kernel_fp16_avx512_256 = {
nullptr,
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_2x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_3x2_Avx2_fp16_fA0fB0fC0,
Expand Down
4 changes: 4 additions & 0 deletions src/FbgemmFP16UKernelsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ namespace fbgemm {

using GemmParamsFP16 = GemmParams<float16>;

#if !defined(__aarch64__)

void NOINLINE gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp);
void NOINLINE gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp);
void NOINLINE gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp);
void NOINLINE gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp);
void NOINLINE gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp);
void NOINLINE gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp);

#endif // !defined(__aarch64__)

} // namespace fbgemm
4 changes: 2 additions & 2 deletions src/FbgemmFloat16Convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void FloatToFloat16_simd(
bool do_clip) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (fbgemmHasAvx512Support()) {
FloatToFloat16_avx512(src, dst, size, do_clip);
} else if (fbgemmHasAvx2Support()) {
Expand All @@ -42,7 +42,7 @@ void FloatToFloat16_simd(
void Float16ToFloat_simd(const float16* src, float* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (fbgemmHasAvx512Support()) {
Float16ToFloat_avx512(src, dst, size);
} else if (fbgemmHasAvx2Support()) {
Expand Down
4 changes: 2 additions & 2 deletions src/FbgemmSparseDense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void SparseDenseMM(
float* C,
int ldc,
bool accum) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
// Run time CPU detection
static const auto iset = fbgemmInstructionSet();

Expand Down Expand Up @@ -229,7 +229,7 @@ FBGEMM_API void fbgemmSparseDenseInt8MM(
return;
}

#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
// Run time CPU detection
static const auto iset = fbgemmInstructionSet();

Expand Down
4 changes: 2 additions & 2 deletions src/GroupwiseConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static jit_conv_kernel_fp getOrCreateConvKernel(
accum);

if (cpuinfo_initialize()) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (fbgemmHasAvx512VnniSupport()) {
return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>::codeCache_
.getOrCreate(kernelSig, [&]() {
Expand Down Expand Up @@ -954,7 +954,7 @@ static void dispatchOutputProcessing(
}

if (cpuinfo_initialize()) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) {
REQUANTIZE_C_PER_G(Avx512);
} else if (fbgemmHasAvx2Support() || fbgemmHasArmNeonSupport()) {
Expand Down
4 changes: 2 additions & 2 deletions src/PackWeightsForConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
#if !defined(FBGEMM_FBCODE) && defined(__aarch64__)
#if defined(__aarch64__)
throw std::runtime_error(
"PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(): No fallback available for aarch64");
#else
Expand Down Expand Up @@ -98,7 +98,7 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(

template <int SPATIAL_DIM, typename T, typename accT>
void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#if !defined(__aarch64__)
if (W_dw_packed_) {
W_dw_packed_->unpack(origin_buf);
} else
Expand Down
4 changes: 2 additions & 2 deletions src/PackWeightsForDirectConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void fbgemmDirectConv(
return;
}

#if !defined(FBGEMM_FBCODE) && defined(__aarch64__)
#if defined(__aarch64__)
throw std::runtime_error(
"fbgemmDirectConv<SPATIAL_DIM, Q_GRAN, FUSE_RELU, BIAS_TYPE>(): No fallback available for aarch64");
#else
Expand Down Expand Up @@ -459,7 +459,7 @@ void fbgemmDirectConv(
}
} // else SPATIAL_DIM

#endif // defined(FBGEMM_FBCODE) || !defined(__aarch64__)
#endif // !defined(__aarch64__)
}

#define INSTANTIATE_REQUANTIZE_SPATIAL_DIM( \
Expand Down
52 changes: 50 additions & 2 deletions src/QuantUtilsNeon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm/Utils.h"
#if defined(__aarch64__)

#if HAVE_SVE
#include "fbgemm/Utils.h"

#define FBGEMM_EXPORTS
#include <arm_neon.h> // @manual
#if HAVE_SVE
#include <arm_sve.h> // @manual
#endif

#include <arm_neon_sve_bridge.h> // @manual
#include <algorithm> //for std::min/std::max
Expand All @@ -31,6 +33,50 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
// Utility functions

void FindMinMax(const float* m, float* min, float* max, int64_t len) {
if (__builtin_expect(len <= 0, 0)) {
*min = 0.0f;
*max = 0.0f;
return;
}

float first = *m;

float32x4_t temp_min_0 = vdupq_n_f32(first);
float32x4_t temp_min_1 = vdupq_n_f32(first);
float32x4_t temp_max_0 = vdupq_n_f32(first);
float32x4_t temp_max_1 = vdupq_n_f32(first);
uint64_t i = 0;
uint64_t count = static_cast<uint64_t>(len);
uint64_t loopBound = count - (count % 8);

for (; i < loopBound; i += 8) {
float32x4_t v0 = vld1q_f32(m + i);
float32x4_t v1 = vld1q_f32(m + i + 4);
temp_min_0 = vminq_f32(temp_min_0, v0);
temp_min_1 = vminq_f32(temp_min_1, v1);
temp_max_0 = vmaxq_f32(temp_max_0, v0);
temp_max_1 = vmaxq_f32(temp_max_1, v1);
}

temp_min_0 = vminq_f32(temp_min_0, temp_min_1);
temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1);

float tmp_min_s = vminvq_f32(temp_min_0);
float tmp_max_s = vmaxvq_f32(temp_max_0);

for (; i < count; i++) {
float tmp = *m;
tmp_min_s = std::min(tmp_min_s, tmp);
tmp_max_s = std::max(tmp_max_s, tmp);
}

*min = tmp_min_s;
*max = tmp_max_s;
}

#if HAVE_SVE

template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
const std::uint8_t* input,
Expand Down Expand Up @@ -141,6 +187,8 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16)
// clang-format on
#undef INSTANTIATE_QuantizationNeonFunctions8Bits

#endif // HAVE_SVE

} // namespace fbgemm

#endif // __aarch64__
Loading
Loading