From f2fc493149d75f0be13207bc1893a48c7fab84a3 Mon Sep 17 00:00:00 2001 From: "Wang, Phoebe" Date: Sun, 10 Nov 2024 22:37:15 +0800 Subject: [PATCH 1/2] [X86][AMX] Support AMX-TRANSPOSE, part 2 Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368 --- clang/include/clang/Basic/BuiltinsX86_64.def | 12 + clang/lib/Headers/CMakeLists.txt | 3 + clang/lib/Headers/amxbf16transposeintrin.h | 94 ++++++ clang/lib/Headers/amxcomplextransposeintrin.h | 301 ++++++++++++++++++ clang/lib/Headers/amxfp16intrin.h | 35 ++ clang/lib/Headers/amxfp16transposeintrin.h | 94 ++++++ clang/lib/Headers/amxintrin.h | 32 -- clang/lib/Headers/immintrin.h | 22 +- clang/lib/Sema/SemaX86.cpp | 6 + clang/test/CodeGen/X86/amx_transpose.c | 39 +++ clang/test/CodeGen/X86/amx_transpose_api.c | 50 ++- clang/test/CodeGen/X86/amx_transpose_errors.c | 50 ++- llvm/include/llvm/IR/IntrinsicsX86.td | 57 ++++ llvm/lib/Target/X86/X86ExpandPseudo.cpp | 28 +- llvm/lib/Target/X86/X86ISelLowering.cpp | 58 ++-- llvm/lib/Target/X86/X86InstrAMX.td | 89 ++++++ llvm/lib/Target/X86/X86LowerAMXType.cpp | 24 +- llvm/lib/Target/X86/X86RegisterInfo.cpp | 8 +- .../CodeGen/X86/amx_transpose_intrinsics.ll | 77 ++++- .../MC/Disassembler/X86/amx-transpose-att.txt | 48 +++ llvm/test/MC/X86/amx-transpose-att.s | 48 +++ llvm/test/MC/X86/amx-transpose-intel.s | 48 +++ 22 files changed, 1154 insertions(+), 69 deletions(-) create mode 100644 clang/lib/Headers/amxbf16transposeintrin.h create mode 100644 clang/lib/Headers/amxcomplextransposeintrin.h create mode 100644 clang/lib/Headers/amxfp16transposeintrin.h diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def index 9f7462b1e0d962..cc8637ed9c50da 100644 --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -133,6 +133,12 @@ TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0t1_internal, "vUsUsUsV256i*V256i*vC*z", TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose") TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1t1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose") TARGET_BUILTIN(__builtin_ia32_ttransposed_internal, "V256iUsUsV256i", "n", "amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttdpfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp16,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttcmmimfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttcmmrlfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_tconjtcmmimfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_tconjtfp16_internal, "V256iUsUsV256i", "n", "amx-complex,amx-transpose") TARGET_BUILTIN(__builtin_ia32_tcvtrowd2ps_internal, "V16fUsUsV256iUi", "n", "amx-avx512,avx10.2-512") TARGET_BUILTIN(__builtin_ia32_tcvtrowps2pbf16h_internal, "V32yUsUsV256iUi", "n", "amx-avx512,avx10.2-512") TARGET_BUILTIN(__builtin_ia32_tcvtrowps2pbf16l_internal, "V32yUsUsV256iUi", "n", "amx-avx512,avx10.2-512") @@ -164,6 +170,12 @@ TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0t1, "vIUcvC*z", "n","amx-transpose") TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1, "vIUcvC*z", "n", "amx-transpose") TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1t1, "vIUcvC*z", "n","amx-transpose") TARGET_BUILTIN(__builtin_ia32_ttransposed, "vIUcIUc", "n", "amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttdpbf16ps, "vIUcIUcIUc", "n", "amx-bf16,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttcmmimfp16ps, "vIUcIUcIUc", "n", "amx-complex,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_ttcmmrlfp16ps, "vIUcIUcIUc", "n", "amx-complex,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_tconjtcmmimfp16ps, "vIUcIUcIUc", "n", "amx-complex,amx-transpose") +TARGET_BUILTIN(__builtin_ia32_tconjtfp16, "vIUcIUc", "n", "amx-complex,amx-transpose") TARGET_BUILTIN(__builtin_ia32_tcvtrowd2ps, "V16fIUcUi", "n", "amx-avx512,avx10.2-512") TARGET_BUILTIN(__builtin_ia32_tcvtrowps2pbf16h, "V32yIUcUi", "n", "amx-avx512,avx10.2-512") diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt index 76366ca1f108e9..19013d37f46ef7 100644 --- a/clang/lib/Headers/CMakeLists.txt +++ b/clang/lib/Headers/CMakeLists.txt @@ -147,8 +147,11 @@ set(x86_files adxintrin.h ammintrin.h amxavx512intrin.h + amxbf16transposeintrin.h amxcomplexintrin.h + amxcomplextransposeintrin.h amxfp16intrin.h + amxfp16transposeintrin.h amxfp8intrin.h amxintrin.h amxtransposeintrin.h diff --git a/clang/lib/Headers/amxbf16transposeintrin.h b/clang/lib/Headers/amxbf16transposeintrin.h new file mode 100644 index 00000000000000..7d31384e317988 --- /dev/null +++ b/clang/lib/Headers/amxbf16transposeintrin.h @@ -0,0 +1,94 @@ +/*===----- amxbf16transposeintrin.h - AMX-BF16 and AMX-TRANSPOSE ------------=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===------------------------------------------------------------------------=== + */ + +#ifndef __IMMINTRIN_H +#error \ + "Never use directly; use instead." +#endif /* __IMMINTRIN_H */ + +#ifndef __AMX_BF16TRANSPOSEINTRIN_H +#define __AMX_BF16TRANSPOSEINTRIN_H +#ifdef __x86_64__ + +/* Define the default attributes for the functions in this file. */ +#define __DEFAULT_FN_ATTRS \ + __attribute__((__always_inline__, __nodebug__, \ + __target__("amx-bf16,amx-transpose"))) + +/// Compute transpose and dot-product of BF16 (16-bit) floating-point pairs in +/// tiles \a a and \a b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in \a dst, and store the +/// 32-bit result back to tile \a dst. +/// +/// \headerfile +/// +/// \code +/// void _tile_tdpbf16ps (__tile dst, __tile a, __tile b) +/// \endcode +/// +/// \code{.operation} +/// FOR m := 0 TO dst.rows - 1 +/// tmp := dst.row[m] +/// FOR k := 0 TO (a.colsb / 4) - 1 +/// FOR n := 0 TO (dst.colsb / 4) - 1 +/// tmp.bf32[n] += FP32(a.row[m].bf16[2*k+0]) * +/// FP32(b.row[k].bf16[2*n+0]) +/// tmp.bf32[n] += FP32(a.row[m].bf16[2*k+1]) * +/// FP32(b.row[k].bf16[2*n+1]) +/// ENDFOR +/// ENDFOR +/// write_row_and_zero(dst, m, tmp, dst.colsb) +/// ENDFOR +/// zero_upper_rows(dst, dst.rows) +/// zero_tileconfig_start() +/// \endcode +/// +/// This intrinsic corresponds to the \c TTDPBF16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +#define _tile_tdpbf16ps(dst, a, b) __builtin_ia32_ttdpbf16ps(dst, a, b) + +/// This is internal intrinsic. C/C++ user should avoid calling it directly. +static __inline__ _tile1024i __DEFAULT_FN_ATTRS +_tile_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_ttdpbf16ps_internal(m, n, k, dst, src1, src2); +} + +/// Compute transpose and dot-product of BF16 (16-bit) floating-point pairs in +/// tiles src0 and src1, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in "dst", and store the +/// 32-bit result back to tile "dst". +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TTDPBF16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static __inline__ void __tile_tdpbf16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_tdpbf16ps_internal(src0.row, src1.col, src0.col, dst->tile, + src0.tile, src1.tile); +} + +#undef __DEFAULT_FN_ATTRS + +#endif /* __x86_64__ */ +#endif /* __AMX_BF16TRANSPOSEINTRIN_H */ diff --git a/clang/lib/Headers/amxcomplextransposeintrin.h b/clang/lib/Headers/amxcomplextransposeintrin.h new file mode 100644 index 00000000000000..06fb53e4deadcd --- /dev/null +++ b/clang/lib/Headers/amxcomplextransposeintrin.h @@ -0,0 +1,301 @@ +/*===----- amxcomplextransposeintrin.h - AMX-COMPLEX and AMX-TRANSPOSE ------=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===------------------------------------------------------------------------=== + */ + +#ifndef __IMMINTRIN_H +#error \ + "Never use directly; include instead." +#endif // __IMMINTRIN_H + +#ifndef __AMX_COMPLEXTRANSPOSEINTRIN_H +#define __AMX_COMPLEXTRANSPOSEINTRIN_H +#ifdef __x86_64__ + +#define __DEFAULT_FN_ATTRS \ + __attribute__((__always_inline__, __nodebug__, \ + __target__("amx-complex,amx-transpose"))) + +/// Perform matrix multiplication of two tiles containing complex elements and +/// accumulate the results into a packed single precision tile. Each dword +/// element in input tiles \a a and \a b is interpreted as a complex number +/// with FP16 real part and FP16 imaginary part. +/// Calculates the imaginary part of the result. For each possible combination +/// of (transposed column of \a a, column of \a b), it performs a set of +/// multiplication and accumulations on all corresponding complex numbers +/// (one from \a a and one from \a b). The imaginary part of the \a a element +/// is multiplied with the real part of the corresponding \a b element, and +/// the real part of the \a a element is multiplied with the imaginary part +/// of the corresponding \a b elements. The two accumulated results are +/// added, and then accumulated into the corresponding row and column of +/// \a dst. +/// +/// \headerfile +/// +/// \code +/// void _tile_tcmmimfp16ps(__tile dst, __tile a, __tile b); +/// \endcode +/// +/// \code{.operation} +/// FOR m := 0 TO dst.rows - 1 +/// tmp := dst.row[m] +/// FOR k := 0 TO a.rows - 1 +/// FOR n := 0 TO (dst.colsb / 4) - 1 +/// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1]) +/// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0]) +/// ENDFOR +/// ENDFOR +/// write_row_and_zero(dst, m, tmp, dst.colsb) +/// ENDFOR +/// zero_upper_rows(dst, dst.rows) +/// zero_tileconfig_start() +/// \endcode +/// +/// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +#define _tile_tcmmimfp16ps(dst, a, b) __builtin_ia32_ttcmmimfp16ps(dst, a, b) + +/// Perform matrix multiplication of two tiles containing complex elements and +/// accumulate the results into a packed single precision tile. Each dword +/// element in input tiles \a a and \a b is interpreted as a complex number +/// with FP16 real part and FP16 imaginary part. +/// Calculates the real part of the result. For each possible combination +/// of (rtransposed colum of \a a, column of \a b), it performs a set of +/// multiplication and accumulations on all corresponding complex numbers +/// (one from \a a and one from \a b). The real part of the \a a element is +/// multiplied with the real part of the corresponding \a b element, and the +/// negated imaginary part of the \a a element is multiplied with the +/// imaginary part of the corresponding \a b elements. The two accumulated +/// results are added, and then accumulated into the corresponding row and +/// column of \a dst. +/// +/// \headerfile +/// +/// \code +/// void _tile_tcmmrlfp16ps(__tile dst, __tile a, __tile b); +/// \endcode +/// +/// \code{.operation} +/// FOR m := 0 TO dst.rows - 1 +/// tmp := dst.row[m] +/// FOR k := 0 TO a.rows - 1 +/// FOR n := 0 TO (dst.colsb / 4) - 1 +/// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+0]) +/// tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+1]) +/// ENDFOR +/// ENDFOR +/// write_row_and_zero(dst, m, tmp, dst.colsb) +/// ENDFOR +/// zero_upper_rows(dst, dst.rows) +/// zero_tileconfig_start() +/// \endcode +/// +/// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +#define _tile_tcmmrlfp16ps(dst, a, b) __builtin_ia32_ttcmmrlfp16ps(dst, a, b) + +/// Perform matrix conjugate transpose and multiplication of two tiles +/// containing complex elements and accumulate the results into a packed +/// single precision tile. Each dword element in input tiles \a a and \a b +/// is interpreted as a complex number with FP16 real part and FP16 imaginary +/// part. +/// Calculates the imaginary part of the result. For each possible combination +/// of (transposed column of \a a, column of \a b), it performs a set of +/// multiplication and accumulations on all corresponding complex numbers +/// (one from \a a and one from \a b). The negated imaginary part of the \a a +/// element is multiplied with the real part of the corresponding \a b +/// element, and the real part of the \a a element is multiplied with the +/// imaginary part of the corresponding \a b elements. The two accumulated +/// results are added, and then accumulated into the corresponding row and +/// column of \a dst. +/// +/// \headerfile +/// +/// \code +/// void _tile_conjtcmmimfp16ps(__tile dst, __tile a, __tile b); +/// \endcode +/// +/// \code{.operation} +/// FOR m := 0 TO dst.rows - 1 +/// tmp := dst.row[m] +/// FOR k := 0 TO a.rows - 1 +/// FOR n := 0 TO (dst.colsb / 4) - 1 +/// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1]) +/// tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0]) +/// ENDFOR +/// ENDFOR +/// write_row_and_zero(dst, m, tmp, dst.colsb) +/// ENDFOR +/// zero_upper_rows(dst, dst.rows) +/// zero_tileconfig_start() +/// \endcode +/// +/// This intrinsic corresponds to the \c TCONJTCMMIMFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +#define _tile_conjtcmmimfp16ps(dst, a, b) \ + __builtin_ia32_tconjtcmmimfp16ps(dst, a, b) + +/// Perform conjugate transpose of an FP16-pair of complex elements from \a a +/// and writes the result to \a dst. +/// +/// \headerfile +/// +/// \code +/// void _tile_conjtfp16(__tile dst, __tile a); +/// \endcode +/// +/// \code{.operation} +/// FOR i := 0 TO dst.rows - 1 +/// FOR j := 0 TO (dst.colsb / 4) - 1 +/// tmp.fp16[2*j+0] := a.row[j].fp16[2*i+0] +/// tmp.fp16[2*j+1] := -a.row[j].fp16[2*i+1] +/// ENDFOR +/// write_row_and_zero(dst, i, tmp, dst.colsb) +/// ENDFOR +/// zero_upper_rows(dst, dst.rows) +/// zero_tileconfig_start() +/// \endcode +/// +/// This intrinsic corresponds to the \c TCONJTFP16 instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The source tile. Max size is 1024 Bytes. +#define _tile_conjtfp16(dst, a) __builtin_ia32_tconjtfp16(dst, a) + +static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmimfp16ps_internal( + unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_ttcmmimfp16ps_internal(m, n, k, dst, src1, src2); +} + +static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmrlfp16ps_internal( + unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_ttcmmrlfp16ps_internal(m, n, k, dst, src1, src2); +} + +static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtcmmimfp16ps_internal( + unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tconjtcmmimfp16ps_internal(m, n, k, dst, src1, src2); +} + +static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtfp16_internal( + unsigned short m, unsigned short n, _tile1024i src) { + return __builtin_ia32_tconjtfp16_internal(m, n, src); +} + +/// Perform matrix multiplication of two tiles containing complex elements and +/// accumulate the results into a packed single precision tile. Each dword +/// element in input tiles src0 and src1 is interpreted as a complex number +/// with FP16 real part and FP16 imaginary part. +/// This function calculates the imaginary part of the result. +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TTCMMIMFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static void __tile_tcmmimfp16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_tcmmimfp16ps_internal(src0.row, src1.col, src0.col, + dst->tile, src0.tile, src1.tile); +} + +/// Perform matrix multiplication of two tiles containing complex elements and +/// accumulate the results into a packed single precision tile. Each dword +/// element in input tiles src0 and src1 is interpreted as a complex number +/// with FP16 real part and FP16 imaginary part. +/// This function calculates the real part of the result. +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TTCMMRLFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static void __tile_tcmmrlfp16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_tcmmrlfp16ps_internal(src0.row, src1.col, src0.col, + dst->tile, src0.tile, src1.tile); +} + +/// Perform matrix conjugate transpose and multiplication of two tiles +/// containing complex elements and accumulate the results into a packed +/// single precision tile. Each dword element in input tiles src0 and src1 +/// is interpreted as a complex number with FP16 real part and FP16 imaginary +/// part. +/// This function calculates the imaginary part of the result. +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TCONJTCMMIMFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static void __tile_conjtcmmimfp16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_conjtcmmimfp16ps_internal(src0.row, src1.col, src0.col, + dst->tile, src0.tile, src1.tile); +} + +/// Perform conjugate transpose of an FP16-pair of complex elements from src and +/// writes the result to dst. +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TCONJTFP16 instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src +/// The source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static void __tile_conjtfp16(__tile1024i *dst, __tile1024i src) { + dst->tile = _tile_conjtfp16_internal(src.row, src.col, src.tile); +} + +#undef __DEFAULT_FN_ATTRS + +#endif // __x86_64__ +#endif // __AMX_COMPLEXTRANSPOSEINTRIN_H diff --git a/clang/lib/Headers/amxfp16intrin.h b/clang/lib/Headers/amxfp16intrin.h index ed798245d41efb..bb4bc31fdafd50 100644 --- a/clang/lib/Headers/amxfp16intrin.h +++ b/clang/lib/Headers/amxfp16intrin.h @@ -15,6 +15,10 @@ #define __AMX_FP16INTRIN_H #ifdef __x86_64__ +/* Define the default attributes for the functions in this file. */ +#define __DEFAULT_FN_ATTRS \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-fp16"))) + /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles \a a /// and \a b, accumulating the intermediate single-precision (32-bit) /// floating-point elements with elements in \a dst, and store the 32-bit @@ -54,5 +58,36 @@ #define _tile_dpfp16ps(dst, a, b) \ __builtin_ia32_tdpfp16ps(dst, a, b) +/// This is internal intrinsic. C/C++ user should avoid calling it directly. +static __inline__ _tile1024i __DEFAULT_FN_ATTRS +_tile_dpfp16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpfp16ps_internal(m, n, k, dst, src1, src2); +} + +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles src0 and +/// src1, accumulating the intermediate single-precision (32-bit) floating-point +/// elements with elements in "dst", and store the 32-bit result back to tile +/// "dst". +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TDPFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static __inline__ void __tile_dpfp16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_dpfp16ps_internal(src0.row, src1.col, src0.col, dst->tile, + src0.tile, src1.tile); +} + +#undef __DEFAULT_FN_ATTRS + #endif /* __x86_64__ */ #endif /* __AMX_FP16INTRIN_H */ diff --git a/clang/lib/Headers/amxfp16transposeintrin.h b/clang/lib/Headers/amxfp16transposeintrin.h new file mode 100644 index 00000000000000..c07c5516301983 --- /dev/null +++ b/clang/lib/Headers/amxfp16transposeintrin.h @@ -0,0 +1,94 @@ +/*===----- amxfp16transposeintrin.h - AMX-FP16 and AMX-TRANSPOSE ------------=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===------------------------------------------------------------------------=== + */ + +#ifndef __IMMINTRIN_H +#error \ + "Never use directly; use instead." +#endif /* __IMMINTRIN_H */ + +#ifndef __AMX_FP16TRANSPOSEINTRIN_H +#define __AMX_FP16TRANSPOSEINTRIN_H +#ifdef __x86_64__ + +/* Define the default attributes for the functions in this file. */ +#define __DEFAULT_FN_ATTRS \ + __attribute__((__always_inline__, __nodebug__, \ + __target__("amx-fp16,amx-transpose"))) + +/// Compute transpose and dot-product of FP16 (16-bit) floating-point pairs in +/// tiles \a a and \a b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in \a dst, and store the +/// 32-bit result back to tile \a dst. +/// +/// \headerfile +/// +/// \code +/// void _tile_tdpfp16ps (__tile dst, __tile a, __tile b) +/// \endcode +/// +/// \code{.operation} +/// FOR m := 0 TO dst.rows - 1 +/// tmp := dst.row[m] +/// FOR k := 0 TO (a.colsb / 4) - 1 +/// FOR n := 0 TO (dst.colsb / 4) - 1 +/// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * +/// FP32(b.row[k].fp16[2*n+0]) +/// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+1]) * +/// FP32(b.row[k].fp16[2*n+1]) +/// ENDFOR +/// ENDFOR +/// write_row_and_zero(dst, m, tmp, dst.colsb) +/// ENDFOR +/// zero_upper_rows(dst, dst.rows) +/// zero_tileconfig_start() +/// \endcode +/// +/// This intrinsic corresponds to the \c TTDPFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +#define _tile_tdpfp16ps(dst, a, b) __builtin_ia32_ttdpfp16ps(dst, a, b) + +/// This is internal intrinsic. C/C++ user should avoid calling it directly. +static __inline__ _tile1024i __DEFAULT_FN_ATTRS +_tile_tdpfp16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_ttdpfp16ps_internal(m, n, k, dst, src1, src2); +} + +/// Compute transpose and dot-product of FP16 (16-bit) floating-point pairs in +/// tiles src0 and src1, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in "dst", and store the +/// 32-bit result back to tile "dst". +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TTDPFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS +static __inline__ void __tile_tdpfp16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_tdpfp16ps_internal(src0.row, src1.col, src0.col, dst->tile, + src0.tile, src1.tile); +} + +#undef __DEFAULT_FN_ATTRS + +#endif /* __x86_64__ */ +#endif /* __AMX_FP16TRANSPOSEINTRIN_H */ diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h index f07a5689011853..b0140615677f27 100644 --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -22,8 +22,6 @@ __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) #define __DEFAULT_FN_ATTRS_BF16 \ __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16"))) -#define __DEFAULT_FN_ATTRS_FP16 \ - __attribute__((__always_inline__, __nodebug__, __target__("amx-fp16"))) /// Load tile configuration from a 64-byte memory location specified by /// "mem_addr". The tile configuration includes the tile type palette, the @@ -294,13 +292,6 @@ _tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); } -/// This is internal intrinsic. C/C++ user should avoid calling it directly. -static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP16 -_tile_dpfp16ps_internal(unsigned short m, unsigned short n, unsigned short k, - _tile1024i dst, _tile1024i src1, _tile1024i src2) { - return __builtin_ia32_tdpfp16ps_internal(m, n, k, dst, src1, src2); -} - /// This struct pack the shape and tile data together for user. We suggest /// initializing the struct as early as possible, because compiler depends /// on the shape information to do configure. The constant value is preferred @@ -495,32 +486,9 @@ static __inline__ void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src0, src0.tile, src1.tile); } -/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles src0 and -/// src1, accumulating the intermediate single-precision (32-bit) floating-point -/// elements with elements in "dst", and store the 32-bit result back to tile -/// "dst". -/// -/// \headerfile -/// -/// This intrinsic corresponds to the TDPFP16PS instruction. -/// -/// \param dst -/// The destination tile. Max size is 1024 Bytes. -/// \param src0 -/// The 1st source tile. Max size is 1024 Bytes. -/// \param src1 -/// The 2nd source tile. Max size is 1024 Bytes. -__DEFAULT_FN_ATTRS_FP16 -static __inline__ void __tile_dpfp16ps(__tile1024i *dst, __tile1024i src0, - __tile1024i src1) { - dst->tile = _tile_dpfp16ps_internal(src0.row, src1.col, src0.col, dst->tile, - src0.tile, src1.tile); -} - #undef __DEFAULT_FN_ATTRS_TILE #undef __DEFAULT_FN_ATTRS_INT8 #undef __DEFAULT_FN_ATTRS_BF16 -#undef __DEFAULT_FN_ATTRS_FP16 #endif /* __x86_64__ */ #endif /* __AMXINTRIN_H */ diff --git a/clang/lib/Headers/immintrin.h b/clang/lib/Headers/immintrin.h index bc240e28d59142..a6945406dc6610 100644 --- a/clang/lib/Headers/immintrin.h +++ b/clang/lib/Headers/immintrin.h @@ -630,9 +630,6 @@ _storebe_i64(void * __P, long long __D) { #if !defined(__SCE__) || __has_feature(modules) || defined(__INVPCID__) #include #endif -#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_FP16__) -#include -#endif #if !defined(__SCE__) || __has_feature(modules) || defined(__KL__) || \ defined(__WIDEKL__) @@ -644,6 +641,10 @@ _storebe_i64(void * __P, long long __D) { #include #endif +#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_FP16__) +#include +#endif + #if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_COMPLEX__) #include #endif @@ -660,6 +661,21 @@ _storebe_i64(void * __P, long long __D) { #include #endif +#if !defined(__SCE__) || __has_feature(modules) || \ + (defined(__AMX_BF16__) && defined(__AMX_TRANSPOSE__)) +#include +#endif + +#if !defined(__SCE__) || __has_feature(modules) || \ + (defined(__AMX_FP16__) && defined(__AMX_TRANSPOSE__)) +#include +#endif + +#if !defined(__SCE__) || __has_feature(modules) || \ + (defined(__AMX_COMPLEX__) && defined(__AMX_TRANSPOSE__)) +#include +#endif + #if !defined(__SCE__) || __has_feature(modules) || \ defined(__AVX512VP2INTERSECT__) #include diff --git a/clang/lib/Sema/SemaX86.cpp b/clang/lib/Sema/SemaX86.cpp index 1155a5edc73c34..4290f48d497290 100644 --- a/clang/lib/Sema/SemaX86.cpp +++ b/clang/lib/Sema/SemaX86.cpp @@ -654,8 +654,14 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) { case X86::BI__builtin_ia32_tdpbhf8ps: case X86::BI__builtin_ia32_tdphbf8ps: case X86::BI__builtin_ia32_tdphf8ps: + case X86::BI__builtin_ia32_ttdpbf16ps: + case X86::BI__builtin_ia32_ttdpfp16ps: + case X86::BI__builtin_ia32_ttcmmimfp16ps: + case X86::BI__builtin_ia32_ttcmmrlfp16ps: + case X86::BI__builtin_ia32_tconjtcmmimfp16ps: return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2}); case X86::BI__builtin_ia32_ttransposed: + case X86::BI__builtin_ia32_tconjtfp16: return CheckBuiltinTileArgumentsRange(TheCall, {0, 1}); } } diff --git a/clang/test/CodeGen/X86/amx_transpose.c b/clang/test/CodeGen/X86/amx_transpose.c index deefc592c7ae66..7e88fd80592d62 100644 --- a/clang/test/CodeGen/X86/amx_transpose.c +++ b/clang/test/CodeGen/X86/amx_transpose.c @@ -1,4 +1,5 @@ // RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-transpose \ +// RUN: -target-feature +amx-bf16 -target-feature +amx-fp16 -target-feature +amx-complex \ // RUN: -target-feature +avx512f -emit-llvm -o - -Wall -Werror -pedantic -Wno-gnu-statement-expression| FileCheck %s #include @@ -34,3 +35,41 @@ void test_tile_transposed(void) // CHECK: call void @llvm.x86.ttransposed(i8 1, i8 2) _tile_transposed(1, 2); } + +void test_tile_tdpbf16ps(void) +{ + // CHECK-LABEL: @test_tile_tdpbf16ps + // CHECK: call void @llvm.x86.ttdpbf16ps(i8 1, i8 2, i8 3) + _tile_tdpbf16ps(1, 2, 3); +} + +void test_tile_tdpfp16ps(void) +{ + // CHECK-LABEL: @test_tile_tdpfp16ps + // CHECK: call void @llvm.x86.ttdpfp16ps(i8 4, i8 5, i8 6) + _tile_tdpfp16ps(4, 5, 6); +} + +void test_tile_tcmmimfp16ps(void) { + // CHECK-LABEL: @test_tile_tcmmimfp16ps + // CHECK: call void @llvm.x86.ttcmmimfp16ps(i8 1, i8 2, i8 3) + _tile_tcmmimfp16ps(1, 2, 3); +} + +void test_tile_tcmmrlfp16ps(void) { + // CHECK-LABEL: @test_tile_tcmmrlfp16ps + // CHECK: call void @llvm.x86.ttcmmrlfp16ps(i8 1, i8 2, i8 3) + _tile_tcmmrlfp16ps(1, 2, 3); +} + +void test_tile_conjtcmmimfp16ps(void) { + // CHECK-LABEL: @test_tile_conjtcmmimfp16ps + // CHECK: call void @llvm.x86.tconjtcmmimfp16ps(i8 1, i8 2, i8 3) + _tile_conjtcmmimfp16ps(1, 2, 3); +} + +void test_tile_conjtfp16(void) { + // CHECK-LABEL: @test_tile_conjtfp16 + // CHECK: call void @llvm.x86.tconjtfp16(i8 1, i8 2) + _tile_conjtfp16(1, 2); +} diff --git a/clang/test/CodeGen/X86/amx_transpose_api.c b/clang/test/CodeGen/X86/amx_transpose_api.c index 10310c2332b7a6..dc3ef5104252ca 100644 --- a/clang/test/CodeGen/X86/amx_transpose_api.c +++ b/clang/test/CodeGen/X86/amx_transpose_api.c @@ -1,5 +1,5 @@ // RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f \ -// RUN: -target-feature +amx-transpose \ +// RUN: -target-feature +amx-transpose -target-feature +amx-bf16 -target-feature +amx-fp16 -target-feature +amx-complex \ // RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK #include @@ -64,3 +64,51 @@ void test_tile_transposed(__tile1024i dst, __tile1024i src) { //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) __tile_transposed(&dst, src); } + +void test_tile_tdpbf16ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_tdpbf16ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.ttdpbf16ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_tdpbf16ps(&c, a, b); +} + +void test_tile_tdpfp16ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_tdpfp16ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.ttdpfp16ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_tdpfp16ps(&c, a, b); +} + +void test_tile_tcmmimfp16ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_tcmmimfp16ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.ttcmmimfp16ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_tcmmimfp16ps(&c, a, b); +} + +void test_tile_tcmmrlfp16ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_tcmmrlfp16ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.ttcmmrlfp16ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_tcmmrlfp16ps(&c, a, b); +} + +void test_tile_conjtcmmimfp16ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_conjtcmmimfp16ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.tconjtcmmimfp16ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_conjtcmmimfp16ps(&c, a, b); +} + +void test_tile_conjtfp16(__tile1024i dst, __tile1024i src) { + //CHECK-LABEL: @test_tile_conjtfp16 + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.tconjtfp16.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_conjtfp16(&dst, src); +} diff --git a/clang/test/CodeGen/X86/amx_transpose_errors.c b/clang/test/CodeGen/X86/amx_transpose_errors.c index 80084c42a240dd..80368c580c793a 100644 --- a/clang/test/CodeGen/X86/amx_transpose_errors.c +++ b/clang/test/CodeGen/X86/amx_transpose_errors.c @@ -1,9 +1,7 @@ // RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown \ // RUN: -target-feature +amx-int8 -target-feature +amx-bf16 -target-feature +amx-transpose \ -// RUN: -target-feature +avx512f -target-feature +amx-element-evex -verify +// RUN: -target-feature +avx512f -target-feature +amx-fp16 -target-feature +amx-complex -verify -#include -#include #include #include @@ -24,8 +22,54 @@ void test_tile_2rpntlvwz1t1(const void *A, size_t B) { _tile_2rpntlvwz1t1(8, A, B); // expected-error {{argument value 8 is outside the valid range [0, 7]}} } +void test_tile_tdpbf16ps() +{ + _tile_tdpbf16ps(8, 2, 3); // expected-error {{argument value 8 is outside the valid range [0, 7]}} + _tile_tdpbf16ps(1, 8, 3); // expected-error {{argument value 8 is outside the valid range [0, 7]}} + _tile_tdpbf16ps(1, 2, 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}} + _tile_tdpbf16ps(1, 1, 3); // expected-error {{tile arguments must refer to different tiles}} + _tile_tdpbf16ps(1, 2, 1); // expected-error {{tile arguments must refer to different tiles}} + _tile_tdpbf16ps(1, 2, 2); // expected-error {{tile arguments must refer to different tiles}} +} + +void test_tile_tdpfp16ps() +{ + _tile_tdpfp16ps(8, 5, 6); // expected-error {{argument value 8 is outside the valid range [0, 7]}} + _tile_tdpfp16ps(1, 8, 6); // expected-error {{argument value 8 is outside the valid range [0, 7]}} + _tile_tdpfp16ps(1, 5, 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}} + _tile_tdpfp16ps(1, 1, 3); // expected-error {{tile arguments must refer to different tiles}} + _tile_tdpfp16ps(1, 2, 1); // expected-error {{tile arguments must refer to different tiles}} + _tile_tdpfp16ps(1, 2, 2); // expected-error {{tile arguments must refer to different tiles}} +} + void test_tile_transposed() { _tile_transposed(8, 2); // expected-error {{argument value 8 is outside the valid range [0, 7]}} _tile_transposed(1, 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}} } + +void test_tile_tcmmimfp16ps() { + _tile_tcmmimfp16ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}} + _tile_tcmmimfp16ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}} + _tile_tcmmimfp16ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}} + _tile_tcmmimfp16ps(1, 1, 3); // expected-error {{tile arguments must refer to different tiles}} +} + +void test_tile_tcmmrlfp16ps() { + _tile_tcmmrlfp16ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}} + _tile_tcmmrlfp16ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}} + _tile_tcmmrlfp16ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}} + _tile_tcmmrlfp16ps(1, 1, 3); // expected-error {{tile arguments must refer to different tiles}} +} + +void test_tile_conjtcmmimfp16ps() { + _tile_conjtcmmimfp16ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}} + _tile_conjtcmmimfp16ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}} + _tile_conjtcmmimfp16ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}} + _tile_conjtcmmimfp16ps(1, 2, 1); // expected-error {{tile arguments must refer to different tiles}} +} + +void test_tile_conjtfp16() { + _tile_conjtfp16(16, 2); // expected-error {{argument value 16 is outside the valid range [0, 7]}} + _tile_conjtfp16(1, 26); // expected-error {{argument value 26 is outside the valid range [0, 7]}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td index 3003f9887e239c..21d4ece9d92f63 100644 --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5951,6 +5951,29 @@ let TargetPrefix = "x86" in { def int_x86_ttransposed : ClangBuiltin<"__builtin_ia32_ttransposed">, Intrinsic<[], [llvm_i8_ty, llvm_i8_ty], [ImmArg>, ImmArg>]>; + def int_x86_ttdpbf16ps : ClangBuiltin<"__builtin_ia32_ttdpbf16ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg>, ImmArg>, + ImmArg>]>; + def int_x86_ttdpfp16ps : ClangBuiltin<"__builtin_ia32_ttdpfp16ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg>, ImmArg>, + ImmArg>]>; + def int_x86_ttcmmimfp16ps : ClangBuiltin<"__builtin_ia32_ttcmmimfp16ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg>, ImmArg>, + ImmArg>]>; + def int_x86_ttcmmrlfp16ps : ClangBuiltin<"__builtin_ia32_ttcmmrlfp16ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg>, ImmArg>, + ImmArg>]>; + def int_x86_tconjtcmmimfp16ps : ClangBuiltin<"__builtin_ia32_tconjtcmmimfp16ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg>, ImmArg>, + ImmArg>]>; + def int_x86_tconjtfp16 : ClangBuiltin<"__builtin_ia32_tconjtfp16">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty], + [ImmArg>, ImmArg>]>; // AMX-AVX512 def int_x86_tcvtrowd2ps : ClangBuiltin<"__builtin_ia32_tcvtrowd2ps">, @@ -6070,6 +6093,40 @@ let TargetPrefix = "x86" in { ClangBuiltin<"__builtin_ia32_ttransposed_internal">, Intrinsic<[llvm_x86amx_ty], [llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty], []>; + def int_x86_ttdpbf16ps_internal : + ClangBuiltin<"__builtin_ia32_ttdpbf16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_ttdpfp16ps_internal : + ClangBuiltin<"__builtin_ia32_ttdpfp16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_ttcmmimfp16ps_internal : + ClangBuiltin<"__builtin_ia32_ttcmmimfp16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_ttcmmrlfp16ps_internal : + ClangBuiltin<"__builtin_ia32_ttcmmrlfp16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_tconjtcmmimfp16ps_internal : + ClangBuiltin<"__builtin_ia32_tconjtcmmimfp16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_tconjtfp16_internal : + ClangBuiltin<"__builtin_ia32_tconjtfp16_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty], []>; def int_x86_tcvtrowd2ps_internal : ClangBuiltin<"__builtin_ia32_tcvtrowd2ps_internal">, diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index a6096e5032e89c..f68291f28afeb4 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -742,10 +742,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, MI.setDesc(TII->get(Opc)); return true; } - case X86::PTTRANSPOSEDV: { + case X86::PTTRANSPOSEDV: + case X86::PTCONJTFP16V: { for (int i = 2; i > 0; --i) MI.removeOperand(i); - MI.setDesc(TII->get(X86::TTRANSPOSED)); + MI.setDesc(TII->get(Opcode == X86::PTTRANSPOSEDV ? X86::TTRANSPOSED + : X86::TCONJTFP16)); return true; } case X86::PTCMMIMFP16PSV: @@ -755,7 +757,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTDPBF16PSV: - case X86::PTDPFP16PSV: { + case X86::PTDPFP16PSV: + case X86::PTTDPBF16PSV: + case X86::PTTDPFP16PSV: + case X86::PTTCMMIMFP16PSV: + case X86::PTTCMMRLFP16PSV: + case X86::PTCONJTCMMIMFP16PSV: { MI.untieRegOperand(4); for (unsigned i = 3; i > 0; --i) MI.removeOperand(i); @@ -769,6 +776,21 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break; case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; + case X86::PTTDPBF16PSV: + Opc = X86::TTDPBF16PS; + break; + case X86::PTTDPFP16PSV: + Opc = X86::TTDPFP16PS; + break; + case X86::PTTCMMIMFP16PSV: + Opc = X86::TTCMMIMFP16PS; + break; + case X86::PTTCMMRLFP16PSV: + Opc = X86::TTCMMRLFP16PS; + break; + case X86::PTCONJTCMMIMFP16PSV: + Opc = X86::TCONJTCMMIMFP16PS; + break; default: llvm_unreachable("Unexpected Opcode"); } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 3888d207206ec8..a0faf4769dfc31 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -37465,13 +37465,19 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTDPBUUD: case X86::PTDPBF16PS: case X86::PTDPFP16PS: + case X86::PTCMMIMFP16PS: + case X86::PTCMMRLFP16PS: case X86::PTDPBF8PS: case X86::PTDPBHF8PS: case X86::PTDPHBF8PS: - case X86::PTDPHF8PS: { + case X86::PTDPHF8PS: + case X86::PTTDPBF16PS: + case X86::PTTDPFP16PS: + case X86::PTTCMMIMFP16PS: + case X86::PTTCMMRLFP16PS: + case X86::PTCONJTCMMIMFP16PS: { unsigned Opc; switch (MI.getOpcode()) { - // clang-format off default: llvm_unreachable("illegal opcode!"); case X86::PTDPBSSD: Opc = X86::TDPBSSD; break; case X86::PTDPBSUD: Opc = X86::TDPBSUD; break; @@ -37479,11 +37485,31 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTDPBUUD: Opc = X86::TDPBUUD; break; case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break; case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break; + case X86::PTCMMIMFP16PS: + Opc = X86::TCMMIMFP16PS; + break; + case X86::PTCMMRLFP16PS: + Opc = X86::TCMMRLFP16PS; + break; case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break; case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break; case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break; case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; - // clang-format on + case X86::PTTDPBF16PS: + Opc = X86::TTDPBF16PS; + break; + case X86::PTTDPFP16PS: + Opc = X86::TTDPFP16PS; + break; + case X86::PTTCMMIMFP16PS: + Opc = X86::TTCMMIMFP16PS; + break; + case X86::PTTCMMRLFP16PS: + Opc = X86::TTCMMRLFP16PS; + break; + case X86::PTCONJTCMMIMFP16PS: + Opc = X86::TCONJTCMMIMFP16PS; + break; } MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); @@ -37546,25 +37572,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); // The pseudo is gone now. return BB; } - case X86::PTCMMIMFP16PS: - case X86::PTCMMRLFP16PS: { - const MIMetadata MIMD(MI); - unsigned Opc; - switch (MI.getOpcode()) { - // clang-format off - default: llvm_unreachable("Unexpected instruction!"); - case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break; - case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break; - // clang-format on - } - MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef); - MI.eraseFromParent(); // The pseudo is gone now. - return BB; - } case X86::PT2RPNTLVWZ0: case X86::PT2RPNTLVWZ0T1: case X86::PT2RPNTLVWZ1: @@ -37598,10 +37605,13 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); // The pseudo is gone now. return BB; } - case X86::PTTRANSPOSED: { + case X86::PTTRANSPOSED: + case X86::PTCONJTFP16: { const DebugLoc &DL = MI.getDebugLoc(); + unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED + : X86::TCONJTFP16; - MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(X86::TTRANSPOSED)); + MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td index b954c977f8c6c9..428d1823355b49 100644 --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -370,6 +370,95 @@ let Predicates = [HasAMXTRANSPOSE, In64BitMode] in { } } // HasAMXTILE, HasAMXTRANSPOSE +let Predicates = [HasAMXBF16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { + let Constraints = "$src1 = $dst" in + def TTDPBF16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "ttdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", + []>, VEX, VVVV, T8,XS; + let Constraints = "$src4 = $dst" in + def PTTDPBF16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_ttdpbf16ps_internal GR16:$src1, GR16:$src2, + GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; + let usesCustomInserter = 1 in + def PTTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), + [(int_x86_ttdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; +} + +let Predicates = [HasAMXFP16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { + let Constraints = "$src1 = $dst" in + def TTDPFP16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "ttdpfp16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", + []>, VEX, VVVV, T8,XD; + let Constraints = "$src4 = $dst" in + def PTTDPFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_ttdpfp16ps_internal GR16:$src1, GR16:$src2, + GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; + let usesCustomInserter = 1 in + def PTTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), + [(int_x86_ttdpfp16ps timm:$src1, timm:$src2, timm:$src3)]>; +} + +let Predicates = [HasAMXCOMPLEX, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { + let Constraints = "$src1 = $dst" in { + def TTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "ttcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", + []>, VEX, VVVV, T8,XD; + def TTCMMRLFP16PS: I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "ttcmmrlfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", + []>, VEX, VVVV, T8,XS; + def TCONJTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "tconjtcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", + []>, VEX, VVVV, WIG, T8,PS; + } + def TCONJTFP16 : I<0x6b, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), + "tconjtfp16\t{$src, $dst|$dst, $src}", []>, VEX, T8,PD; + + let Constraints = "$src4 = $dst" in { + def PTTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_ttcmmimfp16ps_internal GR16:$src1, GR16:$src2, + GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; + def PTTCMMRLFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_ttcmmrlfp16ps_internal GR16:$src1, GR16:$src2, + GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; + def PTCONJTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_tconjtcmmimfp16ps_internal GR16:$src1, GR16:$src2, + GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; + } + def PTCONJTFP16V : PseudoI<(outs TILE:$dst), (ins GR16:$src1, GR16:$src2, TILE:$src3), + [(set TILE: $dst, (int_x86_tconjtfp16_internal GR16:$src1, GR16:$src2, TILE:$src3))]>; + + let usesCustomInserter = 1 in { + def PTTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), + [(int_x86_ttcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; + def PTTCMMRLFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), + [(int_x86_ttcmmrlfp16ps timm:$src1, timm:$src2, timm:$src3)]>; + def PTCONJTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), + [(int_x86_tconjtcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; + def PTCONJTFP16 : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), + [(int_x86_tconjtfp16 timm:$dst, timm:$src)]>; + } +} + multiclass m_tcvtrowd2ps { let Predicates = [HasAMXAVX512, HasAVX10_2_512, In64BitMode] in { let SchedRW = [WriteSystem] in { diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 08c065c39ee1e3..37a27bb7af3d9c 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -258,7 +258,8 @@ std::pair ShapeCalculator::getShape(IntrinsicInst *II, } break; } - case Intrinsic::x86_ttransposed_internal: { + case Intrinsic::x86_ttransposed_internal: + case Intrinsic::x86_tconjtfp16_internal: { assert((OpNo == 2) && "Illegal Operand Number."); Row = getRowFromCol(II, II->getArgOperand(1), 4); Col = getColFromRow(II, II->getArgOperand(0), 4); @@ -275,6 +276,27 @@ std::pair ShapeCalculator::getShape(IntrinsicInst *II, Col = II->getArgOperand(1); break; } + case Intrinsic::x86_ttdpbf16ps_internal: + case Intrinsic::x86_ttdpfp16ps_internal: + case Intrinsic::x86_ttcmmimfp16ps_internal: + case Intrinsic::x86_ttcmmrlfp16ps_internal: + case Intrinsic::x86_tconjtcmmimfp16ps_internal: { + switch (OpNo) { + case 3: + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + case 4: + Row = getRowFromCol(II, II->getArgOperand(2), 4); + Col = getColFromRow(II, II->getArgOperand(0), 4); + break; + case 5: + Row = getRowFromCol(II, II->getArgOperand(2), 4); + Col = II->getArgOperand(1); + break; + } + break; + } } return std::make_pair(Row, Col); diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 1b2192e3891fc5..d931ffb6c994b0 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -1076,7 +1076,13 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM, case X86::PTDPFP16PSV: case X86::PTCMMIMFP16PSV: case X86::PTCMMRLFP16PSV: - case X86::PTTRANSPOSEDV: { + case X86::PTTRANSPOSEDV: + case X86::PTTDPBF16PSV: + case X86::PTTDPFP16PSV: + case X86::PTTCMMIMFP16PSV: + case X86::PTTCMMRLFP16PSV: + case X86::PTCONJTCMMIMFP16PSV: + case X86::PTCONJTFP16V: { MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2); ShapeT Shape(&MO1, &MO2, MRI); diff --git a/llvm/test/CodeGen/X86/amx_transpose_intrinsics.ll b/llvm/test/CodeGen/X86/amx_transpose_intrinsics.ll index 2025ee94a97405..cc4360317db7db 100644 --- a/llvm/test/CodeGen/X86/amx_transpose_intrinsics.ll +++ b/llvm/test/CodeGen/X86/amx_transpose_intrinsics.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+amx-tile,+amx-bf16,+amx-int8,+amx-transpose | FileCheck %s +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+amx-bf16,+amx-fp16,+amx-complex,+amx-transpose | FileCheck %s define void @test_amx(i32 %rv32, i64 %stride, i64 %rvalue, i8* %addr1, <4 x float> %xmm) #0 { ; CHECK-LABEL: test_amx: @@ -9,12 +9,24 @@ define void @test_amx(i32 %rv32, i64 %stride, i64 %rvalue, i8* %addr1, <4 x floa ; CHECK-NEXT: t2rpntlvwz1 (%rcx,%rsi), %tmm0 ; CHECK-NEXT: t2rpntlvwz1t1 (%rcx,%rsi), %tmm2 ; CHECK-NEXT: ttransposed %tmm3, %tmm1 +; CHECK-NEXT: ttdpbf16ps %tmm3, %tmm2, %tmm1 +; CHECK-NEXT: ttdpfp16ps %tmm6, %tmm5, %tmm4 +; CHECK-NEXT: ttcmmimfp16ps %tmm3, %tmm2, %tmm1 +; CHECK-NEXT: ttcmmrlfp16ps %tmm3, %tmm2, %tmm1 +; CHECK-NEXT: tconjtcmmimfp16ps %tmm3, %tmm2, %tmm1 +; CHECK-NEXT: tconjtfp16 %tmm2, %tmm1 ; CHECK-NEXT: retq call void @llvm.x86.t2rpntlvwz0(i8 1, i8* %addr1, i64 %stride) call void @llvm.x86.t2rpntlvwz0t1(i8 2, i8* %addr1, i64 %stride) call void @llvm.x86.t2rpntlvwz1(i8 1, i8* %addr1, i64 %stride) call void @llvm.x86.t2rpntlvwz1t1(i8 2, i8* %addr1, i64 %stride) call void @llvm.x86.ttransposed(i8 1, i8 3) + call void @llvm.x86.ttdpbf16ps(i8 1, i8 2, i8 3) + call void @llvm.x86.ttdpfp16ps(i8 4, i8 5, i8 6) + call void @llvm.x86.ttcmmimfp16ps(i8 1, i8 2, i8 3) + call void @llvm.x86.ttcmmrlfp16ps(i8 1, i8 2, i8 3) + call void @llvm.x86.tconjtcmmimfp16ps(i8 1, i8 2, i8 3) + call void @llvm.x86.tconjtfp16(i8 1, i8 2) ret void } @@ -23,6 +35,63 @@ declare void @llvm.x86.t2rpntlvwz0t1(i8 %tile1, i8* %addr1, i64 %stride) declare void @llvm.x86.t2rpntlvwz1(i8 %tile1, i8* %addr1, i64 %stride) declare void @llvm.x86.t2rpntlvwz1t1(i8 %tile1, i8* %addr1, i64 %stride) declare void @llvm.x86.ttransposed(i8 %tile0, i8 %tile1) +declare void @llvm.x86.ttdpbf16ps(i8 %tile0, i8 %tile1, i8 %tile2) +declare void @llvm.x86.ttdpfp16ps(i8 %tile0, i8 %tile1, i8 %tile2) +declare void @llvm.x86.ttcmmimfp16ps(i8 %A, i8 %B, i8 %C) +declare void @llvm.x86.ttcmmrlfp16ps(i8 %A, i8 %B, i8 %C) +declare void @llvm.x86.tconjtcmmimfp16ps(i8 %A, i8 %B, i8 %C) +declare void @llvm.x86.tconjtfp16(i8 %A, i8 %B) + +define void @test_amx2(i8* %pointer, i8* %base, i64 %stride) #0 { +; CHECK-LABEL: test_amx2: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rbp +; CHECK-NEXT: subq $2928, %rsp # imm = 0xB70 +; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0 +; CHECK-NEXT: vmovups %zmm0, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, %ax +; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm0 +; CHECK-NEXT: tilezero %tmm1 +; CHECK-NEXT: tilezero %tmm2 +; CHECK-NEXT: ttdpbf16ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: ttdpfp16ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: ttcmmimfp16ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: ttcmmrlfp16ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: movabsq $64, %rbp +; CHECK-NEXT: tilestored %tmm2, 896(%rsp,%rbp) # 1024-byte Folded Spill +; CHECK-NEXT: tileloadd 896(%rsp,%rbp), %tmm3 # 1024-byte Folded Reload +; CHECK-NEXT: tconjtcmmimfp16ps %tmm1, %tmm0, %tmm3 +; CHECK-NEXT: tconjtfp16 %tmm3, %tmm0 +; CHECK-NEXT: tilestored %tmm2, (%rdi,%rdx) +; CHECK-NEXT: addq $2928, %rsp # imm = 0xB70 +; CHECK-NEXT: popq %rbp +; CHECK-NEXT: tilerelease +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + + %a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* %base, i64 %stride) + %b = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8) + %c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8) + %c1 = call x86_amx @llvm.x86.ttdpbf16ps.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b) + %c2 = call x86_amx @llvm.x86.ttdpfp16ps.internal(i16 8, i16 8, i16 8, x86_amx %c1, x86_amx %a, x86_amx %b) + %c3 = call x86_amx @llvm.x86.ttcmmimfp16ps.internal(i16 8, i16 8, i16 8, x86_amx %c2, x86_amx %a, x86_amx %b) + %c4 = call x86_amx @llvm.x86.ttcmmrlfp16ps.internal(i16 8, i16 8, i16 8, x86_amx %c3, x86_amx %a, x86_amx %b) + %c5 = call x86_amx @llvm.x86.tconjtcmmimfp16ps.internal(i16 8, i16 8, i16 8, x86_amx %c4, x86_amx %a, x86_amx %b) + %c6 = call x86_amx @llvm.x86.tconjtfp16.internal(i16 8, i16 8, x86_amx %c5) + + call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %c4) + ret void +} define void @test_amx3(i8* %pointer, i8* %base, i64 %stride) #0 { ; CHECK-LABEL: test_amx3: @@ -146,5 +215,11 @@ declare { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0t1.internal(i16, i16, i16, i8* declare { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz1.internal(i16, i16, i16, i8*, i64) declare { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz1t1.internal(i16, i16, i16, i8*, i64) declare x86_amx @llvm.x86.ttransposed.internal(i16, i16, x86_amx) +declare x86_amx @llvm.x86.ttdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.ttdpfp16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.ttcmmimfp16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.ttcmmrlfp16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tconjtcmmimfp16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tconjtfp16.internal(i16, i16, x86_amx) attributes #0 = { nounwind } diff --git a/llvm/test/MC/Disassembler/X86/amx-transpose-att.txt b/llvm/test/MC/Disassembler/X86/amx-transpose-att.txt index e4f1689639ef9a..8c6f1be80ba2dc 100644 --- a/llvm/test/MC/Disassembler/X86/amx-transpose-att.txt +++ b/llvm/test/MC/Disassembler/X86/amx-transpose-att.txt @@ -56,3 +56,51 @@ # ATT: ttransposed %tmm2, %tmm3 # INTEL: ttransposed tmm3, tmm2 0xc4,0xe2,0x7a,0x5f,0xda + +# ATT: ttdpbf16ps %tmm7, %tmm6, %tmm5 +# INTEL: ttdpbf16ps tmm5, tmm6, tmm7 +0xc4,0xe2,0x42,0x6c,0xee + +# ATT: ttdpbf16ps %tmm1, %tmm2, %tmm3 +# INTEL: ttdpbf16ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x72,0x6c,0xda + +# ATT: ttdpfp16ps %tmm7, %tmm6, %tmm5 +# INTEL: ttdpfp16ps tmm5, tmm6, tmm7 +0xc4,0xe2,0x43,0x6c,0xee + +# ATT: ttdpfp16ps %tmm1, %tmm2, %tmm3 +# INTEL: ttdpfp16ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x73,0x6c,0xda + +# ATT: ttcmmimfp16ps %tmm4, %tmm5, %tmm6 +# INTEL: ttcmmimfp16ps tmm6, tmm5, tmm4 +0xc4,0xe2,0x5b,0x6b,0xf5 + +# ATT: ttcmmimfp16ps %tmm1, %tmm2, %tmm3 +# INTEL: ttcmmimfp16ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x73,0x6b,0xda + +# ATT: ttcmmrlfp16ps %tmm4, %tmm5, %tmm6 +# INTEL: ttcmmrlfp16ps tmm6, tmm5, tmm4 +0xc4,0xe2,0x5a,0x6b,0xf5 + +# ATT: ttcmmrlfp16ps %tmm1, %tmm2, %tmm3 +# INTEL: ttcmmrlfp16ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x72,0x6b,0xda + +# ATT: tconjtcmmimfp16ps %tmm4, %tmm5, %tmm6 +# INTEL: tconjtcmmimfp16ps tmm6, tmm5, tmm4 +0xc4,0xe2,0x58,0x6b,0xf5 + +# ATT: tconjtcmmimfp16ps %tmm1, %tmm2, %tmm3 +# INTEL: tconjtcmmimfp16ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x70,0x6b,0xda + +# ATT: tconjtfp16 %tmm5, %tmm6 +# INTEL: tconjtfp16 tmm6, tmm5 +0xc4,0xe2,0x79,0x6b,0xf5 + +# ATT: tconjtfp16 %tmm2, %tmm3 +# INTEL: tconjtfp16 tmm3, tmm2 +0xc4,0xe2,0x79,0x6b,0xda diff --git a/llvm/test/MC/X86/amx-transpose-att.s b/llvm/test/MC/X86/amx-transpose-att.s index da3fa95ef6dd06..21bbf258ac6ef8 100644 --- a/llvm/test/MC/X86/amx-transpose-att.s +++ b/llvm/test/MC/X86/amx-transpose-att.s @@ -55,3 +55,51 @@ // CHECK: ttransposed %tmm2, %tmm3 // CHECK: encoding: [0xc4,0xe2,0x7a,0x5f,0xda] ttransposed %tmm2, %tmm3 + +// CHECK: ttdpbf16ps %tmm1, %tmm2, %tmm5 +// CHECK: encoding: [0xc4,0xe2,0x72,0x6c,0xea] + ttdpbf16ps %tmm1, %tmm2, %tmm5 + +// CHECK: ttdpbf16ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x72,0x6c,0xda] + ttdpbf16ps %tmm1, %tmm2, %tmm3 + +// CHECK: ttdpfp16ps %tmm3, %tmm4, %tmm5 +// CHECK: encoding: [0xc4,0xe2,0x63,0x6c,0xec] + ttdpfp16ps %tmm3, %tmm4, %tmm5 + +// CHECK: ttdpfp16ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x73,0x6c,0xda] + ttdpfp16ps %tmm1, %tmm2, %tmm3 + +// CHECK: ttcmmimfp16ps %tmm4, %tmm5, %tmm6 +// CHECK: encoding: [0xc4,0xe2,0x5b,0x6b,0xf5] + ttcmmimfp16ps %tmm4, %tmm5, %tmm6 + +// CHECK: ttcmmimfp16ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x73,0x6b,0xda] + ttcmmimfp16ps %tmm1, %tmm2, %tmm3 + +// CHECK: ttcmmrlfp16ps %tmm4, %tmm5, %tmm6 +// CHECK: encoding: [0xc4,0xe2,0x5a,0x6b,0xf5] + ttcmmrlfp16ps %tmm4, %tmm5, %tmm6 + +// CHECK: ttcmmrlfp16ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x72,0x6b,0xda] + ttcmmrlfp16ps %tmm1, %tmm2, %tmm3 + +// CHECK: tconjtcmmimfp16ps %tmm4, %tmm5, %tmm6 +// CHECK: encoding: [0xc4,0xe2,0x58,0x6b,0xf5] + tconjtcmmimfp16ps %tmm4, %tmm5, %tmm6 + +// CHECK: tconjtcmmimfp16ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x70,0x6b,0xda] + tconjtcmmimfp16ps %tmm1, %tmm2, %tmm3 + +// CHECK: tconjtfp16 %tmm5, %tmm6 +// CHECK: encoding: [0xc4,0xe2,0x79,0x6b,0xf5] + tconjtfp16 %tmm5, %tmm6 + +// CHECK: tconjtfp16 %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x79,0x6b,0xda] + tconjtfp16 %tmm2, %tmm3 diff --git a/llvm/test/MC/X86/amx-transpose-intel.s b/llvm/test/MC/X86/amx-transpose-intel.s index 3b8dfaed313d61..a772232ddbbf2e 100644 --- a/llvm/test/MC/X86/amx-transpose-intel.s +++ b/llvm/test/MC/X86/amx-transpose-intel.s @@ -55,3 +55,51 @@ // CHECK: ttransposed tmm3, tmm2 // CHECK: encoding: [0xc4,0xe2,0x7a,0x5f,0xda] ttransposed tmm3, tmm2 + +// CHECK: ttdpbf16ps tmm5, tmm0, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x5a,0x6c,0xe8] + ttdpbf16ps tmm5, tmm0, tmm4 + +// CHECK: ttdpbf16ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x72,0x6c,0xda] + ttdpbf16ps tmm3, tmm2, tmm1 + +// CHECK: ttdpfp16ps tmm1, tmm0, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x5b,0x6c,0xc8] + ttdpfp16ps tmm1, tmm0, tmm4 + +// CHECK: ttdpfp16ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x73,0x6c,0xda] + ttdpfp16ps tmm3, tmm2, tmm1 + +// CHECK: ttcmmimfp16ps tmm6, tmm5, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x5b,0x6b,0xf5] + ttcmmimfp16ps tmm6, tmm5, tmm4 + +// CHECK: ttcmmimfp16ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x73,0x6b,0xda] + ttcmmimfp16ps tmm3, tmm2, tmm1 + +// CHECK: ttcmmrlfp16ps tmm6, tmm5, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x5a,0x6b,0xf5] + ttcmmrlfp16ps tmm6, tmm5, tmm4 + +// CHECK: ttcmmrlfp16ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x72,0x6b,0xda] + ttcmmrlfp16ps tmm3, tmm2, tmm1 + +// CHECK: tconjtcmmimfp16ps tmm6, tmm5, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x58,0x6b,0xf5] + tconjtcmmimfp16ps tmm6, tmm5, tmm4 + +// CHECK: tconjtcmmimfp16ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x70,0x6b,0xda] + tconjtcmmimfp16ps tmm3, tmm2, tmm1 + +// CHECK: tconjtfp16 tmm6, tmm5 +// CHECK: encoding: [0xc4,0xe2,0x79,0x6b,0xf5] + tconjtfp16 tmm6, tmm5 + +// CHECK: tconjtfp16 tmm3, tmm2 +// CHECK: encoding: [0xc4,0xe2,0x79,0x6b,0xda] + tconjtfp16 tmm3, tmm2 From a8a00400bd5cf0b8a2d114708b20ffabd5b9bb9e Mon Sep 17 00:00:00 2001 From: "Wang, Phoebe" Date: Thu, 14 Nov 2024 10:43:03 +0800 Subject: [PATCH 2/2] Address review comments --- clang/lib/Headers/amxbf16transposeintrin.h | 2 +- clang/lib/Headers/amxcomplextransposeintrin.h | 14 ++++++++------ clang/lib/Headers/amxfp16transposeintrin.h | 2 +- llvm/lib/Target/X86/X86ISelLowering.cpp | 8 ++++++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/clang/lib/Headers/amxbf16transposeintrin.h b/clang/lib/Headers/amxbf16transposeintrin.h index 7d31384e317988..86f09f2ad8db2b 100644 --- a/clang/lib/Headers/amxbf16transposeintrin.h +++ b/clang/lib/Headers/amxbf16transposeintrin.h @@ -57,7 +57,7 @@ /// The 1st source tile. Max size is 1024 Bytes. /// \param b /// The 2nd source tile. Max size is 1024 Bytes. -#define _tile_tdpbf16ps(dst, a, b) __builtin_ia32_ttdpbf16ps(dst, a, b) +#define _tile_tdpbf16ps(dst, a, b) __builtin_ia32_ttdpbf16ps((dst), (a), (b)) /// This is internal intrinsic. C/C++ user should avoid calling it directly. static __inline__ _tile1024i __DEFAULT_FN_ATTRS diff --git a/clang/lib/Headers/amxcomplextransposeintrin.h b/clang/lib/Headers/amxcomplextransposeintrin.h index 06fb53e4deadcd..11abaf98e93719 100644 --- a/clang/lib/Headers/amxcomplextransposeintrin.h +++ b/clang/lib/Headers/amxcomplextransposeintrin.h @@ -63,7 +63,8 @@ /// The 1st source tile. Max size is 1024 Bytes. /// \param b /// The 2nd source tile. Max size is 1024 Bytes. -#define _tile_tcmmimfp16ps(dst, a, b) __builtin_ia32_ttcmmimfp16ps(dst, a, b) +#define _tile_tcmmimfp16ps(dst, a, b) \ + __builtin_ia32_ttcmmimfp16ps((dst), (a), (b)) /// Perform matrix multiplication of two tiles containing complex elements and /// accumulate the results into a packed single precision tile. Each dword @@ -108,7 +109,8 @@ /// The 1st source tile. Max size is 1024 Bytes. /// \param b /// The 2nd source tile. Max size is 1024 Bytes. -#define _tile_tcmmrlfp16ps(dst, a, b) __builtin_ia32_ttcmmrlfp16ps(dst, a, b) +#define _tile_tcmmrlfp16ps(dst, a, b) \ + __builtin_ia32_ttcmmrlfp16ps((dst), (a), (b)) /// Perform matrix conjugate transpose and multiplication of two tiles /// containing complex elements and accumulate the results into a packed @@ -155,7 +157,7 @@ /// \param b /// The 2nd source tile. Max size is 1024 Bytes. #define _tile_conjtcmmimfp16ps(dst, a, b) \ - __builtin_ia32_tconjtcmmimfp16ps(dst, a, b) + __builtin_ia32_tconjtcmmimfp16ps((dst), (a), (b)) /// Perform conjugate transpose of an FP16-pair of complex elements from \a a /// and writes the result to \a dst. @@ -184,7 +186,7 @@ /// The destination tile. Max size is 1024 Bytes. /// \param a /// The source tile. Max size is 1024 Bytes. -#define _tile_conjtfp16(dst, a) __builtin_ia32_tconjtfp16(dst, a) +#define _tile_conjtfp16(dst, a) __builtin_ia32_tconjtfp16((dst), (a)) static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmimfp16ps_internal( unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, @@ -204,8 +206,8 @@ static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtcmmimfp16ps_internal( return __builtin_ia32_tconjtcmmimfp16ps_internal(m, n, k, dst, src1, src2); } -static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtfp16_internal( - unsigned short m, unsigned short n, _tile1024i src) { +static __inline__ _tile1024i __DEFAULT_FN_ATTRS +_tile_conjtfp16_internal(unsigned short m, unsigned short n, _tile1024i src) { return __builtin_ia32_tconjtfp16_internal(m, n, src); } diff --git a/clang/lib/Headers/amxfp16transposeintrin.h b/clang/lib/Headers/amxfp16transposeintrin.h index c07c5516301983..191f8c6097a2cc 100644 --- a/clang/lib/Headers/amxfp16transposeintrin.h +++ b/clang/lib/Headers/amxfp16transposeintrin.h @@ -57,7 +57,7 @@ /// The 1st source tile. Max size is 1024 Bytes. /// \param b /// The 2nd source tile. Max size is 1024 Bytes. -#define _tile_tdpfp16ps(dst, a, b) __builtin_ia32_ttdpfp16ps(dst, a, b) +#define _tile_tdpfp16ps(dst, a, b) __builtin_ia32_ttdpfp16ps((dst), (a), (b)) /// This is internal intrinsic. C/C++ user should avoid calling it directly. static __inline__ _tile1024i __DEFAULT_FN_ATTRS diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0a612acee7b1aa..f559c1b15ac712 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -37549,8 +37549,12 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTCONJTCMMIMFP16PS: Opc = X86::TCONJTCMMIMFP16PS; break; - case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; - case X86::PTTMMULTF32PS: Opc = X86::TTMMULTF32PS; break; + case X86::PTMMULTF32PS: + Opc = X86::TMMULTF32PS; + break; + case X86::PTTMMULTF32PS: + Opc = X86::TTMMULTF32PS; + break; } MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc));