Skip to content

Commit

Permalink
Add support for reading OpTranspose
Browse files Browse the repository at this point in the history
Signed-off-by: Qinglai Xiao <q.xiao@think-silicon.com>
  • Loading branch information
jigsawecho authored and AlexeySotkin committed Dec 6, 2019
1 parent bc9fa54 commit d09b6b6
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 1 deletion.
62 changes: 62 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,68 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, Res);
}

case OpTranspose: {
auto TR = static_cast<SPIRVTranspose *>(BV);
IRBuilder<> Builder(BB);
auto Matrix = transValue(TR->getMatrix(), F, BB);
unsigned ColNum = Matrix->getType()->getArrayNumElements();
VectorType *ColTy =
cast<VectorType>(cast<ArrayType>(Matrix->getType())->getElementType());
unsigned RowNum = ColTy->getVectorNumElements();

auto VTy = VectorType::get(ColTy->getElementType(), ColNum);
auto ResultTy = ArrayType::get(VTy, RowNum);
Value *V = UndefValue::get(ResultTy);

SmallVector<Value *, 16> MCache;
MCache.reserve(ColNum);
for (unsigned Idx = 0; Idx != ColNum; ++Idx)
MCache.push_back(Builder.CreateExtractValue(Matrix, Idx));

if (ColNum == RowNum) {
// Fastpath
switch (ColNum) {
case 2: {
Value *V1 = Builder.CreateShuffleVector(MCache[0], MCache[1], {0, 2});
V = Builder.CreateInsertValue(V, V1, 0);
Value *V2 = Builder.CreateShuffleVector(MCache[0], MCache[1], {1, 3});
V = Builder.CreateInsertValue(V, V2, 1);
return mapValue(BV, V);
}

case 4: {
for (unsigned Idx = 0; Idx < 4; ++Idx) {
Value *V1 =
Builder.CreateShuffleVector(MCache[0], MCache[1], {Idx, Idx + 4});
Value *V2 =
Builder.CreateShuffleVector(MCache[2], MCache[3], {Idx, Idx + 4});
Value *V3 = Builder.CreateShuffleVector(V1, V2, {0, 1, 2, 3});
V = Builder.CreateInsertValue(V, V3, Idx);
}
return mapValue(BV, V);
}

default:
break;
}
}

// Slowpath
for (unsigned Idx = 0; Idx != RowNum; ++Idx) {
Value *Vec = UndefValue::get(VTy);

for (unsigned Idx2 = 0; Idx2 != ColNum; ++Idx2) {
Value *S =
Builder.CreateExtractElement(MCache[Idx2], Builder.getInt32(Idx));
Vec = Builder.CreateInsertElement(Vec, S, Idx2);
}

V = Builder.CreateInsertValue(V, Vec, Idx);
}

return mapValue(BV, V);
}

case OpCopyObject: {
SPIRVCopyObject *CO = static_cast<SPIRVCopyObject *>(BV);
AllocaInst *AI =
Expand Down
1 change: 0 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,6 @@ _SPIRV_OP(ImageFetch)
_SPIRV_OP(ImageGather)
_SPIRV_OP(ImageDrefGather)
_SPIRV_OP(QuantizeToF16)
_SPIRV_OP(Transpose)
_SPIRV_OP(ArrayLength)
_SPIRV_OP(OuterProduct)
_SPIRV_OP(IAddCarry)
Expand Down
49 changes: 49 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,55 @@ class SPIRVMatrixTimesMatrix : public SPIRVInstruction {
SPIRVId RightMatrix;
};

class SPIRVTranspose : public SPIRVInstruction {
public:
static const Op OC = OpTranspose;
static const SPIRVWord FixedWordCount = 3;

// Complete constructor
SPIRVTranspose(SPIRVType *TheType, SPIRVId TheId, SPIRVId TheMatrix,
SPIRVBasicBlock *BB)
: SPIRVInstruction(4, OC, TheType, TheId, BB), Matrix(TheMatrix) {
validate();
assert(BB && "Invalid BB");
}

// Incomplete constructor
SPIRVTranspose() : SPIRVInstruction(OC), Matrix(SPIRVID_INVALID) {}

SPIRVValue *getMatrix() const { return getValue(Matrix); }

std::vector<SPIRVValue *> getOperands() override {
std::vector<SPIRVId> Operands;
Operands.push_back(Matrix);
return getValues(Operands);
}

void setWordCount(SPIRVWord FixedWordCount) override {
SPIRVEntry::setWordCount(FixedWordCount);
}

_SPIRV_DEF_ENCDEC3(Type, Id, Matrix)

void validate() const override {
SPIRVInstruction::validate();
if (getValue(Matrix)->isForward())
return;

SPIRVType *Ty = getType()->getScalarType();
SPIRVType *MTy = getValueType(Matrix)->getScalarType();

(void)Ty;
(void)MTy;

assert(Ty->isTypeFloat() && "Invalid result type for OpTranspose");
assert(Ty == MTy && "Mismatch float type");
}

private:
SPIRVId Matrix;
};

class SPIRVUnary : public SPIRVInstTemplateBase {
protected:
void validate() const override {
Expand Down
9 changes: 9 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVInstruction *addMatrixTimesMatrixInst(SPIRVType *TheType, SPIRVId M1,
SPIRVId M2,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addTransposeInst(SPIRVType *TheType, SPIRVId TheMatrix,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *) override;
SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
Expand Down Expand Up @@ -1109,6 +1111,13 @@ SPIRVModuleImpl::addMatrixTimesMatrixInst(SPIRVType *TheType, SPIRVId M1,
new SPIRVMatrixTimesMatrix(TheType, getId(), M1, M2, BB));
}

SPIRVInstruction *SPIRVModuleImpl::addTransposeInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVBasicBlock *BB) {
return BB->addInstruction(
new SPIRVTranspose(TheType, getId(), TheMatrix, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addGroupInst(Op OpCode, SPIRVType *Type, Scope Scope,
const std::vector<SPIRVValue *> &Ops,
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ class SPIRVModule {
virtual SPIRVInstruction *addMatrixTimesMatrixInst(SPIRVType *TheType,
SPIRVId M1, SPIRVId M2,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addTransposeInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *) = 0;
virtual SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
Expand Down
174 changes: 174 additions & 0 deletions test/matrix_transpose.spt
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
; RUN: llvm-spirv %s -to-binary -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o %t.bc
; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-LLVM: %1 = load [4 x <4 x float>], [4 x <4 x float>]* %mtx4
; CHECK-LLVM: %2 = extractvalue [4 x <4 x float>] %1, 0
; CHECK-LLVM: %3 = extractvalue [4 x <4 x float>] %1, 1
; CHECK-LLVM: %4 = extractvalue [4 x <4 x float>] %1, 2
; CHECK-LLVM: %5 = extractvalue [4 x <4 x float>] %1, 3
; CHECK-LLVM: %6 = shufflevector <4 x float> %2, <4 x float> %3, <2 x i32> <i32 0, i32 4>
; CHECK-LLVM: %7 = shufflevector <4 x float> %4, <4 x float> %5, <2 x i32> <i32 0, i32 4>
; CHECK-LLVM: %8 = shufflevector <2 x float> %6, <2 x float> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-LLVM: %9 = insertvalue [4 x <4 x float>] undef, <4 x float> %8, 0
; CHECK-LLVM: %10 = shufflevector <4 x float> %2, <4 x float> %3, <2 x i32> <i32 1, i32 5>
; CHECK-LLVM: %11 = shufflevector <4 x float> %4, <4 x float> %5, <2 x i32> <i32 1, i32 5>
; CHECK-LLVM: %12 = shufflevector <2 x float> %10, <2 x float> %11, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-LLVM: %13 = insertvalue [4 x <4 x float>] %9, <4 x float> %12, 1
; CHECK-LLVM: %14 = shufflevector <4 x float> %2, <4 x float> %3, <2 x i32> <i32 2, i32 6>
; CHECK-LLVM: %15 = shufflevector <4 x float> %4, <4 x float> %5, <2 x i32> <i32 2, i32 6>
; CHECK-LLVM: %16 = shufflevector <2 x float> %14, <2 x float> %15, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-LLVM: %17 = insertvalue [4 x <4 x float>] %13, <4 x float> %16, 2
; CHECK-LLVM: %18 = shufflevector <4 x float> %2, <4 x float> %3, <2 x i32> <i32 3, i32 7>
; CHECK-LLVM: %19 = shufflevector <4 x float> %4, <4 x float> %5, <2 x i32> <i32 3, i32 7>
; CHECK-LLVM: %20 = shufflevector <2 x float> %18, <2 x float> %19, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-LLVM: %21 = insertvalue [4 x <4 x float>] %17, <4 x float> %20, 3
; CHECK-LLVM: store [4 x <4 x float>] %21, [4 x <4 x float>]* %res4

; CHECK-LLVM: %22 = load [2 x <2 x float>], [2 x <2 x float>]* %mtx2
; CHECK-LLVM: %23 = extractvalue [2 x <2 x float>] %22, 0
; CHECK-LLVM: %24 = extractvalue [2 x <2 x float>] %22, 1
; CHECK-LLVM: %25 = shufflevector <2 x float> %23, <2 x float> %24, <2 x i32> <i32 0, i32 2>
; CHECK-LLVM: %26 = insertvalue [2 x <2 x float>] undef, <2 x float> %25, 0
; CHECK-LLVM: %27 = shufflevector <2 x float> %23, <2 x float> %24, <2 x i32> <i32 1, i32 3>
; CHECK-LLVM: %28 = insertvalue [2 x <2 x float>] %26, <2 x float> %27, 1
; CHECK-LLVM: store [2 x <2 x float>] %28, [2 x <2 x float>]* %res2

; CHECK-LLVM: %29 = load [3 x <3 x float>], [3 x <3 x float>]* %mtx3
; CHECK-LLVM: %30 = extractvalue [3 x <3 x float>] %29, 0
; CHECK-LLVM: %31 = extractvalue [3 x <3 x float>] %29, 1
; CHECK-LLVM: %32 = extractvalue [3 x <3 x float>] %29, 2
; CHECK-LLVM: %33 = extractelement <3 x float> %30, i32 0
; CHECK-LLVM: %34 = insertelement <3 x float> undef, float %33, i64 0
; CHECK-LLVM: %35 = extractelement <3 x float> %31, i32 0
; CHECK-LLVM: %36 = insertelement <3 x float> %34, float %35, i64 1
; CHECK-LLVM: %37 = extractelement <3 x float> %32, i32 0
; CHECK-LLVM: %38 = insertelement <3 x float> %36, float %37, i64 2
; CHECK-LLVM: %39 = insertvalue [3 x <3 x float>] undef, <3 x float> %38, 0
; CHECK-LLVM: %40 = extractelement <3 x float> %30, i32 1
; CHECK-LLVM: %41 = insertelement <3 x float> undef, float %40, i64 0
; CHECK-LLVM: %42 = extractelement <3 x float> %31, i32 1
; CHECK-LLVM: %43 = insertelement <3 x float> %41, float %42, i64 1
; CHECK-LLVM: %44 = extractelement <3 x float> %32, i32 1
; CHECK-LLVM: %45 = insertelement <3 x float> %43, float %44, i64 2
; CHECK-LLVM: %46 = insertvalue [3 x <3 x float>] %39, <3 x float> %45, 1
; CHECK-LLVM: %47 = extractelement <3 x float> %30, i32 2
; CHECK-LLVM: %48 = insertelement <3 x float> undef, float %47, i64 0
; CHECK-LLVM: %49 = extractelement <3 x float> %31, i32 2
; CHECK-LLVM: %50 = insertelement <3 x float> %48, float %49, i64 1
; CHECK-LLVM: %51 = extractelement <3 x float> %32, i32 2
; CHECK-LLVM: %52 = insertelement <3 x float> %50, float %51, i64 2
; CHECK-LLVM: %53 = insertvalue [3 x <3 x float>] %46, <3 x float> %52, 2
; CHECK-LLVM: store [3 x <3 x float>] %53, [3 x <3 x float>]* %res3

; CHECK-LLVM: %54 = load [4 x <3 x float>], [4 x <3 x float>]* %mtx43
; CHECK-LLVM: %55 = extractvalue [4 x <3 x float>] %54, 0
; CHECK-LLVM: %56 = extractvalue [4 x <3 x float>] %54, 1
; CHECK-LLVM: %57 = extractvalue [4 x <3 x float>] %54, 2
; CHECK-LLVM: %58 = extractvalue [4 x <3 x float>] %54, 3
; CHECK-LLVM: %59 = extractelement <3 x float> %55, i32 0
; CHECK-LLVM: %60 = insertelement <4 x float> undef, float %59, i64 0
; CHECK-LLVM: %61 = extractelement <3 x float> %56, i32 0
; CHECK-LLVM: %62 = insertelement <4 x float> %60, float %61, i64 1
; CHECK-LLVM: %63 = extractelement <3 x float> %57, i32 0
; CHECK-LLVM: %64 = insertelement <4 x float> %62, float %63, i64 2
; CHECK-LLVM: %65 = extractelement <3 x float> %58, i32 0
; CHECK-LLVM: %66 = insertelement <4 x float> %64, float %65, i64 3
; CHECK-LLVM: %67 = insertvalue [3 x <4 x float>] undef, <4 x float> %66, 0
; CHECK-LLVM: %68 = extractelement <3 x float> %55, i32 1
; CHECK-LLVM: %69 = insertelement <4 x float> undef, float %68, i64 0
; CHECK-LLVM: %70 = extractelement <3 x float> %56, i32 1
; CHECK-LLVM: %71 = insertelement <4 x float> %69, float %70, i64 1
; CHECK-LLVM: %72 = extractelement <3 x float> %57, i32 1
; CHECK-LLVM: %73 = insertelement <4 x float> %71, float %72, i64 2
; CHECK-LLVM: %74 = extractelement <3 x float> %58, i32 1
; CHECK-LLVM: %75 = insertelement <4 x float> %73, float %74, i64 3
; CHECK-LLVM: %76 = insertvalue [3 x <4 x float>] %67, <4 x float> %75, 1
; CHECK-LLVM: %77 = extractelement <3 x float> %55, i32 2
; CHECK-LLVM: %78 = insertelement <4 x float> undef, float %77, i64 0
; CHECK-LLVM: %79 = extractelement <3 x float> %56, i32 2
; CHECK-LLVM: %80 = insertelement <4 x float> %78, float %79, i64 1
; CHECK-LLVM: %81 = extractelement <3 x float> %57, i32 2
; CHECK-LLVM: %82 = insertelement <4 x float> %80, float %81, i64 2
; CHECK-LLVM: %83 = extractelement <3 x float> %58, i32 2
; CHECK-LLVM: %84 = insertelement <4 x float> %82, float %83, i64 3
; CHECK-LLVM: %85 = insertvalue [3 x <4 x float>] %76, <4 x float> %84, 2
; CHECK-LLVM: store [3 x <4 x float>] %85, [3 x <4 x float>]* %res34
; CHECK-LLVM: ret void

119734787 65536 458752 51 0
2 Capability Addresses
2 Capability Linkage
2 Capability Kernel
2 Capability Float64
2 Capability Matrix
3 MemoryModel 2 2
8 EntryPoint 6 40 "matrix_transpose"
3 Source 3 102000
3 Name 30 "res4"
3 Name 31 "mtx4"
3 Name 32 "res2"
3 Name 33 "mtx2"
3 Name 34 "res3"
3 Name 35 "mtx3"
3 Name 36 "res34"
3 Name 37 "mtx43"

2 TypeVoid 5
3 TypeFloat 6 32
4 TypeVector 7 6 4
4 TypeMatrix 8 7 4
4 TypePointer 9 7 8 ; 9 : Pointer to Matrix4x4

4 TypeVector 10 6 2
4 TypeMatrix 11 10 2
4 TypePointer 12 7 11 ; 12 : Pointer to Matrix2x2

4 TypeVector 13 6 3
4 TypeMatrix 14 13 3
4 TypePointer 15 7 14 ; 15 : Pointer to Matrix3x3

4 TypeMatrix 17 13 4
4 TypePointer 18 7 17 ; 18 : Pointer to Matrix4x3

4 TypeMatrix 20 7 3
4 TypePointer 21 7 20 ; 21 : Pointer to Matrix3x4

11 TypeFunction 29 5 9 9 12 12 15 15 21 18

5 Function 5 40 0 29
3 FunctionParameter 9 30
3 FunctionParameter 9 31
3 FunctionParameter 12 32
3 FunctionParameter 12 33
3 FunctionParameter 15 34
3 FunctionParameter 15 35
3 FunctionParameter 21 36
3 FunctionParameter 18 37

2 Label 50

; 4x4, fastpath
4 Load 8 41 31
4 Transpose 8 42 41
3 Store 30 42

; 2x2, fastpath
4 Load 11 43 33
4 Transpose 11 44 43
3 Store 32 44

; 3x3, slowpath
4 Load 14 45 35
4 Transpose 14 46 45
3 Store 34 46

; 3x4, slowpath
4 Load 17 47 37
4 Transpose 20 48 47
3 Store 36 48

1 Return

1 FunctionEnd

0 comments on commit d09b6b6

Please sign in to comment.