Skip to content

Commit

Permalink
Add support for reading OpVectorTimesMatrix
Browse files Browse the repository at this point in the history
Signed-off-by: Qinglai Xiao <q.xiao@think-silicon.com>
  • Loading branch information
jigsawecho authored and AlexeySachkov committed Oct 25, 2019
1 parent da3b1e4 commit 4581205
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 1 deletion.
42 changes: 42 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,48 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, Scale);
}

case OpVectorTimesMatrix: {
auto *VTM = static_cast<SPIRVVectorTimesMatrix *>(BV);
IRBuilder<> Builder(BB);
Value *Mat = transValue(VTM->getMatrix(), F, BB);
Value *Vec = transValue(VTM->getVector(), F, BB);

// Vec is of N elements.
// Mat is of M columns and N rows.
// Mat consists of vectors: V_1, V_2, ..., V_M
//
// The product is:
//
// |------- M ----------|
// Result = sum ( {Vec_1, Vec_1, ..., Vec_1} * {V_1_1, V_2_1, ..., V_M_1},
// {Vec_2, Vec_2, ..., Vec_2} * {V_1_2, V_2_2, ..., V_M_2},
// ...
// {Vec_N, Vec_N, ..., Vec_N} * {V_1_N, V_2_N, ..., V_M_N});

unsigned M = Mat->getType()->getArrayNumElements();

VectorType *VTy =
VectorType::get(Vec->getType()->getVectorElementType(), M);
auto ETy = VTy->getElementType();
unsigned N = Vec->getType()->getVectorNumElements();
Value *V = Builder.CreateVectorSplat(M, ConstantFP::get(ETy, 0.0));

for (unsigned Idx = 0; Idx != N; ++Idx) {
Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(Idx));
Value *Lhs = Builder.CreateVectorSplat(M, S);
Value *Rhs = UndefValue::get(VTy);
for (unsigned Idx2 = 0; Idx2 != M; ++Idx2) {
Value *Vx = Builder.CreateExtractValue(Mat, Idx2);
Value *Vxi = Builder.CreateExtractElement(Vx, Builder.getInt32(Idx));
Rhs = Builder.CreateInsertElement(Rhs, Vxi, Builder.getInt32(Idx2));
}
Value *Mul = Builder.CreateFMul(Lhs, Rhs);
V = Builder.CreateFAdd(V, Mul);
}

return mapValue(BV, V);
}

case OpMatrixTimesScalar: {
auto MTS = static_cast<SPIRVMatrixTimesScalar *>(BV);
IRBuilder<> Builder(BB);
Expand Down
1 change: 0 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,6 @@ _SPIRV_OP(ImageDrefGather)
_SPIRV_OP(QuantizeToF16)
_SPIRV_OP(Transpose)
_SPIRV_OP(ArrayLength)
_SPIRV_OP(VectorTimesMatrix)
_SPIRV_OP(MatrixTimesMatrix)
_SPIRV_OP(OuterProduct)
_SPIRV_OP(IAddCarry)
Expand Down
59 changes: 59 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,65 @@ class SPIRVVectorTimesScalar : public SPIRVInstruction {
SPIRVId Scalar;
};

class SPIRVVectorTimesMatrix : public SPIRVInstruction {
public:
static const Op OC = OpVectorTimesMatrix;
static const SPIRVWord FixedWordCount = 4;

// Complete constructor
SPIRVVectorTimesMatrix(SPIRVType *TheType, SPIRVId TheId, SPIRVId TheVector,
SPIRVId TheMatrix, SPIRVBasicBlock *BB)
: SPIRVInstruction(5, OC, TheType, TheId, BB), Vector(TheVector),
Matrix(TheMatrix) {
validate();
assert(BB && "Invalid BB");
}

// Incomplete constructor
SPIRVVectorTimesMatrix()
: SPIRVInstruction(OC), Vector(SPIRVID_INVALID), Matrix(SPIRVID_INVALID) {
}

SPIRVValue *getVector() const { return getValue(Vector); }
SPIRVValue *getMatrix() const { return getValue(Matrix); }

std::vector<SPIRVValue *> getOperands() override {
std::vector<SPIRVId> Operands;
Operands.push_back(Vector);
Operands.push_back(Matrix);
return getValues(Operands);
}

void setWordCount(SPIRVWord FixedWordCount) override {
SPIRVEntry::setWordCount(FixedWordCount);
}

_SPIRV_DEF_ENCDEC4(Type, Id, Vector, Matrix)

void validate() const override {
SPIRVInstruction::validate();
if (getValue(Vector)->isForward() || getValue(Matrix)->isForward())
return;

SPIRVType *Ty = getType()->getScalarType();
SPIRVType *MTy = getValueType(Matrix)->getScalarType();
SPIRVType *VTy = getValueType(Vector)->getScalarType();

(void)Ty;
(void)MTy;
(void)VTy;
assert(Ty->isTypeFloat() && "Invalid result type for OpVectorTimesMatrix");
assert(VTy->isTypeFloat() && "Invalid Vector type for OpVectorTimesMatrix");
assert(MTy->isTypeFloat() && "Invalid Matrix type for OpVectorTimesMatrix");

assert(Ty == MTy && Ty == VTy && "Mismatch float type");
}

private:
SPIRVId Vector;
SPIRVId Matrix;
};

class SPIRVMatrixTimesScalar : public SPIRVInstruction {
public:
static const Op OC = OpMatrixTimesScalar;
Expand Down
12 changes: 12 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVId TheVector,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addVectorTimesMatrixInst(SPIRVType *TheType,
SPIRVId TheVector,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addMatrixTimesScalarInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVId TheScalar,
Expand Down Expand Up @@ -1071,6 +1075,14 @@ SPIRVModuleImpl::addVectorTimesScalarInst(SPIRVType *TheType, SPIRVId TheVector,
new SPIRVVectorTimesScalar(TheType, getId(), TheVector, TheScalar, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addVectorTimesMatrixInst(SPIRVType *TheType, SPIRVId TheVector,
SPIRVId TheMatrix,
SPIRVBasicBlock *BB) {
return BB->addInstruction(
new SPIRVVectorTimesMatrix(TheType, getId(), TheVector, TheMatrix, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addMatrixTimesScalarInst(SPIRVType *TheType, SPIRVId TheMatrix,
SPIRVId TheScalar,
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ class SPIRVModule {
SPIRVId TheVector,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addVectorTimesMatrixInst(SPIRVType *TheType,
SPIRVId TheVector,
SPIRVId TheMatrix,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addMatrixTimesScalarInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVId TheScalar,
Expand Down
112 changes: 112 additions & 0 deletions test/vector_times_matrix.spt
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
119734787 65536 458752 21 0
2 Capability Addresses
2 Capability Linkage
2 Capability Kernel
2 Capability Float64
2 Capability Matrix
3 MemoryModel 2 2
8 EntryPoint 6 20 "vector_times_matrix"
3 Source 3 102000
3 Name 12 "res"
3 Name 13 "lhs"
3 Name 14 "rhs"

2 TypeVoid 5
3 TypeFloat 6 32
4 TypeVector 7 6 4
4 TypeMatrix 8 7 4
4 TypePointer 9 7 8
4 TypePointer 10 7 7
6 TypeFunction 11 5 10 10 9

5 Function 5 20 0 11
3 FunctionParameter 10 12
3 FunctionParameter 10 13
3 FunctionParameter 9 14

2 Label 15
4 Load 7 16 13
4 Load 8 17 14
5 VectorTimesMatrix 7 18 16 17
3 Store 12 18
1 Return

1 FunctionEnd

; RUN: llvm-spirv %s -to-binary -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o %t.bc
; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-LLVM: %1 = load <4 x float>, <4 x float>* %lhs
; CHECK-LLVM: %2 = load [4 x <4 x float>], [4 x <4 x float>]* %rhs
; CHECK-LLVM: %3 = extractelement <4 x float> %1, i32 0
; CHECK-LLVM: %.splatinsert = insertelement <4 x float> undef, float %3, i32 0
; CHECK-LLVM: %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %4 = extractvalue [4 x <4 x float>] %2, 0
; CHECK-LLVM: %5 = extractelement <4 x float> %4, i32 0
; CHECK-LLVM: %6 = insertelement <4 x float> undef, float %5, i32 0
; CHECK-LLVM: %7 = extractvalue [4 x <4 x float>] %2, 1
; CHECK-LLVM: %8 = extractelement <4 x float> %7, i32 0
; CHECK-LLVM: %9 = insertelement <4 x float> %6, float %8, i32 1
; CHECK-LLVM: %10 = extractvalue [4 x <4 x float>] %2, 2
; CHECK-LLVM: %11 = extractelement <4 x float> %10, i32 0
; CHECK-LLVM: %12 = insertelement <4 x float> %9, float %11, i32 2
; CHECK-LLVM: %13 = extractvalue [4 x <4 x float>] %2, 3
; CHECK-LLVM: %14 = extractelement <4 x float> %13, i32 0
; CHECK-LLVM: %15 = insertelement <4 x float> %12, float %14, i32 3
; CHECK-LLVM: %16 = fmul <4 x float> %.splat, %15
; CHECK-LLVM: %17 = fadd <4 x float> zeroinitializer, %16
; CHECK-LLVM: %18 = extractelement <4 x float> %1, i32 1
; CHECK-LLVM: %.splatinsert1 = insertelement <4 x float> undef, float %18, i32 0
; CHECK-LLVM: %.splat2 = shufflevector <4 x float> %.splatinsert1, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %19 = extractvalue [4 x <4 x float>] %2, 0
; CHECK-LLVM: %20 = extractelement <4 x float> %19, i32 1
; CHECK-LLVM: %21 = insertelement <4 x float> undef, float %20, i32 0
; CHECK-LLVM: %22 = extractvalue [4 x <4 x float>] %2, 1
; CHECK-LLVM: %23 = extractelement <4 x float> %22, i32 1
; CHECK-LLVM: %24 = insertelement <4 x float> %21, float %23, i32 1
; CHECK-LLVM: %25 = extractvalue [4 x <4 x float>] %2, 2
; CHECK-LLVM: %26 = extractelement <4 x float> %25, i32 1
; CHECK-LLVM: %27 = insertelement <4 x float> %24, float %26, i32 2
; CHECK-LLVM: %28 = extractvalue [4 x <4 x float>] %2, 3
; CHECK-LLVM: %29 = extractelement <4 x float> %28, i32 1
; CHECK-LLVM: %30 = insertelement <4 x float> %27, float %29, i32 3
; CHECK-LLVM: %31 = fmul <4 x float> %.splat2, %30
; CHECK-LLVM: %32 = fadd <4 x float> %17, %31
; CHECK-LLVM: %33 = extractelement <4 x float> %1, i32 2
; CHECK-LLVM: %.splatinsert3 = insertelement <4 x float> undef, float %33, i32 0
; CHECK-LLVM: %.splat4 = shufflevector <4 x float> %.splatinsert3, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %34 = extractvalue [4 x <4 x float>] %2, 0
; CHECK-LLVM: %35 = extractelement <4 x float> %34, i32 2
; CHECK-LLVM: %36 = insertelement <4 x float> undef, float %35, i32 0
; CHECK-LLVM: %37 = extractvalue [4 x <4 x float>] %2, 1
; CHECK-LLVM: %38 = extractelement <4 x float> %37, i32 2
; CHECK-LLVM: %39 = insertelement <4 x float> %36, float %38, i32 1
; CHECK-LLVM: %40 = extractvalue [4 x <4 x float>] %2, 2
; CHECK-LLVM: %41 = extractelement <4 x float> %40, i32 2
; CHECK-LLVM: %42 = insertelement <4 x float> %39, float %41, i32 2
; CHECK-LLVM: %43 = extractvalue [4 x <4 x float>] %2, 3
; CHECK-LLVM: %44 = extractelement <4 x float> %43, i32 2
; CHECK-LLVM: %45 = insertelement <4 x float> %42, float %44, i32 3
; CHECK-LLVM: %46 = fmul <4 x float> %.splat4, %45
; CHECK-LLVM: %47 = fadd <4 x float> %32, %46
; CHECK-LLVM: %48 = extractelement <4 x float> %1, i32 3
; CHECK-LLVM: %.splatinsert5 = insertelement <4 x float> undef, float %48, i32 0
; CHECK-LLVM: %.splat6 = shufflevector <4 x float> %.splatinsert5, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %49 = extractvalue [4 x <4 x float>] %2, 0
; CHECK-LLVM: %50 = extractelement <4 x float> %49, i32 3
; CHECK-LLVM: %51 = insertelement <4 x float> undef, float %50, i32 0
; CHECK-LLVM: %52 = extractvalue [4 x <4 x float>] %2, 1
; CHECK-LLVM: %53 = extractelement <4 x float> %52, i32 3
; CHECK-LLVM: %54 = insertelement <4 x float> %51, float %53, i32 1
; CHECK-LLVM: %55 = extractvalue [4 x <4 x float>] %2, 2
; CHECK-LLVM: %56 = extractelement <4 x float> %55, i32 3
; CHECK-LLVM: %57 = insertelement <4 x float> %54, float %56, i32 2
; CHECK-LLVM: %58 = extractvalue [4 x <4 x float>] %2, 3
; CHECK-LLVM: %59 = extractelement <4 x float> %58, i32 3
; CHECK-LLVM: %60 = insertelement <4 x float> %57, float %59, i32 3
; CHECK-LLVM: %61 = fmul <4 x float> %.splat6, %60
; CHECK-LLVM: %62 = fadd <4 x float> %47, %61
; CHECK-LLVM: store <4 x float> %62, <4 x float>* %res
; CHECK-LLVM: ret void

0 comments on commit 4581205

Please sign in to comment.