Skip to content

Commit

Permalink
Reverse translation of arithmetic instructions for cooperative matrix…
Browse files Browse the repository at this point in the history
…es (#2166)

Implement translation via SPIR-V friendly calls, as:

the LLVM instructions are not capable to accept target extension types;
matrix arithmetic instructions require additional carry additional rules, which LLVM can not perform (for example while technically Add for vectors and (flattened) matrices is the same - yet for matrices we need to perform extra checks, also mul instruction is complitely different).
As for now some BE would need to recognize and define what to do with a call to __spirv_FMul(matrixA, matrixB). Better option is to map such SPIR-V to an intrinsic or define an appropriate type in LLVM (hence defining rules for GEP and other instructions) , but it's off the table now.

Original commit:
KhronosGroup/SPIRV-LLVM-Translator@fdc961f
  • Loading branch information
vmaksimo authored and sys-ce-bb committed Oct 5, 2023
1 parent 0583230 commit 8b36ee8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
9 changes: 9 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,9 @@ static void applyFPFastMathModeDecorations(const SPIRVValue *BV,
Value *SPIRVToLLVM::transShiftLogicalBitwiseInst(SPIRVValue *BV, BasicBlock *BB,
Function *F) {
SPIRVBinary *BBN = static_cast<SPIRVBinary *>(BV);
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
return mapValue(BV, transSPIRVBuiltinFromInst(BBN, BB));
}
Instruction::BinaryOps BO;
auto OP = BBN->getOpCode();
if (isLogicalOpCode(OP))
Expand Down Expand Up @@ -2413,6 +2416,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
Builder.SetInsertPoint(BB);
}
SPIRVUnary *BC = static_cast<SPIRVUnary *>(BV);
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB));
}
auto *Neg =
Builder.CreateNeg(transValue(BC->getOperand(0), F, BB), BV->getName());
if (auto *NegInst = dyn_cast<Instruction>(Neg)) {
Expand Down Expand Up @@ -2465,6 +2471,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,

case OpFNegate: {
SPIRVUnary *BC = static_cast<SPIRVUnary *>(BV);
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB));
}
auto *Neg = UnaryOperator::CreateFNeg(transValue(BC->getOperand(0), F, BB),
BV->getName(), BB);
applyFPFastMathModeDecorations(BV, Neg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
; TODO: Validation is disabled till the moment the tools in CI are updated (passes locally)
; R/UN: spirv-val %t.spv

; TODO: come up with an approach and implement reverse translation
; R/UN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
; R/UN: llvm-dis %t.rev.bc
; R/UN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixTypeInt:]] [[#TypeInt]]
Expand All @@ -21,6 +20,8 @@ target triple = "spir-unknown-unknown"

; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixIn:]] [[#]] {{$}}
; CHECK-SPIRV: SNegate [[#MatrixTypeInt]] [[#]] [[#MatrixIn]]
; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegatePU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1)
define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kernel_arg_access_qual !11 !kernel_arg_type !12 !kernel_arg_type_qual !9 !kernel_arg_base_type !12 {
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegate(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1)
Expand All @@ -29,7 +30,8 @@ define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kern

; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixIn:]] [[#]] {{$}}
; CHECK-SPIRV: FNegate [[#MatrixTypeFloat]] [[#]] [[#MatrixIn]]
; CHECK-LLVM: fneg
; CHECK-LLVM: %0 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z15__spirv_FNegatePU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
define spir_kernel void @testFNeg(float %a) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !9 {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
Expand All @@ -40,6 +42,9 @@ entry:
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: IAdd [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
; CHECK-LLVM: %2 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IAddPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
Expand All @@ -50,6 +55,7 @@ define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: ISub [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_ISubPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
Expand All @@ -60,6 +66,7 @@ define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: IMul [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IMulPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
Expand All @@ -70,6 +77,7 @@ define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: SDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_SDivPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
define void @testSDiv(i32 %a, i32 %b) {
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
Expand All @@ -80,6 +88,7 @@ define void @testSDiv(i32 %a, i32 %b) {
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: UDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_UDivPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
define void @testUDiv(i32 %a, i32 %b) {
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
Expand All @@ -91,44 +100,50 @@ define void @testUDiv(i32 %a, i32 %b) {
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: FAdd [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %0 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAddPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
define spir_kernel void @testFAdd(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
ret void
}

; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: FSub [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSubPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
define spir_kernel void @testFSub(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
ret void
}

; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: FMul [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMulPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
define spir_kernel void @testFMul(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
ret void
}

; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
; CHECK-SPIRV: FDiv [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDivPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
define spir_kernel void @testFDiv(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
ret void
}

Expand Down

0 comments on commit 8b36ee8

Please sign in to comment.