Skip to content

Commit

Permalink
Fix compilation with GCC and Clang
Browse files Browse the repository at this point in the history
Signed-off-by: Sv. Lockal <lockalsash@gmail.com>
  • Loading branch information
AngryLoki committed Apr 9, 2024
1 parent 3406cd2 commit 8039b65
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 79 deletions.
12 changes: 9 additions & 3 deletions cmake/cpu/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ if(MSVC)
else(MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-field-initializers")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-type-limits")
Expand All @@ -58,9 +59,14 @@ else(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
# Eigen fails to build with some versions, so convert this to a warning
# Details at http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1459
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ignored-attributes")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-pessimizing-move")

if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-vla-cxx-extension")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-builtins")
endif()

if (CMAKE_COMPILER_IS_GNUCXX AND NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-stringop-overflow")
endif()
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ void mul_attenion_weights_and_value_of_head_half(
auto hsi = 0;
auto vec_size = 32; // 512/16
for (hsi = 0; hsi <= head_size - vec_size; hsi += vec_size) {
auto attn_w_vec = _mm512_set1_ph(attn_w);
auto attn_w_vec = _mm512_set1_ph(*(_Float16*)&attn_w);
auto v_vec = _mm512_loadu_ph(v_ptr_start + hsi);
if (accumulate) {
auto attn_out_vec = _mm512_loadu_ph(attn_out_start + hsi);
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/jit/codegen/LlgaTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ struct LlgaTensorDesc {
}

bool operator!=(const LlgaTensorDesc& desc) const {
return *this != desc;
return !(*this == desc);
}

static size_t hash(const LlgaTensorDesc& desc) {
Expand Down
20 changes: 20 additions & 0 deletions csrc/cpu/tpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,29 @@ typedef at::Half half;
rvalue of type X'. Keep the original code as backup:
*/
#ifdef __clang__

#ifdef __INTEL_LLVM_COMPILER
#define DECL_VLA_PTR_PT(type, name, dims, t) \
auto name = (type(*) dims)(t.data_ptr<type>())
#else
/*
Workaround for Clang crash https://github.com/llvm/llvm-project/issues/75428.
This generates 2 statements, which is unsafe for `if (...) DECL_VLA_PTR_PT(...)`
blocks. However, due to first line it will be a compile error
(unknown type name 'vla_type_42'), not a runtime error.
*/
#define DECL_VLA_PTR_PT_MERGE_(a, b, type, name, dims, t) \
using a##b = type (*)dims; \
a##b name = reinterpret_cast<a##b>(t.data_ptr<type>())

#define DECL_VLA_PTR_PT_LABEL_(cnt, type, name, dims, t) \
DECL_VLA_PTR_PT_MERGE_(vla_type_, cnt, type, name, dims, t)

#define DECL_VLA_PTR_PT(type, name, dims, t) \
DECL_VLA_PTR_PT_LABEL_(__COUNTER__, type, name, dims, t)
#endif

#else // not clang
#define DECL_VLA_PTR_PT(type, name, dims, t) \
type(*name) dims = (type(*) dims)(t.data_ptr<type>())
#endif
Expand Down
131 changes: 67 additions & 64 deletions csrc/cpu/tpp/woq/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct VecOps<__m512h> {
static inline __m512h set1(ST v) {
return _mm512_set1_ph(v);
}
static inline __m512h set1(float v) {
return _mm512_set1_ph(ST(v));
}
static inline __m512h setzero() {
return _mm512_setzero_ph();
}
Expand All @@ -133,73 +136,73 @@ struct VecOps<__m512h> {
}
static inline __m512h set_0_to_15() {
return _mm512_set_ph(
15.0f,
14.0f,
13.0f,
12.0f,
11.0f,
10.0f,
9.0f,
8.0f,
7.0f,
6.0f,
5.0f,
4.0f,
3.0f,
2.0f,
1.0f,
0.0f,
15.0f,
14.0f,
13.0f,
12.0f,
11.0f,
10.0f,
9.0f,
8.0f,
7.0f,
6.0f,
5.0f,
4.0f,
3.0f,
2.0f,
1.0f,
0.0f);
15.0f16,
14.0f16,
13.0f16,
12.0f16,
11.0f16,
10.0f16,
9.0f16,
8.0f16,
7.0f16,
6.0f16,
5.0f16,
4.0f16,
3.0f16,
2.0f16,
1.0f16,
0.0f16,
15.0f16,
14.0f16,
13.0f16,
12.0f16,
11.0f16,
10.0f16,
9.0f16,
8.0f16,
7.0f16,
6.0f16,
5.0f16,
4.0f16,
3.0f16,
2.0f16,
1.0f16,
0.0f16);
}
static inline __m512h set_nf4_lut() {
return _mm512_set_ph(
1.0f,
0.7229568362236023,
0.5626170039176941,
0.44070982933044434,
0.33791524171829224,
0.24611230194568634,
0.16093020141124725,
0.07958029955625534,
0.0f,
-0.09105003625154495,
-0.18477343022823334,
-0.28444138169288635,
-0.39491748809814453,
-0.5250730514526367,
-0.6961928009986877,
-1.0f,
1.0f,
0.7229568362236023,
0.5626170039176941,
0.44070982933044434,
0.33791524171829224,
0.24611230194568634,
0.16093020141124725,
0.07958029955625534,
0.0f,
-0.09105003625154495,
-0.18477343022823334,
-0.28444138169288635,
-0.39491748809814453,
-0.5250730514526367,
-0.6961928009986877,
-1.0f);
1.0f16,
0.7229568362236023f16,
0.5626170039176941f16,
0.44070982933044434f16,
0.33791524171829224f16,
0.24611230194568634f16,
0.16093020141124725f16,
0.07958029955625534f16,
0.0f16,
-0.09105003625154495f16,
-0.18477343022823334f16,
-0.28444138169288635f16,
-0.39491748809814453f16,
-0.5250730514526367f16,
-0.6961928009986877f16,
-1.0f16,
1.0f16,
0.7229568362236023f16,
0.5626170039176941f16,
0.44070982933044434f16,
0.33791524171829224f16,
0.24611230194568634f16,
0.16093020141124725f16,
0.07958029955625534f16,
0.0f16,
-0.09105003625154495f16,
-0.18477343022823334f16,
-0.28444138169288635f16,
-0.39491748809814453f16,
-0.5250730514526367f16,
-0.6961928009986877f16,
-1.0f16);
}
};

Expand Down
3 changes: 0 additions & 3 deletions csrc/cpu/utils/library.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

#include <c10/util/Logging.h>

// The flag is used to control the pytorch log level and defined at c10
extern int FLAGS_caffe2_log_level;

// This FLAGS_caffe2_log_level flag is used to control the log level and defined
// at Pytorch c10 library. The default is log level is warn. But it triggers the
// warn to override the kernel under a particular dispatch key. Unfortunately,
Expand Down
12 changes: 6 additions & 6 deletions csrc/cpu/vec/vec512/perf_kernel/add_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ inline void _dil_div_add_reduce_max_fusion_kernel_half(
const int& size,
at::Half* out,
at::Half& max) {
auto vec_ps_min = _mm512_set1_ph((at::Half)(-65504.0));
auto vec_ps_min = _mm512_set1_ph(-65504.0f16);
auto vec_a = vec_ps_min;
auto vec_b = vec_ps_min;
auto vec_out = vec_ps_min;

int i = 0;
auto vec_r_dim_per_head = _mm512_set1_ph((at::Half)(1.0 / dim_per_head));
auto vec_r_dim_per_head = _mm512_set1_ph(_Float16(1.0 / dim_per_head));
for (; i <= size - 32; i += 32) {
vec_a = _loadu_half(a + i);
vec_b = _loadu_half(b + i);
Expand Down Expand Up @@ -304,9 +304,9 @@ inline void _dil_exp_reduce_sum_fusion_kernel_half(
const int& size,
at::Half* out,
at::Half& val) {
static auto vec_zero = _mm512_set1_ph((at::Half)(0.0));
auto vec_max = _mm512_set1_ph(val);
auto vec_sum = _mm512_set1_ph(at::Half(0.0));
static auto vec_zero = _mm512_set1_ph(0.0f16);
auto vec_max = _mm512_set1_ph(*(_Float16*)&val);
auto vec_sum = _mm512_set1_ph(0.0f16);
__m512h vec_a = {};
__m512h vec_out = {};

Expand Down Expand Up @@ -364,7 +364,7 @@ inline void _dil_normalization_kernel_half(
const at::Half& sum,
const int& size,
at::Half* out) {
auto vec_sum = _mm512_set1_ph(sum);
auto vec_sum = _mm512_set1_ph(*(_Float16*)&sum);
__m512h vec_a = {};
__m512h vec_out = {};

Expand Down
3 changes: 2 additions & 1 deletion tests/cpu/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ if(NOT EXISTS ${TORCH_INSTALL_PREFIX})
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-Bsymbolic-functions")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-Bsymbolic-functions")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-Bsymbolic-functions")

# Set the include dir
include_directories(${PYTORCH_INSTALL_DIR}/include)
Expand Down

0 comments on commit 8039b65

Please sign in to comment.