Skip to content

Commit

Permalink
Add support for reading OpMatrixTimesVector
Browse files Browse the repository at this point in the history
As well as new type SPIRVMatrixTimesVector.

Signed-off-by: Qinglai Xiao <q.xiao@think-silicon.com>
  • Loading branch information
jigsawecho authored and AlexeySachkov committed Oct 3, 2019
1 parent a97c061 commit e63a7cb
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 2 deletions.
40 changes: 40 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,46 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, V);
}

case OpMatrixTimesVector: {
auto *MTV = static_cast<SPIRVMatrixTimesVector *>(BV);
IRBuilder<> Builder(BB);
Value *Mat = transValue(MTV->getMatrix(), F, BB);
Value *Vec = transValue(MTV->getVector(), F, BB);

// Result is similar to Matrix * Matrix
// Mat is of M columns and N rows.
// Mat consists of vectors: V_1, V_2, ..., V_M
// where each vector is of size N.
//
// Vec is of size M.
// The product is a vector of size N.
//
// |------- N ----------|
// Result = sum ( {Vec_1, Vec_1, ..., Vec_1} * V_1,
// {Vec_2, Vec_2, ..., Vec_2} * V_2,
// ...
// {Vec_M, Vec_M, ..., Vec_M} * V_N );
//
// where sum is defined as vector sum.

unsigned M = Mat->getType()->getArrayNumElements();
VectorType *VTy =
cast<VectorType>(cast<ArrayType>(Mat->getType())->getElementType());
unsigned N = VTy->getVectorNumElements();
auto ETy = VTy->getElementType();
Value *V = Builder.CreateVectorSplat(N, ConstantFP::get(ETy, 0.0));

for (unsigned Idx = 0; Idx != M; ++Idx) {
Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(Idx));
Value *Lhs = Builder.CreateVectorSplat(N, S);
Value *Vx = Builder.CreateExtractValue(Mat, Idx);
Value *Mul = Builder.CreateFMul(Lhs, Vx);
V = Builder.CreateFAdd(V, Mul);
}

return mapValue(BV, V);
}

case OpCopyObject: {
SPIRVCopyObject *CO = static_cast<SPIRVCopyObject *>(BV);
AllocaInst *AI =
Expand Down
1 change: 0 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,6 @@ _SPIRV_OP(QuantizeToF16)
_SPIRV_OP(Transpose)
_SPIRV_OP(ArrayLength)
_SPIRV_OP(VectorTimesMatrix)
_SPIRV_OP(MatrixTimesVector)
_SPIRV_OP(MatrixTimesMatrix)
_SPIRV_OP(OuterProduct)
_SPIRV_OP(IAddCarry)
Expand Down
60 changes: 59 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,13 +1316,71 @@ class SPIRVMatrixTimesScalar : public SPIRVInstruction {
assert(Ty->isTypeFloat() && "Invalid result type for OpMatrixTimesScalar");
assert(MTy->isTypeFloat() && "Invalid Matrix type for OpMatrixTimesScalar");
assert(STy->isTypeFloat() && "Invalid Scalar type for OpMatrixTimesScalar");
assert(Ty == MTy && Ty == STy && "Mismatch float type");
}

private:
SPIRVId Matrix;
SPIRVId Scalar;
};

class SPIRVMatrixTimesVector : public SPIRVInstruction {
public:
static const Op OC = OpMatrixTimesVector;
static const SPIRVWord FixedWordCount = 4;

// Complete constructor
SPIRVMatrixTimesVector(SPIRVType *TheType, SPIRVId TheId, SPIRVId TheMatrix,
SPIRVId TheVector, SPIRVBasicBlock *BB)
: SPIRVInstruction(5, OC, TheType, TheId, BB), Matrix(TheMatrix),
Vector(TheVector) {
validate();
assert(BB && "Invalid BB");
}

// Incomplete constructor
SPIRVMatrixTimesVector()
: SPIRVInstruction(OC), Matrix(SPIRVID_INVALID), Vector(SPIRVID_INVALID) {
}

SPIRVValue *getMatrix() const { return getValue(Matrix); }

SPIRVValue *getVector() const { return getValue(Vector); }

std::vector<SPIRVValue *> getOperands() override {
std::vector<SPIRVId> Operands;
Operands.push_back(Matrix);
Operands.push_back(Vector);
return getValues(Operands);
}

void setWordCount(SPIRVWord FixedWordCount) override {
SPIRVEntry::setWordCount(FixedWordCount);
}

_SPIRV_DEF_ENCDEC4(Type, Id, Matrix, Vector)

void validate() const override {
SPIRVInstruction::validate();
if (getValue(Matrix)->isForward() || getValue(Vector)->isForward())
return;
SPIRVType *Ty = getType()->getScalarType();
SPIRVType *MTy = getValueType(Matrix)->getScalarType();
SPIRVType *VTy = getValueType(Vector)->getScalarType();

(void)Ty;
(void)MTy;
(void)VTy;
assert(Ty->isTypeFloat() && "Invalid result type for OpMatrixTimesVector");
assert(MTy->isTypeFloat() && "Invalid Matrix type for OpMatrixTimesVector");
assert(VTy->isTypeFloat() && "Invalid Vector type for OpMatrixTimesVector");

assert(Ty == MTy && Ty == VTy && "Mismatch float type");
}

private:
SPIRVId Matrix;
SPIRVId Scalar;
SPIRVId Vector;
};

class SPIRVUnary : public SPIRVInstTemplateBase {
Expand Down
12 changes: 12 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVId TheMatrix,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addMatrixTimesVectorInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVId TheVector,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *) override;
SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
Expand Down Expand Up @@ -1075,6 +1079,14 @@ SPIRVModuleImpl::addMatrixTimesScalarInst(SPIRVType *TheType, SPIRVId TheMatrix,
new SPIRVMatrixTimesScalar(TheType, getId(), TheMatrix, TheScalar, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addMatrixTimesVectorInst(SPIRVType *TheType, SPIRVId TheMatrix,
SPIRVId TheVector,
SPIRVBasicBlock *BB) {
return BB->addInstruction(
new SPIRVMatrixTimesVector(TheType, getId(), TheMatrix, TheVector, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addGroupInst(Op OpCode, SPIRVType *Type, Scope Scope,
const std::vector<SPIRVValue *> &Ops,
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ class SPIRVModule {
SPIRVId TheMatrix,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addMatrixTimesVectorInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVId TheVector,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *) = 0;
virtual SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
Expand Down
68 changes: 68 additions & 0 deletions test/matrix_times_vector.spt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
119734787 65536 458752 21 0
2 Capability Addresses
2 Capability Linkage
2 Capability Kernel
2 Capability Float64
2 Capability Matrix
3 MemoryModel 2 2
8 EntryPoint 6 20 "matrix_times_vector"
3 Source 3 102000
3 Name 12 "res"
3 Name 13 "lhs"
3 Name 14 "rhs"

2 TypeVoid 5
3 TypeFloat 6 32
4 TypeVector 7 6 4
4 TypeMatrix 8 7 4
4 TypePointer 9 7 8
4 TypePointer 10 7 7
6 TypeFunction 11 5 10 9 10

5 Function 5 20 0 11
3 FunctionParameter 10 12
3 FunctionParameter 9 13
3 FunctionParameter 10 14

2 Label 15
4 Load 8 16 13
4 Load 7 17 14
5 MatrixTimesVector 7 18 16 17
3 Store 12 18
1 Return

1 FunctionEnd

; RUN: llvm-spirv %s -to-binary -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o %t.bc
; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-LLVM: %1 = load [4 x <4 x float>], [4 x <4 x float>]* %lhs
; CHECK-LLVM: %2 = load <4 x float>, <4 x float>* %rhs
; CHECK-LLVM: %3 = extractelement <4 x float> %2, i32 0
; CHECK-LLVM: %.splatinsert = insertelement <4 x float> undef, float %3, i32 0
; CHECK-LLVM: %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %4 = extractvalue [4 x <4 x float>] %1, 0
; CHECK-LLVM: %5 = fmul <4 x float> %.splat, %4
; CHECK-LLVM: %6 = fadd <4 x float> zeroinitializer, %5
; CHECK-LLVM: %7 = extractelement <4 x float> %2, i32 1
; CHECK-LLVM: %.splatinsert1 = insertelement <4 x float> undef, float %7, i32 0
; CHECK-LLVM: %.splat2 = shufflevector <4 x float> %.splatinsert1, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %8 = extractvalue [4 x <4 x float>] %1, 1
; CHECK-LLVM: %9 = fmul <4 x float> %.splat2, %8
; CHECK-LLVM: %10 = fadd <4 x float> %6, %9
; CHECK-LLVM: %11 = extractelement <4 x float> %2, i32 2
; CHECK-LLVM: %.splatinsert3 = insertelement <4 x float> undef, float %11, i32 0
; CHECK-LLVM: %.splat4 = shufflevector <4 x float> %.splatinsert3, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %12 = extractvalue [4 x <4 x float>] %1, 2
; CHECK-LLVM: %13 = fmul <4 x float> %.splat4, %12
; CHECK-LLVM: %14 = fadd <4 x float> %10, %13
; CHECK-LLVM: %15 = extractelement <4 x float> %2, i32 3
; CHECK-LLVM: %.splatinsert5 = insertelement <4 x float> undef, float %15, i32 0
; CHECK-LLVM: %.splat6 = shufflevector <4 x float> %.splatinsert5, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %16 = extractvalue [4 x <4 x float>] %1, 3
; CHECK-LLVM: %17 = fmul <4 x float> %.splat6, %16
; CHECK-LLVM: %18 = fadd <4 x float> %14, %17
; CHECK-LLVM: store <4 x float> %18, <4 x float>* %res
; CHECK-LLVM: ret void

0 comments on commit e63a7cb

Please sign in to comment.