From 1b9ab0ad27e2cf06beaacd5d42d4cf1447e104b5 Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Wed, 4 Sep 2024 16:27:59 +0200 Subject: [PATCH] [Backport to 19] SPIRVReader: handle direct types with CooperativeMatrixLengthKHR (#2695) (#2707) Translation of the attached test would currently fail due to the SPIRVReader attempting to process the `%matTy` operand as a regular value instead of a type. `OpCooperativeMatrixLengthKHR` seems to be pretty unique in taking an additional type operand beyond the result type, so special-case it in the reader. The translator currently accepts a non-type operand for `OpCooperativeMatrixLengthKHR` too, even though that's not within the specification; see various TODOs in the existing SPV_KHR_cooperative_matrix tests. Leave that relaxation in place, by only translating the operand as a type when it is an `OpTypeCooperativeMatrixKHR`. (cherry picked from commit 2b5f15d871aa39bcc9d2667883dd989afa32a146) --- lib/SPIRV/SPIRVReader.cpp | 12 +++++++-- lib/SPIRV/libSPIRV/SPIRVInstruction.cpp | 2 ++ .../SPV_KHR_cooperative_matrix/length.spvasm | 25 +++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 test/extensions/KHR/SPV_KHR_cooperative_matrix/length.spvasm diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index 6f6452237c..6155a01843 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -3293,8 +3293,16 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, OC == OpControlBarrier) Func->addFnAttr(Attribute::Convergent); } - auto *Call = - CallInst::Create(Func, transValue(Ops, BB->getParent(), BB), "", BB); + CallInst *Call; + if (BI->getOpCode() == OpCooperativeMatrixLengthKHR && + Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) { + // OpCooperativeMatrixLengthKHR needs special handling as its operand is + // a Type instead of a Value. + llvm::Type *MatTy = transType(reinterpret_cast(Ops[0])); + Call = CallInst::Create(Func, Constant::getNullValue(MatTy), "", BB); + } else { + Call = CallInst::Create(Func, transValue(Ops, BB->getParent(), BB), "", BB); + } setName(Call, BI); setAttrByCalledFunc(Call); SPIRVDBG(spvdbgs() << "[transInstToBuiltinCall] " << *BI << " -> "; diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp b/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp index dd7adf39b5..67d8b1ca24 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp @@ -146,6 +146,8 @@ SPIRVInstruction::getOperandTypes(const std::vector &Ops) { SPIRVType *Ty = nullptr; if (I->getOpCode() == OpFunction) Ty = reinterpret_cast(I)->getFunctionType(); + else if (I->getOpCode() == OpTypeCooperativeMatrixKHR) + Ty = reinterpret_cast(I); else Ty = I->getType(); diff --git a/test/extensions/KHR/SPV_KHR_cooperative_matrix/length.spvasm b/test/extensions/KHR/SPV_KHR_cooperative_matrix/length.spvasm new file mode 100644 index 0000000000..5be3e2e14d --- /dev/null +++ b/test/extensions/KHR/SPV_KHR_cooperative_matrix/length.spvasm @@ -0,0 +1,25 @@ +; RUN: spirv-as --target-env spv1.0 -o %t.spv %s +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s + +OpCapability Addresses +OpCapability Kernel +OpCapability CooperativeMatrixKHR +OpExtension "SPV_KHR_cooperative_matrix" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "testCoopMat" +%void = OpTypeVoid +%float = OpTypeFloat 32 +%fnTy = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_3 = OpConstant %uint 3 +%uint_0 = OpConstant %uint 0 +%uint_8 = OpConstant %uint 8 +%matTy = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 +%1 = OpFunction %void None %fnTy +%2 = OpLabel +%3 = OpCooperativeMatrixLengthKHR %uint %matTy +OpReturn +OpFunctionEnd + +; CHECK: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS143__spirv_CooperativeMatrixKHR__float_3_8_8_0(target("spirv.CooperativeMatrixKHR", float, 3, 8, 8, 0) zeroinitializer)