diff --git a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp index 29ddf7a248912..81f213813c845 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp @@ -1096,6 +1096,9 @@ static void applyFPFastMathModeDecorations(const SPIRVValue *BV, Value *SPIRVToLLVM::transShiftLogicalBitwiseInst(SPIRVValue *BV, BasicBlock *BB, Function *F) { SPIRVBinary *BBN = static_cast(BV); + if (BV->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(BV, transSPIRVBuiltinFromInst(BBN, BB)); + } Instruction::BinaryOps BO; auto OP = BBN->getOpCode(); if (isLogicalOpCode(OP)) @@ -2413,6 +2416,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, Builder.SetInsertPoint(BB); } SPIRVUnary *BC = static_cast(BV); + if (BV->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB)); + } auto *Neg = Builder.CreateNeg(transValue(BC->getOperand(0), F, BB), BV->getName()); if (auto *NegInst = dyn_cast(Neg)) { @@ -2465,6 +2471,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, case OpFNegate: { SPIRVUnary *BC = static_cast(BV); + if (BV->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB)); + } auto *Neg = UnaryOperator::CreateFNeg(transValue(BC->getOperand(0), F, BB), BV->getName(), BB); applyFPFastMathModeDecorations(BV, Neg); diff --git a/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/arithmetic_instructions.ll b/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/arithmetic_instructions.ll index 46c0767a4c0b4..87e4fc17dd18e 100644 --- a/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/arithmetic_instructions.ll +++ b/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/arithmetic_instructions.ll @@ -6,10 +6,9 @@ ; TODO: Validation is disabled till the moment the tools in CI are updated (passes locally) ; R/UN: spirv-val %t.spv -; TODO: come up with an approach and implement reverse translation -; R/UN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc -; R/UN: llvm-dis %t.rev.bc -; R/UN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM +; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc +; RUN: llvm-dis %t.rev.bc +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM ; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0 ; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixTypeInt:]] [[#TypeInt]] @@ -21,6 +20,8 @@ target triple = "spir-unknown-unknown" ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixIn:]] [[#]] {{$}} ; CHECK-SPIRV: SNegate [[#MatrixTypeInt]] [[#]] [[#MatrixIn]] +; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0) +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegatePU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1) define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kernel_arg_access_qual !11 !kernel_arg_type !12 !kernel_arg_type_qual !9 !kernel_arg_base_type !12 { %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegate(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1) @@ -29,7 +30,8 @@ define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kern ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixIn:]] [[#]] {{$}} ; CHECK-SPIRV: FNegate [[#MatrixTypeFloat]] [[#]] [[#MatrixIn]] -; CHECK-LLVM: fneg +; CHECK-LLVM: %0 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z15__spirv_FNegatePU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0) define spir_kernel void @testFNeg(float %a) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !9 { entry: %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) @@ -40,6 +42,9 @@ entry: ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: IAdd [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0) +; CHECK-LLVM: %2 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0) +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IAddPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2) define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 { %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) %2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) @@ -50,6 +55,7 @@ define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 ! ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: ISub [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_ISubPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2) define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 { %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) %2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) @@ -60,6 +66,7 @@ define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 ! ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: IMul [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IMulPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2) define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 { %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) %2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) @@ -70,6 +77,7 @@ define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 ! ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: SDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_SDivPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2) define void @testSDiv(i32 %a, i32 %b) { %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) %2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) @@ -80,6 +88,7 @@ define void @testSDiv(i32 %a, i32 %b) { ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: UDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_UDivPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2) define void @testUDiv(i32 %a, i32 %b) { %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) %2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0) @@ -91,44 +100,50 @@ define void @testUDiv(i32 %a, i32 %b) { ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: FAdd [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %0 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAddPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) define spir_kernel void @testFAdd(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 { entry: %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) - %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) + %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) ret void } ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: FSub [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSubPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) define spir_kernel void @testFSub(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 { entry: %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) - %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) + %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) ret void } ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: FMul [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMulPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) define spir_kernel void @testFMul(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 { entry: %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) - %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) + %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) ret void } ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}} ; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}} ; CHECK-SPIRV: FDiv [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]] +; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDivPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) define spir_kernel void @testFDiv(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 { entry: %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) %1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) - %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) + %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1) ret void }