diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index 2bf2ba5783..cb27067e30 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -1607,6 +1607,46 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, return mapValue(BV, V); } + case OpMatrixTimesVector: { + auto *MTV = static_cast(BV); + IRBuilder<> Builder(BB); + Value *Mat = transValue(MTV->getMatrix(), F, BB); + Value *Vec = transValue(MTV->getVector(), F, BB); + + // Result is similar to Matrix * Matrix + // Mat is of M columns and N rows. + // Mat consists of vectors: V_1, V_2, ..., V_M + // where each vector is of size N. + // + // Vec is of size M. + // The product is a vector of size N. + // + // |------- N ----------| + // Result = sum ( {Vec_1, Vec_1, ..., Vec_1} * V_1, + // {Vec_2, Vec_2, ..., Vec_2} * V_2, + // ... + // {Vec_M, Vec_M, ..., Vec_M} * V_N ); + // + // where sum is defined as vector sum. + + unsigned M = Mat->getType()->getArrayNumElements(); + VectorType *VTy = + cast(cast(Mat->getType())->getElementType()); + unsigned N = VTy->getVectorNumElements(); + auto ETy = VTy->getElementType(); + Value *V = Builder.CreateVectorSplat(N, ConstantFP::get(ETy, 0.0)); + + for (unsigned Idx = 0; Idx != M; ++Idx) { + Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(Idx)); + Value *Lhs = Builder.CreateVectorSplat(N, S); + Value *Vx = Builder.CreateExtractValue(Mat, Idx); + Value *Mul = Builder.CreateFMul(Lhs, Vx); + V = Builder.CreateFAdd(V, Mul); + } + + return mapValue(BV, V); + } + case OpCopyObject: { SPIRVCopyObject *CO = static_cast(BV); AllocaInst *AI = diff --git a/lib/SPIRV/libSPIRV/SPIRVEntry.h b/lib/SPIRV/libSPIRV/SPIRVEntry.h index 1bdfca924a..ccdf2a20fc 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEntry.h +++ b/lib/SPIRV/libSPIRV/SPIRVEntry.h @@ -787,7 +787,6 @@ _SPIRV_OP(QuantizeToF16) _SPIRV_OP(Transpose) _SPIRV_OP(ArrayLength) _SPIRV_OP(VectorTimesMatrix) -_SPIRV_OP(MatrixTimesVector) _SPIRV_OP(MatrixTimesMatrix) _SPIRV_OP(OuterProduct) _SPIRV_OP(IAddCarry) diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 4a41d9631e..82325048d9 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -1316,13 +1316,71 @@ class SPIRVMatrixTimesScalar : public SPIRVInstruction { assert(Ty->isTypeFloat() && "Invalid result type for OpMatrixTimesScalar"); assert(MTy->isTypeFloat() && "Invalid Matrix type for OpMatrixTimesScalar"); assert(STy->isTypeFloat() && "Invalid Scalar type for OpMatrixTimesScalar"); + assert(Ty == MTy && Ty == STy && "Mismatch float type"); + } + +private: + SPIRVId Matrix; + SPIRVId Scalar; +}; + +class SPIRVMatrixTimesVector : public SPIRVInstruction { +public: + static const Op OC = OpMatrixTimesVector; + static const SPIRVWord FixedWordCount = 4; + + // Complete constructor + SPIRVMatrixTimesVector(SPIRVType *TheType, SPIRVId TheId, SPIRVId TheMatrix, + SPIRVId TheVector, SPIRVBasicBlock *BB) + : SPIRVInstruction(5, OC, TheType, TheId, BB), Matrix(TheMatrix), + Vector(TheVector) { + validate(); + assert(BB && "Invalid BB"); + } + + // Incomplete constructor + SPIRVMatrixTimesVector() + : SPIRVInstruction(OC), Matrix(SPIRVID_INVALID), Vector(SPIRVID_INVALID) { + } + + SPIRVValue *getMatrix() const { return getValue(Matrix); } + + SPIRVValue *getVector() const { return getValue(Vector); } + std::vector getOperands() override { + std::vector Operands; + Operands.push_back(Matrix); + Operands.push_back(Vector); + return getValues(Operands); + } + + void setWordCount(SPIRVWord FixedWordCount) override { + SPIRVEntry::setWordCount(FixedWordCount); + } + + _SPIRV_DEF_ENCDEC4(Type, Id, Matrix, Vector) + + void validate() const override { SPIRVInstruction::validate(); + if (getValue(Matrix)->isForward() || getValue(Vector)->isForward()) + return; + SPIRVType *Ty = getType()->getScalarType(); + SPIRVType *MTy = getValueType(Matrix)->getScalarType(); + SPIRVType *VTy = getValueType(Vector)->getScalarType(); + + (void)Ty; + (void)MTy; + (void)VTy; + assert(Ty->isTypeFloat() && "Invalid result type for OpMatrixTimesVector"); + assert(MTy->isTypeFloat() && "Invalid Matrix type for OpMatrixTimesVector"); + assert(VTy->isTypeFloat() && "Invalid Vector type for OpMatrixTimesVector"); + + assert(Ty == MTy && Ty == VTy && "Mismatch float type"); } private: SPIRVId Matrix; - SPIRVId Scalar; + SPIRVId Vector; }; class SPIRVUnary : public SPIRVInstTemplateBase { diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp index 25f2401c65..2cb364ea7d 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -377,6 +377,10 @@ class SPIRVModuleImpl : public SPIRVModule { SPIRVId TheMatrix, SPIRVId TheScalar, SPIRVBasicBlock *BB) override; + SPIRVInstruction *addMatrixTimesVectorInst(SPIRVType *TheType, + SPIRVId TheMatrix, + SPIRVId TheVector, + SPIRVBasicBlock *BB) override; SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *, SPIRVBasicBlock *) override; SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind, @@ -1075,6 +1079,14 @@ SPIRVModuleImpl::addMatrixTimesScalarInst(SPIRVType *TheType, SPIRVId TheMatrix, new SPIRVMatrixTimesScalar(TheType, getId(), TheMatrix, TheScalar, BB)); } +SPIRVInstruction * +SPIRVModuleImpl::addMatrixTimesVectorInst(SPIRVType *TheType, SPIRVId TheMatrix, + SPIRVId TheVector, + SPIRVBasicBlock *BB) { + return BB->addInstruction( + new SPIRVMatrixTimesVector(TheType, getId(), TheMatrix, TheVector, BB)); +} + SPIRVInstruction * SPIRVModuleImpl::addGroupInst(Op OpCode, SPIRVType *Type, Scope Scope, const std::vector &Ops, diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.h b/lib/SPIRV/libSPIRV/SPIRVModule.h index d89c469d65..225f1f77d5 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.h +++ b/lib/SPIRV/libSPIRV/SPIRVModule.h @@ -378,6 +378,10 @@ class SPIRVModule { SPIRVId TheMatrix, SPIRVId TheScalar, SPIRVBasicBlock *BB) = 0; + virtual SPIRVInstruction *addMatrixTimesVectorInst(SPIRVType *TheType, + SPIRVId TheMatrix, + SPIRVId TheVector, + SPIRVBasicBlock *BB) = 0; virtual SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *, SPIRVBasicBlock *) = 0; virtual SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind, diff --git a/test/matrix_times_vector.spt b/test/matrix_times_vector.spt new file mode 100644 index 0000000000..0726381d44 --- /dev/null +++ b/test/matrix_times_vector.spt @@ -0,0 +1,68 @@ +119734787 65536 458752 21 0 +2 Capability Addresses +2 Capability Linkage +2 Capability Kernel +2 Capability Float64 +2 Capability Matrix +3 MemoryModel 2 2 +8 EntryPoint 6 20 "matrix_times_vector" +3 Source 3 102000 +3 Name 12 "res" +3 Name 13 "lhs" +3 Name 14 "rhs" + +2 TypeVoid 5 +3 TypeFloat 6 32 +4 TypeVector 7 6 4 +4 TypeMatrix 8 7 4 +4 TypePointer 9 7 8 +4 TypePointer 10 7 7 +6 TypeFunction 11 5 10 9 10 + +5 Function 5 20 0 11 +3 FunctionParameter 10 12 +3 FunctionParameter 9 13 +3 FunctionParameter 10 14 + +2 Label 15 +4 Load 8 16 13 +4 Load 7 17 14 +5 MatrixTimesVector 7 18 16 17 +3 Store 12 18 +1 Return + +1 FunctionEnd + +; RUN: llvm-spirv %s -to-binary -o %t.spv +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r %t.spv -o %t.bc +; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM + +; CHECK-LLVM: %1 = load [4 x <4 x float>], [4 x <4 x float>]* %lhs +; CHECK-LLVM: %2 = load <4 x float>, <4 x float>* %rhs +; CHECK-LLVM: %3 = extractelement <4 x float> %2, i32 0 +; CHECK-LLVM: %.splatinsert = insertelement <4 x float> undef, float %3, i32 0 +; CHECK-LLVM: %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer +; CHECK-LLVM: %4 = extractvalue [4 x <4 x float>] %1, 0 +; CHECK-LLVM: %5 = fmul <4 x float> %.splat, %4 +; CHECK-LLVM: %6 = fadd <4 x float> zeroinitializer, %5 +; CHECK-LLVM: %7 = extractelement <4 x float> %2, i32 1 +; CHECK-LLVM: %.splatinsert1 = insertelement <4 x float> undef, float %7, i32 0 +; CHECK-LLVM: %.splat2 = shufflevector <4 x float> %.splatinsert1, <4 x float> undef, <4 x i32> zeroinitializer +; CHECK-LLVM: %8 = extractvalue [4 x <4 x float>] %1, 1 +; CHECK-LLVM: %9 = fmul <4 x float> %.splat2, %8 +; CHECK-LLVM: %10 = fadd <4 x float> %6, %9 +; CHECK-LLVM: %11 = extractelement <4 x float> %2, i32 2 +; CHECK-LLVM: %.splatinsert3 = insertelement <4 x float> undef, float %11, i32 0 +; CHECK-LLVM: %.splat4 = shufflevector <4 x float> %.splatinsert3, <4 x float> undef, <4 x i32> zeroinitializer +; CHECK-LLVM: %12 = extractvalue [4 x <4 x float>] %1, 2 +; CHECK-LLVM: %13 = fmul <4 x float> %.splat4, %12 +; CHECK-LLVM: %14 = fadd <4 x float> %10, %13 +; CHECK-LLVM: %15 = extractelement <4 x float> %2, i32 3 +; CHECK-LLVM: %.splatinsert5 = insertelement <4 x float> undef, float %15, i32 0 +; CHECK-LLVM: %.splat6 = shufflevector <4 x float> %.splatinsert5, <4 x float> undef, <4 x i32> zeroinitializer +; CHECK-LLVM: %16 = extractvalue [4 x <4 x float>] %1, 3 +; CHECK-LLVM: %17 = fmul <4 x float> %.splat6, %16 +; CHECK-LLVM: %18 = fadd <4 x float> %14, %17 +; CHECK-LLVM: store <4 x float> %18, <4 x float>* %res +; CHECK-LLVM: ret void