Skip to content

Commit

Permalink
[Backport to 19] SPIRVReader: handle direct types with CooperativeMat…
Browse files Browse the repository at this point in the history
…rixLengthKHR (#2695) (#2707)

Translation of the attached test would currently fail due to the
SPIRVReader attempting to process the `%matTy` operand as a regular
value instead of a type.  `OpCooperativeMatrixLengthKHR` seems to be
pretty unique in taking an additional type operand beyond the result
type, so special-case it in the reader.

The translator currently accepts a non-type operand for
`OpCooperativeMatrixLengthKHR` too, even though that's not within the
specification; see various TODOs in the existing
SPV_KHR_cooperative_matrix tests.  Leave that relaxation in place, by
only translating the operand as a type when it is an
`OpTypeCooperativeMatrixKHR`.

(cherry picked from commit 2b5f15d)
  • Loading branch information
svenvh committed Sep 4, 2024
1 parent d65c25a commit 1b9ab0a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3293,8 +3293,16 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
OC == OpControlBarrier)
Func->addFnAttr(Attribute::Convergent);
}
auto *Call =
CallInst::Create(Func, transValue(Ops, BB->getParent(), BB), "", BB);
CallInst *Call;
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR &&
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) {
// OpCooperativeMatrixLengthKHR needs special handling as its operand is
// a Type instead of a Value.
llvm::Type *MatTy = transType(reinterpret_cast<SPIRVType *>(Ops[0]));
Call = CallInst::Create(Func, Constant::getNullValue(MatTy), "", BB);
} else {
Call = CallInst::Create(Func, transValue(Ops, BB->getParent(), BB), "", BB);
}
setName(Call, BI);
setAttrByCalledFunc(Call);
SPIRVDBG(spvdbgs() << "[transInstToBuiltinCall] " << *BI << " -> ";
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ SPIRVInstruction::getOperandTypes(const std::vector<SPIRVValue *> &Ops) {
SPIRVType *Ty = nullptr;
if (I->getOpCode() == OpFunction)
Ty = reinterpret_cast<SPIRVFunction *>(I)->getFunctionType();
else if (I->getOpCode() == OpTypeCooperativeMatrixKHR)
Ty = reinterpret_cast<SPIRVType *>(I);
else
Ty = I->getType();

Expand Down
25 changes: 25 additions & 0 deletions test/extensions/KHR/SPV_KHR_cooperative_matrix/length.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
; RUN: spirv-as --target-env spv1.0 -o %t.spv %s
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s

OpCapability Addresses
OpCapability Kernel
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "testCoopMat"
%void = OpTypeVoid
%float = OpTypeFloat 32
%fnTy = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_3 = OpConstant %uint 3
%uint_0 = OpConstant %uint 0
%uint_8 = OpConstant %uint 8
%matTy = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
%1 = OpFunction %void None %fnTy
%2 = OpLabel
%3 = OpCooperativeMatrixLengthKHR %uint %matTy
OpReturn
OpFunctionEnd

; CHECK: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS143__spirv_CooperativeMatrixKHR__float_3_8_8_0(target("spirv.CooperativeMatrixKHR", float, 3, 8, 8, 0) zeroinitializer)

0 comments on commit 1b9ab0a

Please sign in to comment.