From c7cf7558cc92e12fc6b9c36d5c3a231390ab0fc8 Mon Sep 17 00:00:00 2001 From: "Larsen, Steffen" Date: Mon, 31 Mar 2025 05:20:10 -0700 Subject: [PATCH 1/2] [SYCL][Docs] Add sycl_ext_oneapi_ternary_bitwise extension This commit adds the specification and implementation of a new bitwise operation taking three operands. The exact bitwise operation is determined by the LUTIndex template argument, which can be calculated by applying the bitwise operation to a predefined set of arguments. Signed-off-by: Larsen, Steffen --- clang/lib/Driver/ToolChains/Clang.cpp | 3 +- clang/lib/Sema/SPIRVBuiltins.td | 3 + clang/test/Driver/sycl-spirv-ext-old-model.c | 2 + clang/test/Driver/sycl-spirv-ext.c | 2 + .../ClangLinkerWrapper.cpp | 3 +- .../clang-sycl-linker/ClangSYCLLinker.cpp | 3 +- llvm-spirv/include/LLVMSPIRVExtensions.inc | 1 + .../lib/SPIRV/libSPIRV/SPIRVInstruction.h | 55 ++ .../lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 2 + .../SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 1 + .../lib/SPIRV/libSPIRV/spirv_internal.hpp | 7 +- .../bitwise_function.ll | 78 +++ .../sycl_ext_oneapi_ternary_bitwise.asciidoc | 115 ++++ .../oneapi/experimental/ternary_bitwise.hpp | 571 ++++++++++++++++++ sycl/include/sycl/sycl.hpp | 1 + sycl/source/feature_test.hpp.in | 1 + .../test-e2e/Experimental/ternary_bitwise.cpp | 158 +++++ 17 files changed, 1002 insertions(+), 4 deletions(-) create mode 100644 llvm-spirv/test/extensions/INTEL/SPV_INTEL_ternary_bitwise_function/bitwise_function.ll create mode 100644 sycl/doc/extensions/experimental/sycl_ext_oneapi_ternary_bitwise.asciidoc create mode 100644 sycl/include/sycl/ext/oneapi/experimental/ternary_bitwise.hpp create mode 100644 sycl/test-e2e/Experimental/ternary_bitwise.cpp diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp index c789a8253a32a..360c59eb561ef 100644 --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -10852,7 +10852,8 @@ static void getTripleBasedSPIRVTransOpts(Compilation &C, ",+SPV_INTEL_fpga_latency_control" ",+SPV_KHR_shader_clock" ",+SPV_INTEL_bindless_images" - ",+SPV_INTEL_task_sequence"; + ",+SPV_INTEL_task_sequence" + ",+SPV_INTEL_ternary_bitwise_function"; ExtArg = ExtArg + DefaultExtArg + INTELExtArg; if (C.getDriver().IsFPGAHWMode()) // Enable several extensions on FPGA H/W exclusively diff --git a/clang/lib/Sema/SPIRVBuiltins.td b/clang/lib/Sema/SPIRVBuiltins.td index bba559f0f255a..72a0f0ece626b 100644 --- a/clang/lib/Sema/SPIRVBuiltins.td +++ b/clang/lib/Sema/SPIRVBuiltins.td @@ -901,6 +901,9 @@ foreach name = ["BitCount"] in { def : SPVBuiltin; } +def : SPVBuiltin<"BitwiseFunctionINTEL", + [AIGenTypeN, AIGenTypeN, AIGenTypeN, AIGenTypeN, Int]>; + // 3.32.20. Barrier Instructions foreach name = ["ControlBarrier"] in { diff --git a/clang/test/Driver/sycl-spirv-ext-old-model.c b/clang/test/Driver/sycl-spirv-ext-old-model.c index f3c920979841b..d18f41d49db4c 100644 --- a/clang/test/Driver/sycl-spirv-ext-old-model.c +++ b/clang/test/Driver/sycl-spirv-ext-old-model.c @@ -36,6 +36,7 @@ // CHECK-DEFAULT-SAME:,+SPV_KHR_shader_clock // CHECK-DEFAULT-SAME:,+SPV_INTEL_bindless_images // CHECK-DEFAULT-SAME:,+SPV_INTEL_task_sequence +// CHECK-DEFAULT-SAME:,+SPV_INTEL_ternary_bitwise_function // CHECK-DEFAULT-SAME:,+SPV_INTEL_bfloat16_conversion // CHECK-DEFAULT-SAME:,+SPV_INTEL_joint_matrix // CHECK-DEFAULT-SAME:,+SPV_INTEL_hw_thread_queries @@ -73,6 +74,7 @@ // CHECK-CPU-SAME:,+SPV_INTEL_fpga_invocation_pipelining_attributes // CHECK-CPU-SAME:,+SPV_INTEL_fpga_latency_control // CHECK-CPU-SAME:,+SPV_INTEL_task_sequence +// CHECK-CPU-SAME:,+SPV_INTEL_ternary_bitwise_function // CHECK-CPU-SAME:,+SPV_INTEL_bfloat16_conversion // CHECK-CPU-SAME:,+SPV_INTEL_joint_matrix // CHECK-CPU-SAME:,+SPV_INTEL_hw_thread_queries diff --git a/clang/test/Driver/sycl-spirv-ext.c b/clang/test/Driver/sycl-spirv-ext.c index a8394d1ece837..71f7f20a6e768 100644 --- a/clang/test/Driver/sycl-spirv-ext.c +++ b/clang/test/Driver/sycl-spirv-ext.c @@ -53,6 +53,7 @@ // CHECK-DEFAULT-SAME:,+SPV_KHR_shader_clock // CHECK-DEFAULT-SAME:,+SPV_INTEL_bindless_images // CHECK-DEFAULT-SAME:,+SPV_INTEL_task_sequence +// CHECK-DEFAULT-SAME:,+SPV_INTEL_ternary_bitwise_function // CHECK-DEFAULT-SAME:,+SPV_INTEL_bfloat16_conversion // CHECK-DEFAULT-SAME:,+SPV_INTEL_joint_matrix // CHECK-DEFAULT-SAME:,+SPV_INTEL_hw_thread_queries @@ -90,6 +91,7 @@ // CHECK-CPU-SAME:,+SPV_INTEL_fpga_invocation_pipelining_attributes // CHECK-CPU-SAME:,+SPV_INTEL_fpga_latency_control // CHECK-CPU-SAME:,+SPV_INTEL_task_sequence +// CHECK-CPU-SAME:,+SPV_INTEL_ternary_bitwise_function // CHECK-CPU-SAME:,+SPV_INTEL_bfloat16_conversion // CHECK-CPU-SAME:,+SPV_INTEL_joint_matrix // CHECK-CPU-SAME:,+SPV_INTEL_hw_thread_queries diff --git a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp index ded1e3d6a4d5b..a40a9864a9983 100644 --- a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp +++ b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp @@ -873,7 +873,8 @@ getTripleBasedSPIRVTransOpts(const ArgList &Args, ",+SPV_INTEL_fpga_latency_control" ",+SPV_KHR_shader_clock" ",+SPV_INTEL_bindless_images" - ",+SPV_INTEL_task_sequence"; + ",+SPV_INTEL_task_sequence" + ",+SPV_INTEL_ternary_bitwise_function"; ExtArg = ExtArg + DefaultExtArg + INTELExtArg; ExtArg += ",+SPV_INTEL_bfloat16_conversion" ",+SPV_INTEL_joint_matrix" diff --git a/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp b/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp index 3d1fa65da7750..e9974309a54da 100644 --- a/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp +++ b/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp @@ -355,7 +355,8 @@ static void getSPIRVTransOpts(const ArgList &Args, ",+SPV_INTEL_fpga_latency_control" ",+SPV_INTEL_task_sequence" ",+SPV_KHR_shader_clock" - ",+SPV_INTEL_bindless_images"; + ",+SPV_INTEL_bindless_images" + ",+SPV_INTEL_ternary_bitwise_function"; ExtArg = ExtArg + DefaultExtArg + INTELExtArg; ExtArg += ",+SPV_INTEL_token_type" ",+SPV_INTEL_bfloat16_conversion" diff --git a/llvm-spirv/include/LLVMSPIRVExtensions.inc b/llvm-spirv/include/LLVMSPIRVExtensions.inc index 75f83715a4119..c9a7409128d11 100644 --- a/llvm-spirv/include/LLVMSPIRVExtensions.inc +++ b/llvm-spirv/include/LLVMSPIRVExtensions.inc @@ -76,3 +76,4 @@ EXT(SPV_INTEL_maximum_registers) EXT(SPV_INTEL_bindless_images) EXT(SPV_INTEL_2d_block_io) EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate) +EXT(SPV_INTEL_ternary_bitwise_function) diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 21d1f1ae67dd9..3c3f62f99fbf9 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -4446,5 +4446,60 @@ class SPIRVSubgroupMatrixMultiplyAccumulateINTELInst _SPIRV_OP(SubgroupMatrixMultiplyAccumulate, true, 7, true, 4) #undef _SPIRV_OP +class SPIRVTernaryBitwiseFunctionINTELInst : public SPIRVInstTemplateBase { +public: + void validate() const override { + SPIRVInstruction::validate(); + SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog(); + std::string InstName = "BitwiseFunctionINTEL"; + + const SPIRVType *ResTy = this->getType(); + SPVErrLog.checkError( + ResTy->isTypeInt() || (ResTy->isTypeVector() && + ResTy->getVectorComponentType()->isTypeInt()), + SPIRVEC_InvalidInstruction, + InstName + "\nResult type must be an integer scalar or vector.\n"); + + auto CommonArgCheck = [this, ResTy, &InstName, + &SPVErrLog](size_t ArgI, const char *ArgPlacement) { + SPIRVValue *Arg = + const_cast(this)->getOperand( + ArgI); + SPVErrLog.checkError( + Arg->getType() == ResTy, SPIRVEC_InvalidInstruction, + InstName + "\n" + ArgPlacement + + " argument must be the same as the result type.\n"); + }; + + CommonArgCheck(0, "First"); + CommonArgCheck(1, "Second"); + CommonArgCheck(2, "Third"); + + SPIRVValue *LUTIndexArg = + const_cast(this)->getOperand(3); + const SPIRVType *LUTIndexArgTy = LUTIndexArg->getType(); + SPVErrLog.checkError( + LUTIndexArgTy->isTypeInt(32), SPIRVEC_InvalidInstruction, + InstName + "\nFourth argument must be a 32-bit integer scalar.\n"); + SPVErrLog.checkError( + isConstantOpCode(LUTIndexArg->getOpCode()), SPIRVEC_InvalidInstruction, + InstName + "\nFourth argument must be constant instruction.\n"); + } + + std::optional getRequiredExtension() const override { + return ExtensionID::SPV_INTEL_ternary_bitwise_function; + } + SPIRVCapVec getRequiredCapability() const override { + return getVec(internal::CapabilityTernaryBitwiseFunctionINTEL); + } +}; + +#define _SPIRV_OP(x, ...) \ + typedef SPIRVInstTemplate \ + SPIRV##x##INTEL; +_SPIRV_OP(BitwiseFunction, true, 7) +#undef _SPIRV_OP + } // namespace SPIRV #endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index 0a92b2b81d70e..32d31dea69234 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -696,6 +696,8 @@ template <> inline void SPIRVMap::init() { "SubgroupRequirementsINTEL"); add(internal::CapabilityTaskSequenceINTEL, "TaskSequenceINTEL"); add(internal::CapabilityBindlessImagesINTEL, "BindlessImagesINTEL"); + add(internal::CapabilityTernaryBitwiseFunctionINTEL, + "TernaryBitwiseFunctionINTEL"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index f1187cfa23226..6adbf255d6967 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -44,3 +44,4 @@ _SPIRV_OP_INTERNAL(ConvertHandleToSamplerINTEL, internal::ConvertHandleToSamplerINTEL) _SPIRV_OP_INTERNAL(ConvertHandleToSampledImageINTEL, internal::ConvertHandleToSampledImageINTEL) +_SPIRV_OP_INTERNAL(BitwiseFunctionINTEL, internal::BitwiseFunctionINTEL) diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp b/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp index 93a85e01e440e..2ded482aca87c 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -89,6 +89,7 @@ enum InternalOp { IOpConvertHandleToImageINTEL = 6529, IOpConvertHandleToSamplerINTEL = 6530, IOpConvertHandleToSampledImageINTEL = 6531, + IOpBitwiseFunctionINTEL = 6242, IOpPrev = OpMax - 2, IOpForward }; @@ -124,7 +125,8 @@ enum InternalCapability { ICapabilityJointMatrixPackedInt4ComponentTypeINTEL = 6439, ICapabilityCacheControlsINTEL = 6441, ICapabilitySubgroupRequirementsINTEL = 6445, - ICapabilityBindlessImagesINTEL = 6528 + ICapabilityBindlessImagesINTEL = 6528, + ICapabilityTernaryBitwiseFunctionINTEL = 6241 }; enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 }; @@ -222,6 +224,9 @@ _SPIRV_OP(Capability, BindlessImagesINTEL) _SPIRV_OP(Op, ConvertHandleToImageINTEL) _SPIRV_OP(Op, ConvertHandleToSamplerINTEL) _SPIRV_OP(Op, ConvertHandleToSampledImageINTEL) + +_SPIRV_OP(Capability, TernaryBitwiseFunctionINTEL) +_SPIRV_OP(Op, BitwiseFunctionINTEL) #undef _SPIRV_OP constexpr SourceLanguage SourceLanguagePython = diff --git a/llvm-spirv/test/extensions/INTEL/SPV_INTEL_ternary_bitwise_function/bitwise_function.ll b/llvm-spirv/test/extensions/INTEL/SPV_INTEL_ternary_bitwise_function/bitwise_function.ll new file mode 100644 index 0000000000000..d5b087b370e41 --- /dev/null +++ b/llvm-spirv/test/extensions/INTEL/SPV_INTEL_ternary_bitwise_function/bitwise_function.ll @@ -0,0 +1,78 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_ternary_bitwise_function -o %t.spv +; RUN: llvm-spirv %t.spv --to-text -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM + +; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension: +; CHECK-ERROR-NEXT: SPV_INTEL_ternary_bitwise_function + +; CHECK-SPIRV-NOT: Name [[#]] "_Z28__spirv_BitwiseFunctionINTELiiij" +; CHECK-SPIRV-NOT: Name [[#]] "_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_j" + +; CHECK-SPIRV-DAG: Capability TernaryBitwiseFunctionINTEL +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_ternary_bitwise_function" + +; CHECK-SPIRV-DAG: TypeInt [[#TYPEINT:]] 32 0 +; CHECK-SPIRV-DAG: TypeVector [[#TYPEINTVEC4:]] [[#TYPEINT]] 4 +; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#ScalarLUT:]] 24 +; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#VecLUT:]] 42 + +; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarA:]] +; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarB:]] +; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarC:]] +; CHECK-SPIRV: BitwiseFunctionINTEL [[#TYPEINT]] {{.*}} [[#ScalarA]] [[#ScalarB]] [[#ScalarC]] [[#ScalarLUT]] +; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecA:]] +; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecB:]] +; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecC:]] +; CHECK-SPIRV: BitwiseFunctionINTEL [[#TYPEINTVEC4]] {{.*}} [[#VecA]] [[#VecB]] [[#VecC]] [[#VecLUT]] + +; CHECK-LLVM: %[[ScalarA:.*]] = load i32, ptr +; CHECK-LLVM: %[[ScalarB:.*]] = load i32, ptr +; CHECK-LLVM: %[[ScalarC:.*]] = load i32, ptr +; CHECK-LLVM: call spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32 %[[ScalarA]], i32 %[[ScalarB]], i32 %[[ScalarC]], i32 24) +; CHECK-LLVM: %[[VecA:.*]] = load <4 x i32>, ptr +; CHECK-LLVM: %[[VecB:.*]] = load <4 x i32>, ptr +; CHECK-LLVM: %[[VecC:.*]] = load <4 x i32>, ptr +; CHECK-LLVM: call spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32> %[[VecA]], <4 x i32> %[[VecB]], <4 x i32> %[[VecC]], i32 42) + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir" + +; Function Attrs: nounwind readnone +define spir_kernel void @fooScalar() { +entry: + %argA = alloca i32 + %argB = alloca i32 + %argC = alloca i32 + %A = load i32, ptr %argA + %B = load i32, ptr %argB + %C = load i32, ptr %argC + %res = call spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32 %A, i32 %B, i32 %C, i32 24) + ret void +} + +; Function Attrs: nounwind readnone +define spir_kernel void @fooVec() { +entry: + %argA = alloca <4 x i32> + %argB = alloca <4 x i32> + %argC = alloca <4 x i32> + %A = load <4 x i32>, ptr %argA + %B = load <4 x i32>, ptr %argB + %C = load <4 x i32>, ptr %argC + %res = call spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C, i32 42) + ret void +} + +declare dso_local spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32, i32, i32, i32) +declare dso_local spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32>, <4 x i32>, <4 x i32>, i32) + +!llvm.module.flags = !{!0} +!opencl.spir.version = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 1, i32 2} diff --git a/sycl/doc/extensions/experimental/sycl_ext_oneapi_ternary_bitwise.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_oneapi_ternary_bitwise.asciidoc new file mode 100644 index 0000000000000..b6b370e3c2f46 --- /dev/null +++ b/sycl/doc/extensions/experimental/sycl_ext_oneapi_ternary_bitwise.asciidoc @@ -0,0 +1,115 @@ += sycl_ext_oneapi_ternary_bitwise + +:source-highlighter: coderay +:coderay-linenums-mode: table + +// This section needs to be after the document title. +:doctype: book +:toc2: +:toc: left +:encoding: utf-8 +:lang: en +:dpcpp: pass:[DPC++] +:endnote: —{nbsp}end{nbsp}note + +// Set the default source code type in this document to C++, +// for syntax highlighting purposes. This is needed because +// docbook uses c++ and html5 uses cpp. +:language: {basebackend@docbook:c++:cpp} + + +== Notice + +[%hardbreaks] +Copyright (C) 2025 Intel Corporation. All rights reserved. + +Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks +of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by +permission by Khronos. + + +== Contact + +To report problems with this extension, please open a new issue at: + +https://github.com/intel/llvm/issues + + +== Dependencies + +This extension is written against the SYCL 2020 revision 9 specification. All +references below to the "core SYCL specification" or to section numbers in the +SYCL specification refer to that revision. + + +== Status + +This is an experimental extension specification, intended to provide early +access to features and gather community feedback. Interfaces defined in this +specification are implemented in {dpcpp}, but they are not finalized and may +change incompatibly in future versions of {dpcpp} without prior notice. +*Shipping software products should not rely on APIs defined in this +specification.* + + +== Overview + +Some hardware offers efficient bitwise operations on three arguments. To expose +these bitwise operations in SYCL, this extension adds a new `ternary_bitwise` +function, where the bitwise operation computed is controlled through a +look-up table (LUT) index computed by applying the bitwise operation to a +predefined set of operands. + + +== Specification + +=== Feature test macro + +This extension provides a feature-test macro as described in the core SYCL +specification. An implementation supporting this extension must predefine the +macro `SYCL_EXT_ONEAPI_TERNARY_BITWISE` to one of the values defined in the table +below. Applications can test for the existence of this macro to determine if +the implementation supports this feature, or applications can test the macro's +value to determine which of the extension's features the implementation +supports. + +[%header,cols="1,5"] +|=== +|Value +|Description + +|1 +|The APIs of this experimental extension are not versioned, so the + feature-test macro always has this value. +|=== + +=== New `ternary_bitwise` function + + +|==== +a| +[frame=all,grid=none] +!==== +a! +[source] +---- +namespace sycl::ext::oneapi::experimental { + +template T ternary_bitwise(T A, T B, T C) + +} // namespace sycl::ext::oneapi::experimental +---- +!==== + +_Constraints:_ The type `T` must be a generic integer type, as listed in section +link:https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:integer-functions[4.17.7. Integer functions] +of the SYCL 2020 specification. + +_Returns:_ The result of the bitwise operation identified by `LUTIndex` applied +to `a`, `b` and `c`. For any `I` it holds that +`ternary_bitwise(0xAA, 0xCC, 0xF0) == I` evaluates to `true, so the +`LUTIndex` for a bitwise operation can be determined by applying the bitwise +operation to `0xAA`, `0xCC` and `0xF0` in place of `a`, `b` and `c` +respectively. +|==== + diff --git a/sycl/include/sycl/ext/oneapi/experimental/ternary_bitwise.hpp b/sycl/include/sycl/ext/oneapi/experimental/ternary_bitwise.hpp new file mode 100644 index 0000000000000..6e83cb0c4614e --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/ternary_bitwise.hpp @@ -0,0 +1,571 @@ +//==- ternary_bitwise.hpp --- SYCL extension for ternary bitwise functions -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include +#include + +#pragma once + +namespace sycl { +inline namespace _V1 { +namespace detail { +#if !defined(__SYCL_DEVICE_ONLY__) || defined(__NVPTX__) || defined(__AMDGCN__) +// Host implementation of the ternary bitwise operation LUT. +template T applyTernaryBitwise(T A, T B, T C) { + switch (LUTIndex) { + case 0: + return T(0); + case 1: + return (~(~C & B) & (~C & ~A)); + case 2: + return (~(~C & B) & (~C & A)); + case 3: + return (~C & ~B); + case 4: + return ~(~(~C & B) | (~C & A)); + case 5: + return (~C & ~A); + case 6: + return (~(~C & B) ^ ~(~C & A)); + case 7: + return ((~C & ~B) | (~C & ~A)); + case 8: + return ~(~(~C & B) | ~(~C & A)); + case 9: + return ~(~(~C & B) ^ (~C & ~A)); + case 10: + return (~C & A); + case 11: + return ((~C & ~B) | (~C & A)); + case 12: + return (~C & B); + case 13: + return ~(~(~C & B) & ~(~C & ~A)); + case 14: + return ~(~(~C & B) & ~(~C & A)); + case 15: + return ~C; + case 16: + return (~(~C | B) & ~(~C | A)); + case 17: + return (~B & ~A); + case 18: + return ~(~(~C | B) ^ ~(~B & A)); + case 19: + return ((~C & ~B) | (~B & ~A)); + case 20: + return ~(~(~C ^ ~B) | A); + case 21: + return ((~C | ~B) & ~A); + case 22: + return ~((~(~C ^ ~B) & ~(~C & A)) | ~(~(~C ^ ~B) ^ ~A)); + case 23: + return ~(((~C & ~B) | ~(~C ^ ~A)) ^ ~(~(~C | B) | ~(~C | ~A))); + case 24: + return ~(~(~C ^ ~B) | ~(~C ^ ~A)); + case 25: + return ((~C | ~B) & ~(~B ^ ~A)); + case 26: + return ((~C | ~B) & (~C ^ ~A)); + case 27: + return ~(~(~C & A) ^ (~B & ~A)); + case 28: + return ~(~(~C ^ ~B) | (~B & A)); + case 29: + return ~(~(~C & B) ^ (~B & ~A)); + case 30: + return (~(~C ^ ~B) ^ ~(~B & A)); + case 31: + return ((~C | ~B) & (~C | ~A)); + case 32: + return (~(~C | B) & (~C | A)); + case 33: + return (~(~C | B) ^ (~B & ~A)); + case 34: + return (~B & A); + case 35: + return ((~C & ~B) | (~B & A)); + case 36: + return ~(~(~C ^ ~B) | (~C ^ ~A)); + case 37: + return ((~C | ~B) & ~(~C ^ ~A)); + case 38: + return ((~C | ~B) & (~B ^ ~A)); + case 39: + return (~(~C | A) ^ (~B | ~A)); + case 40: + return ~(~(~C ^ ~B) | ~A); + case 41: + return (~(~(~C ^ ~B) & (~C | A)) ^ ~A); + case 42: + return ((~C | ~B) & A); + case 43: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ ((~C | ~B) & ~A)); + case 44: + return ~(~(~C ^ ~B) | ~(~C | A)); + case 45: + return ~(~(~C ^ ~B) ^ (~B & ~A)); + case 46: + return (~(~C & B) ^ ~(~B & A)); + case 47: + return ((~C | ~B) & (~C | A)); + case 48: + return ~(~C | B); + case 49: + return (~(~C | B) | (~B & ~A)); + case 50: + return (~(~C | B) | (~B & A)); + case 51: + return ~B; + case 52: + return ~(~(~C ^ ~B) | (~C & A)); + case 53: + return (~(~C | B) ^ (~C & ~A)); + case 54: + return (~(~C ^ ~B) ^ ~(~C & A)); + case 55: + return ((~C | ~B) & (~B | ~A)); + case 56: + return ~(~(~C ^ ~B) | ~(~B | A)); + case 57: + return ~(~(~C ^ ~B) ^ (~C & ~A)); + case 58: + return ~(~(~C | B) ^ ~(~C & A)); + case 59: + return ((~C | ~B) & (~B | A)); + case 60: + return (~C ^ ~B); + case 61: + return ~(~(~C ^ ~B) & ~(~C & ~A)); + case 62: + return ~(~(~C ^ ~B) & ~(~C & A)); + case 63: + return (~C | ~B); + case 64: + return (~(~C & B) & ~(~B | A)); + case 65: + return (~(~C ^ ~B) & ~A); + case 66: + return (~(~C ^ ~B) & (~C ^ ~A)); + case 67: + return (~(~C ^ ~B) & (~C | ~A)); + case 68: + return ~(~B | A); + case 69: + return ~(~(~C | B) | A); + case 70: + return ~(~(~C | B) | ~(~B ^ ~A)); + case 71: + return (~(~C | B) ^ (~B | ~A)); + case 72: + return ~(~(~C & B) ^ ~(~B | A)); + case 73: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ (~B | ~A)); + case 74: + return ~(~(~C | B) | ~(~C ^ ~A)); + case 75: + return ~(~(~C ^ ~B) ^ (~B | ~A)); + case 76: + return ~(~(~C & B) & (~B | A)); + case 77: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ (~(~C & B) & (~B | ~A))); + case 78: + return ~(~(~C & A) ^ ~(~B | A)); + case 79: + return ~(~(~C | B) | ~(~C | ~A)); + case 80: + return ~(~C | A); + case 81: + return (~(~C & B) & ~A); + case 82: + return (~(~C & B) & (~C ^ ~A)); + case 83: + return ~(~(~C & B) ^ (~C | ~A)); + case 84: + return ~((~C & ~B) | A); + case 85: + return ~A; + case 86: + return (~(~C & B) ^ ~(~C ^ ~A)); + case 87: + return ((~C & ~B) | ~A); + case 88: + return ~((~C & ~B) | ~(~C ^ ~A)); + case 89: + return ~(~(~C & B) ^ ~A); + case 90: + return (~C ^ ~A); + case 91: + return ((~C & ~B) | (~C ^ ~A)); + case 92: + return ~(~(~C & B) ^ ~(~C | A)); + case 93: + return ~(~(~C & B) & A); + case 94: + return ~(~(~C & B) & ~(~C ^ ~A)); + case 95: + return (~C | ~A); + case 96: + return (~(~C | B) ^ ~(~C | A)); + case 97: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ (~C | ~A)); + case 98: + return (~(~C & B) & (~B ^ ~A)); + case 99: + return ~(~(~C ^ ~B) ^ (~C | ~A)); + case 100: + return ~((~C & ~B) | ~(~B ^ ~A)); + case 101: + return (~(~C | B) ^ ~A); + case 102: + return (~B ^ ~A); + case 103: + return ((~C & ~B) | (~B ^ ~A)); + case 104: + return ((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C & B) ^ ~(~C ^ ~A))); + case 105: + return ~(~(~C ^ ~B) ^ ~A); + case 106: + return (~(~C & B) ^ ~(~B ^ ~A)); + case 107: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ ~A); + case 108: + return ~(~(~C ^ ~B) ^ ~(~C | A)); + case 109: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C & B) ^ (~C | ~A))); + case 110: + return ~(~(~C & B) & ~(~B ^ ~A)); + case 111: + return (~(~C | B) ^ (~C | ~A)); + case 112: + return (~(~C | B) | ~(~C | A)); + case 113: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C | B) | ~(~C | ~A))); + case 114: + return ~(~(~C | A) ^ ~(~B & A)); + case 115: + return (~(~C & B) & (~B | ~A)); + case 116: + return (~(~C | B) ^ ~(~B | A)); + case 117: + return (~(~C | B) | ~A); + case 118: + return (~(~C | B) | (~B ^ ~A)); + case 119: + return (~B | ~A); + case 120: + return ~(~(~C ^ ~B) ^ ~(~B | A)); + case 121: + return ((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C | B) ^ (~B | ~A))); + case 122: + return (~(~C | B) | (~C ^ ~A)); + case 123: + return ~(~(~C & B) ^ (~B | ~A)); + case 124: + return ~(~(~C ^ ~B) & (~C | A)); + case 125: + return ~(~(~C ^ ~B) & A); + case 126: + return ~(~(~C ^ ~B) & ~(~C ^ ~A)); + case 127: + return ~(~(~C & B) & ~(~B | ~A)); + case 128: + return (~(~C & B) & ~(~B | ~A)); + case 129: + return (~(~C ^ ~B) & ~(~C ^ ~A)); + case 130: + return (~(~C ^ ~B) & A); + case 131: + return (~(~C ^ ~B) & (~C | A)); + case 132: + return (~(~C & B) ^ (~B | ~A)); + case 133: + return ~(~(~C | B) | (~C ^ ~A)); + case 134: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C | B) ^ (~B | ~A))); + case 135: + return (~(~C ^ ~B) ^ ~(~B | A)); + case 136: + return ~(~B | ~A); + case 137: + return ~(~(~C | B) | (~B ^ ~A)); + case 138: + return ~(~(~C | B) | ~A); + case 139: + return ~(~(~C | B) ^ ~(~B | A)); + case 140: + return ~(~(~C & B) & (~B | ~A)); + case 141: + return (~(~C | A) ^ ~(~B & A)); + case 142: + return ((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C | B) | ~(~C | ~A))); + case 143: + return ~(~(~C | B) | ~(~C | A)); + case 144: + return ~(~(~C | B) ^ (~C | ~A)); + case 145: + return (~(~C & B) & ~(~B ^ ~A)); + case 146: + return ((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C & B) ^ (~C | ~A))); + case 147: + return (~(~C ^ ~B) ^ ~(~C | A)); + case 148: + return ((~(~C ^ ~B) & ~(~C & A)) ^ ~A); + case 149: + return ~(~(~C & B) ^ ~(~B ^ ~A)); + case 150: + return (~(~C ^ ~B) ^ ~A); + case 151: + return ~((~(~C ^ ~B) & ~(~C & A)) ^ ~(~(~C & B) ^ ~(~C ^ ~A))); + case 152: + return ~((~C & ~B) | (~B ^ ~A)); + case 153: + return ~(~B ^ ~A); + case 154: + return ~(~(~C | B) ^ ~A); + case 155: + return ((~C & ~B) | ~(~B ^ ~A)); + case 156: + return (~(~C ^ ~B) ^ (~C | ~A)); + case 157: + return ~(~(~C & B) & (~B ^ ~A)); + case 158: + return ((~(~C ^ ~B) & ~(~C & A)) ^ (~C | ~A)); + case 159: + return ~(~(~C | B) ^ ~(~C | A)); + case 160: + return ~(~C | ~A); + case 161: + return (~(~C & B) & ~(~C ^ ~A)); + case 162: + return (~(~C & B) & A); + case 163: + return (~(~C & B) ^ ~(~C | A)); + case 164: + return ~((~C & ~B) | (~C ^ ~A)); + case 165: + return ~(~C ^ ~A); + case 166: + return (~(~C & B) ^ ~A); + case 167: + return ((~C & ~B) | ~(~C ^ ~A)); + case 168: + return ~((~C & ~B) | ~A); + case 169: + return ~(~(~C & B) ^ ~(~C ^ ~A)); + case 170: + return A; + case 171: + return ((~C & ~B) | A); + case 172: + return (~(~C & B) ^ (~C | ~A)); + case 173: + return ~(~(~C & B) & (~C ^ ~A)); + case 174: + return ~(~(~C & B) & ~A); + case 175: + return (~C | A); + case 176: + return (~(~C | B) | ~(~C | ~A)); + case 177: + return (~(~C & A) ^ ~(~B | A)); + case 178: + return ((~(~C ^ ~B) & ~(~C & A)) ^ (~(~C & B) & (~B | ~A))); + case 179: + return (~(~C & B) & (~B | A)); + case 180: + return (~(~C ^ ~B) ^ (~B | ~A)); + case 181: + return (~(~C | B) | ~(~C ^ ~A)); + case 182: + return ((~(~C ^ ~B) & ~(~C & A)) ^ (~B | ~A)); + case 183: + return (~(~C & B) ^ ~(~B | A)); + case 184: + return ~(~(~C | B) ^ (~B | ~A)); + case 185: + return (~(~C | B) | ~(~B ^ ~A)); + case 186: + return (~(~C | B) | A); + case 187: + return (~B | A); + case 188: + return ~(~(~C ^ ~B) & (~C | ~A)); + case 189: + return ~(~(~C ^ ~B) & (~C ^ ~A)); + case 190: + return ~(~(~C ^ ~B) & ~A); + case 191: + return ~(~(~C & B) & ~(~B | A)); + case 192: + return ~(~C | ~B); + case 193: + return (~(~C ^ ~B) & ~(~C & A)); + case 194: + return (~(~C ^ ~B) & ~(~C & ~A)); + case 195: + return ~(~C ^ ~B); + case 196: + return ~((~C | ~B) & (~B | A)); + case 197: + return (~(~C | B) ^ ~(~C & A)); + case 198: + return (~(~C ^ ~B) ^ (~C & ~A)); + case 199: + return (~(~C ^ ~B) | ~(~B | A)); + case 200: + return ~((~C | ~B) & (~B | ~A)); + case 201: + return ~(~(~C ^ ~B) ^ ~(~C & A)); + case 202: + return ~(~(~C | B) ^ (~C & ~A)); + case 203: + return (~(~C ^ ~B) | (~C & A)); + case 204: + return B; + case 205: + return ~(~(~C | B) | (~B & A)); + case 206: + return ~(~(~C | B) | (~B & ~A)); + case 207: + return (~C | B); + case 208: + return ~((~C | ~B) & (~C | A)); + case 209: + return ~(~(~C & B) ^ ~(~B & A)); + case 210: + return (~(~C ^ ~B) ^ (~B & ~A)); + case 211: + return (~(~C ^ ~B) | ~(~C | A)); + case 212: + return ((~(~C ^ ~B) & ~(~C & A)) ^ ((~C | ~B) & ~A)); + case 213: + return ~((~C | ~B) & A); + case 214: + return ~(~(~(~C ^ ~B) & (~C | A)) ^ ~A); + case 215: + return (~(~C ^ ~B) | ~A); + case 216: + return ~(~(~C | A) ^ (~B | ~A)); + case 217: + return ~((~C | ~B) & (~B ^ ~A)); + case 218: + return ~((~C | ~B) & ~(~C ^ ~A)); + case 219: + return (~(~C ^ ~B) | (~C ^ ~A)); + case 220: + return ~((~C & ~B) | (~B & A)); + case 221: + return ~(~B & A); + case 222: + return ~(~(~C | B) ^ (~B & ~A)); + case 223: + return ~(~(~C | B) & (~C | A)); + case 224: + return ~((~C | ~B) & (~C | ~A)); + case 225: + return ~(~(~C ^ ~B) ^ ~(~B & A)); + case 226: + return (~(~C & B) ^ (~B & ~A)); + case 227: + return (~(~C ^ ~B) | (~B & A)); + case 228: + return (~(~C & A) ^ (~B & ~A)); + case 229: + return ~((~C | ~B) & (~C ^ ~A)); + case 230: + return ~((~C | ~B) & ~(~B ^ ~A)); + case 231: + return (~(~C ^ ~B) | ~(~C ^ ~A)); + case 232: + return (((~C & ~B) | ~(~C ^ ~A)) ^ ~(~(~C | B) | ~(~C | ~A))); + case 233: + return ((~(~C ^ ~B) & ~(~C & A)) | ~(~(~C ^ ~B) ^ ~A)); + case 234: + return ~((~C | ~B) & ~A); + case 235: + return (~(~C ^ ~B) | A); + case 236: + return ~((~C & ~B) | (~B & ~A)); + case 237: + return (~(~C | B) ^ ~(~B & A)); + case 238: + return ~(~B & ~A); + case 239: + return ~(~(~C | B) & ~(~C | A)); + case 240: + return C; + case 241: + return (~(~C & B) & ~(~C & A)); + case 242: + return (~(~C & B) & ~(~C & ~A)); + case 243: + return ~(~C & B); + case 244: + return ~((~C & ~B) | (~C & A)); + case 245: + return ~(~C & A); + case 246: + return (~(~C & B) ^ (~C & ~A)); + case 247: + return (~(~C & B) | ~(~C & A)); + case 248: + return ~((~C & ~B) | (~C & ~A)); + case 249: + return ~(~(~C & B) ^ ~(~C & A)); + case 250: + return ~(~C & ~A); + case 251: + return (~(~C & B) | (~C & A)); + case 252: + return ~(~C & ~B); + case 253: + return ~(~(~C & B) & (~C & A)); + case 254: + return ~(~(~C & B) & (~C & ~A)); + case 255: + return T(1); + } +} +#endif +} // namespace detail + +namespace ext::oneapi::experimental { + +template +sycl::detail::builtin_enable_integer_t ternary_bitwise(T A, T B, T C) { + if constexpr (sycl::detail::is_marray_v) { + return sycl::detail::builtin_marray_impl( + [](auto SA, auto SB, auto SC) { + return ternary_bitwise(SA, SB, SC); + }, + A, B, C); + } else { +#if defined(__SYCL_DEVICE_ONLY__) && !defined(__NVPTX__) && !defined(__AMDGCN__) + // TODO: Implement __spirv_BitwiseFunctionINTEL for NVPTX and AMDGCN. + return __spirv_BitwiseFunctionINTEL( + sycl::detail::simplify_if_swizzle_t{A}, + sycl::detail::simplify_if_swizzle_t{B}, + sycl::detail::simplify_if_swizzle_t{C}, + static_cast(LUTIndex)); +#else + return sycl::detail::applyTernaryBitwise( + sycl::detail::simplify_if_swizzle_t{A}, + sycl::detail::simplify_if_swizzle_t{B}, + sycl::detail::simplify_if_swizzle_t{C}); +#endif + } +} +} // namespace ext::oneapi::experimental +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp index bbedde0f2e188..7f1323d8f9448 100644 --- a/sycl/include/sycl/sycl.hpp +++ b/sycl/include/sycl/sycl.hpp @@ -106,6 +106,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/source/feature_test.hpp.in b/sycl/source/feature_test.hpp.in index cccc8d462be50..1a736ccbdd1b0 100644 --- a/sycl/source/feature_test.hpp.in +++ b/sycl/source/feature_test.hpp.in @@ -116,6 +116,7 @@ inline namespace _V1 { #define SYCL_EXT_ONEAPI_NUM_COMPUTE_UNITS 1 #define SYCL_EXT_ONEAPI_DEVICE_IMAGE_BACKEND_CONTENT 1 #define SYCL_EXT_ONEAPI_CURRENT_DEVICE 1 +#define SYCL_EXT_ONEAPI_TERNARY_BITWISE 1 #define SYCL_KHR_FREE_FUNCTION_COMMANDS 1 // In progress yet #define SYCL_EXT_ONEAPI_ATOMIC16 0 diff --git a/sycl/test-e2e/Experimental/ternary_bitwise.cpp b/sycl/test-e2e/Experimental/ternary_bitwise.cpp new file mode 100644 index 0000000000000..c19cd8c973339 --- /dev/null +++ b/sycl/test-e2e/Experimental/ternary_bitwise.cpp @@ -0,0 +1,158 @@ +// REQUIRES: aspect-usm_shared_allocations + +// XFAIL: opencl && cpu +// XFAIL-TRACKER: TODO + +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Checks the results of the ternary bitwise function extension. + +#include +#include +#include + +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +constexpr size_t NumOps = 256; + +template +std::array apply(T *A, T *B, T *C, std::index_sequence) { + return std::array{ + syclex::ternary_bitwise(A[Is], B[Is], C[Is])...}; +} + +template void fillRandom(T *Dest) { + std::random_device RDev; + std::mt19937 RNG(RDev()); + std::uniform_int_distribution Dist; + for (size_t I = 0; I < NumOps; ++I) + Dest[I] = Dist(RNG); +} + +template void fillRandom(sycl::vec *Dest) { + std::random_device RDev; + std::mt19937 RNG(RDev()); + std::uniform_int_distribution Dist; + for (size_t I = 0; I < NumOps; ++I) + for (size_t J = 0; J < N; ++J) + Dest[I][J] = Dist(RNG); +} + +template void fillRandom(sycl::marray *Dest) { + std::random_device RDev; + std::mt19937 RNG(RDev()); + std::uniform_int_distribution Dist; + for (size_t I = 0; I < NumOps; ++I) + for (size_t J = 0; J < N; ++J) + Dest[I][J] = Dist(RNG); +} + +bool allTrue(bool B) { return B; } + +template bool allTrue(sycl::vec B) { + for (size_t I = 0; I < N; ++I) + if (!static_cast(B[I])) + return false; + return true; +} + +template bool allTrue(sycl::marray B) { + return std::all_of(B.begin(), B.end(), [](bool b) { return b; }); +} + +template std::string toString(T X) { return std::to_string(X); } + +template std::string toString(sycl::vec X) { + std::string Result = "{" + toString(X[0]); + for (size_t I = 1; I < N; ++I) + Result += "," + toString(X[I]); + return Result + "}"; +} + +template std::string toString(sycl::marray X) { + std::string Result = "{" + toString(X[0]); + for (size_t I = 1; I < N; ++I) + Result += "," + toString(X[I]); + return Result + "}"; +} + +template int Check(sycl::queue &Q, std::string_view TName) { + constexpr auto IdxSeq = std::make_index_sequence{}; + + int Failed = 0; + + T *A = sycl::malloc_shared(NumOps, Q); + T *B = sycl::malloc_shared(NumOps, Q); + T *C = sycl::malloc_shared(NumOps, Q); + auto *Out = sycl::malloc_shared>(1, Q); + + fillRandom(A); + fillRandom(B); + fillRandom(C); + + Q.single_task([=]() { *Out = apply(A, B, C, IdxSeq); }).wait_and_throw(); + + std::array DevResults = *Out; + std::array HostResults = apply(A, B, C, IdxSeq); + + for (size_t I = 0; I < NumOps; ++I) { + if (allTrue(DevResults[I] != HostResults[I])) { + std::cout << "Failed check for type " << TName << " at index " << I + << ": " << toString(DevResults[I]) + << " != " << toString(HostResults[I]) << std::endl; + ++Failed; + } + } + + sycl::free(A, Q); + sycl::free(B, Q); + sycl::free(C, Q); + sycl::free(Out, Q); + + return Failed; +} + +int main() { + sycl::queue Q; + + int Failed = 0; +#define CHECK(...) Failed += Check<__VA_ARGS__>(Q, #__VA_ARGS__); + CHECK(char) + CHECK(signed char) + CHECK(unsigned char) + CHECK(short) + CHECK(unsigned short) + CHECK(int) + CHECK(unsigned int) + CHECK(long) + CHECK(unsigned long) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::vec) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + CHECK(sycl::marray) + return Failed; +} From e46f314e09fffd77d918e9386cd3e5f07104d3b0 Mon Sep 17 00:00:00 2001 From: "Larsen, Steffen" Date: Fri, 4 Apr 2025 01:10:25 -0700 Subject: [PATCH 2/2] Cut down tests and make a single kernel Signed-off-by: Larsen, Steffen --- .../test-e2e/Experimental/ternary_bitwise.cpp | 170 +++++++++++------- 1 file changed, 105 insertions(+), 65 deletions(-) diff --git a/sycl/test-e2e/Experimental/ternary_bitwise.cpp b/sycl/test-e2e/Experimental/ternary_bitwise.cpp index c19cd8c973339..c00f7e208c1ad 100644 --- a/sycl/test-e2e/Experimental/ternary_bitwise.cpp +++ b/sycl/test-e2e/Experimental/ternary_bitwise.cpp @@ -17,24 +17,26 @@ namespace syclex = sycl::ext::oneapi::experimental; constexpr size_t NumOps = 256; +constexpr auto IdxSeq = std::make_index_sequence{}; + +static std::random_device RDev; +static std::mt19937 RNG(RDev()); template -std::array apply(T *A, T *B, T *C, std::index_sequence) { +std::array apply(std::array &A, std::array &B, + std::array &C, + std::index_sequence) { return std::array{ syclex::ternary_bitwise(A[Is], B[Is], C[Is])...}; } template void fillRandom(T *Dest) { - std::random_device RDev; - std::mt19937 RNG(RDev()); std::uniform_int_distribution Dist; for (size_t I = 0; I < NumOps; ++I) Dest[I] = Dist(RNG); } template void fillRandom(sycl::vec *Dest) { - std::random_device RDev; - std::mt19937 RNG(RDev()); std::uniform_int_distribution Dist; for (size_t I = 0; I < NumOps; ++I) for (size_t J = 0; J < N; ++J) @@ -42,8 +44,6 @@ template void fillRandom(sycl::vec *Dest) { } template void fillRandom(sycl::marray *Dest) { - std::random_device RDev; - std::mt19937 RNG(RDev()); std::uniform_int_distribution Dist; for (size_t I = 0; I < NumOps; ++I) for (size_t J = 0; J < N; ++J) @@ -79,25 +79,26 @@ template std::string toString(sycl::marray X) { return Result + "}"; } -template int Check(sycl::queue &Q, std::string_view TName) { - constexpr auto IdxSeq = std::make_index_sequence{}; - - int Failed = 0; - - T *A = sycl::malloc_shared(NumOps, Q); - T *B = sycl::malloc_shared(NumOps, Q); - T *C = sycl::malloc_shared(NumOps, Q); - auto *Out = sycl::malloc_shared>(1, Q); - - fillRandom(A); - fillRandom(B); - fillRandom(C); - - Q.single_task([=]() { *Out = apply(A, B, C, IdxSeq); }).wait_and_throw(); +template struct MemObj { + std::array A; + std::array B; + std::array C; + std::array Out; +}; + +template MemObj *createMem(sycl::queue &Q) { + MemObj *Obj = sycl::malloc_shared>(NumOps, Q); + fillRandom(Obj->A.data()); + fillRandom(Obj->B.data()); + fillRandom(Obj->C.data()); + return Obj; +} - std::array DevResults = *Out; - std::array HostResults = apply(A, B, C, IdxSeq); +template int checkResult(MemObj &Mem, std::string_view TName) { + std::array &DevResults = Mem.Out; + std::array HostResults = apply(Mem.A, Mem.B, Mem.C, IdxSeq); + int Failed = 0; for (size_t I = 0; I < NumOps; ++I) { if (allTrue(DevResults[I] != HostResults[I])) { std::cout << "Failed check for type " << TName << " at index " << I @@ -106,53 +107,92 @@ template int Check(sycl::queue &Q, std::string_view TName) { ++Failed; } } - - sycl::free(A, Q); - sycl::free(B, Q); - sycl::free(C, Q); - sycl::free(Out, Q); - return Failed; } int main() { sycl::queue Q; + auto *CharObj = createMem(Q); + auto *SCharObj = createMem(Q); + auto *UCharObj = createMem(Q); + auto *ShortObj = createMem(Q); + auto *UShortObj = createMem(Q); + auto *IntObj = createMem(Q); + auto *UIntObj = createMem(Q); + auto *LongObj = createMem(Q); + auto *ULongObj = createMem(Q); + auto *SChar2Obj = createMem>(Q); + auto *UShort8Obj = createMem>(Q); + auto *Int2Obj = createMem>(Q); + auto *ULong8Obj = createMem>(Q); + auto *CharMarrayObj = createMem>(Q); + auto *UShortMarrayObj = createMem>(Q); + auto *IntMarrayObj = createMem>(Q); + auto *ULongMarrayObj = createMem>(Q); + + Q.parallel_for(17, [=](sycl::id<1> Idx) { + // We let the ID determine which memory object the work-item processes. + size_t WorkCounter = 0; +#define APPLY(MEM_OBJ) \ + if ((WorkCounter++) == Idx[0]) \ + MEM_OBJ->Out = apply(MEM_OBJ->A, MEM_OBJ->B, MEM_OBJ->C, IdxSeq); + APPLY(CharObj) + APPLY(SCharObj) + APPLY(UCharObj) + APPLY(ShortObj) + APPLY(UShortObj) + APPLY(IntObj) + APPLY(UIntObj) + APPLY(LongObj) + APPLY(ULongObj) + APPLY(SChar2Obj) + APPLY(UShort8Obj) + APPLY(Int2Obj) + APPLY(ULong8Obj) + APPLY(CharMarrayObj) + APPLY(UShortMarrayObj) + APPLY(IntMarrayObj) + APPLY(ULongMarrayObj) + }).wait_and_throw(); + int Failed = 0; -#define CHECK(...) Failed += Check<__VA_ARGS__>(Q, #__VA_ARGS__); - CHECK(char) - CHECK(signed char) - CHECK(unsigned char) - CHECK(short) - CHECK(unsigned short) - CHECK(int) - CHECK(unsigned int) - CHECK(long) - CHECK(unsigned long) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::vec) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) - CHECK(sycl::marray) + + Failed += checkResult(*CharObj, "char"); + Failed += checkResult(*SCharObj, "signed char"); + Failed += checkResult(*UCharObj, "unsigned char"); + Failed += checkResult(*ShortObj, "short"); + Failed += checkResult(*UShortObj, "unsigned short"); + Failed += checkResult(*IntObj, "int"); + Failed += checkResult(*UIntObj, "unsigned int"); + Failed += checkResult(*LongObj, "long"); + Failed += checkResult(*ULongObj, "unsigned long"); + Failed += checkResult(*SChar2Obj, "sycl::vec"); + Failed += checkResult(*UShort8Obj, "sycl::vec"); + Failed += checkResult(*Int2Obj, "sycl::vec"); + Failed += checkResult(*ULong8Obj, "sycl::vec"); + Failed += checkResult(*CharMarrayObj, "sycl::marray"); + Failed += checkResult(*UShortMarrayObj, "sycl::marray"); + Failed += checkResult(*IntMarrayObj, "sycl::marray"); + Failed += checkResult(*ULongMarrayObj, "sycl::marray"); + + sycl::free(CharObj, Q); + sycl::free(SCharObj, Q); + sycl::free(UCharObj, Q); + sycl::free(ShortObj, Q); + sycl::free(UShortObj, Q); + sycl::free(IntObj, Q); + sycl::free(UIntObj, Q); + sycl::free(LongObj, Q); + sycl::free(ULongObj, Q); + sycl::free(SChar2Obj, Q); + sycl::free(UShort8Obj, Q); + sycl::free(Int2Obj, Q); + sycl::free(ULong8Obj, Q); + sycl::free(CharMarrayObj, Q); + sycl::free(UShortMarrayObj, Q); + sycl::free(IntMarrayObj, Q); + sycl::free(ULongMarrayObj, Q); + return Failed; }