Skip to content

Commit

Permalink
Add pass to lower Bitcast to nonstandard type instructions (KhronosGr…
Browse files Browse the repository at this point in the history
…oup#1117)

* Add pass to lower Bitcast to nonstandard type instructions

It is a pass to lower bitcast instructions to non-standard SPIR-V types.
It covers only known issues. At this moment there is only one pattern
that should be covered - use of vector with unsupported number of
elements. This pattern should be handled this way:

%0 = bitcast <3 x i64> addrspace(1)* @id to <6 x i32> addrspace(1)*
%1 = addrspacecast <6 x i32> addrspace(1)* %0 to <6 x i32> addrspace(4)*
%2 = load <6 x i32>, <6 x i32> addrspace(4)* %1, align 32
%conv = extractelement <6 x i32> %2, i32 1
%conv1 = sitofp i32 %conv to float

Must be replaced by:
%0 = addrspacecast <3 x i64> addrspace(1)* @id to  <3 x i64> addrspace(4)*
%1 = load <3 x i64>, <3 x i64> addrspace(4)* %0, align 32
%2 = extractelement <3 x i64> %1, i32 0
%conv = trunc i64 %2 to i32
%conv1 = sitofp i32 %conv to float

It is assumed that the pass will be further developed as new patterns arise.
  • Loading branch information
KornevNikita authored and AlexeySotkin committed Aug 20, 2021
1 parent 7569447 commit 0770392
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/LLVMSPIRVLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ void initializeSPIRVRegularizeLLVMLegacyPass(PassRegistry &);
void initializeSPIRVToOCL12LegacyPass(PassRegistry &);
void initializeSPIRVToOCL20LegacyPass(PassRegistry &);
void initializePreprocessMetadataLegacyPass(PassRegistry &);
void initializeSPIRVLowerBitCastToNonStandardTypeLegacyPass(PassRegistry &);

class ModulePass;
class FunctionPass;
} // namespace llvm

#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -215,6 +217,10 @@ ModulePass *createSPIRVWriterPass(std::ostream &Str);
ModulePass *createSPIRVWriterPass(std::ostream &Str,
const SPIRV::TranslatorOpts &Opts);

/// Create a pass for removing bitcast instructions to non-standard SPIR-V
/// types
FunctionPass *createSPIRVLowerBitCastToNonStandardTypeLegacy();

} // namespace llvm

#endif // SPIRV_H
1 change: 1 addition & 0 deletions lib/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_llvm_library(LLVMSPIRVLib
OCLTypeToSPIRV.cpp
OCLUtil.cpp
VectorComputeUtil.cpp
SPIRVLowerBitCastToNonStandardType.cpp
SPIRVLowerBool.cpp
SPIRVLowerConstExpr.cpp
SPIRVLowerMemmove.cpp
Expand Down
221 changes: 221 additions & 0 deletions lib/SPIRV/SPIRVLowerBitCastToNonStandardType.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
//===============- SPIRVLowerBitCastToNonStandardType.cpp -================//
//
// The LLVM/SPIRV Translator
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
// Copyright (c) 2021 Intel Corporation. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal with the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimers.
// Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimers in the documentation
// and/or other materials provided with the distribution.
// Neither the names of Intel Corporation, nor the names of its
// contributors may be used to endorse or promote products derived from this
// Software without specific prior written permission.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
// THE SOFTWARE.
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of BitCast to nonstandard types. LLVM
// transformations bitcast some vector types to scalar types, which are not
// universally supported across all targets. We need ensure that "optimized"
// LLVM IR doesn't have primitive types other than supported by the
// SPIR target (i.e. "scalar 8/16/32/64-bit integer and 16/32/64-bit floating
// point types, 2/3/4/8/16-element vector of scalar types").
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "spv-lower-bitcast-to-nonstandard-type"

#include "SPIRVInternal.h"

#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PassManager.h"

#include <utility>

using namespace llvm;

namespace SPIRV {

static VectorType *getVectorType(Type *Ty) {
assert(Ty != nullptr && "Expected non-null type");
if (auto *ElemTy = dyn_cast<PointerType>(Ty))
Ty = ElemTy->getElementType();
return dyn_cast<VectorType>(Ty);
}

/// Since SPIR-V does not support non-standard vector types, instructions using
/// these types should be replaced in a special way to avoid using of
/// unsupported types.
/// lowerBitCastToNonStdVec function is designed to avoid using of bitcast to
/// unsupported vector types instructions and should be called if similar
/// instructions have been encountered in input LLVM IR.
bool lowerBitCastToNonStdVec(Instruction *OldInst, Value *NewInst,
const VectorType *OldVecTy,
std::vector<Instruction *> &InstsToErase,
IRBuilder<> &Builder,
unsigned RecursionDepth = 0) {
static constexpr unsigned MaxRecursionDepth = 16;
if (RecursionDepth++ > MaxRecursionDepth)
report_fatal_error(
"The depth of recursion exceeds the maximum possible depth", false);

bool Changed = false;
VectorType *NewVecTy = getVectorType(NewInst->getType());
if (NewVecTy) {
Builder.SetInsertPoint(OldInst);
for (auto *U : OldInst->users()) {
// Handle addrspacecast instruction after bitcast if present
if (auto *ASCastInst = dyn_cast<AddrSpaceCastInst>(U)) {
unsigned DestAS = ASCastInst->getDestAddressSpace();
auto *NewVecPtrTy = NewVecTy->getPointerTo(DestAS);
// AddrSpaceCast is created explicitly instead of using method
// IRBuilder<>.CreateAddrSpaceCast because IRBuilder doesn't create
// separate instruction for constant values. Whereas SPIR-V translator
// doesn't like several nested instructions in one.
Value *LocalValue = new AddrSpaceCastInst(NewInst, NewVecPtrTy);
Builder.Insert(LocalValue);
Changed |=
lowerBitCastToNonStdVec(ASCastInst, LocalValue, OldVecTy,
InstsToErase, Builder, RecursionDepth);
}
// Handle load instruction which is following the bitcast in the pattern
else if (auto *LI = dyn_cast<LoadInst>(U)) {
Value *LocalValue = Builder.CreateLoad(NewVecTy, NewInst);
Changed |= lowerBitCastToNonStdVec(
LI, LocalValue, OldVecTy, InstsToErase, Builder, RecursionDepth);
}
// Handle extractelement instruction which is following the load
else if (auto *EEI = dyn_cast<ExtractElementInst>(U)) {
uint64_t NumElemsInOldVec = OldVecTy->getElementCount().getValue();
uint64_t NumElemsInNewVec = NewVecTy->getElementCount().getValue();
uint64_t OldElemIdx =
cast<ConstantInt>(EEI->getIndexOperand())->getZExtValue();
uint64_t NewElemIdx =
OldElemIdx / (NumElemsInOldVec / NumElemsInNewVec);
Value *LocalValue = Builder.CreateExtractElement(NewInst, NewElemIdx);
// The trunc instruction truncates the high order bits in value, so it
// may be necessary to shift right high order bits, if required bits are
// not at the end of extracted value
unsigned OldVecElemBitWidth =
cast<IntegerType>(OldVecTy->getElementType())->getBitWidth();
unsigned NewVecElemBitWidth =
cast<IntegerType>(NewVecTy->getElementType())->getBitWidth();
unsigned BitWidthRatio = NewVecElemBitWidth / OldVecElemBitWidth;
if (auto RequiredBitsIdx =
OldElemIdx % BitWidthRatio != BitWidthRatio - 1) {
uint64_t Shift =
OldVecElemBitWidth * (BitWidthRatio - RequiredBitsIdx);
LocalValue = Builder.CreateLShr(LocalValue, Shift);
}
LocalValue =
Builder.CreateTrunc(LocalValue, OldVecTy->getElementType());
Changed |= lowerBitCastToNonStdVec(
EEI, LocalValue, OldVecTy, InstsToErase, Builder, RecursionDepth);
}
}
}
InstsToErase.push_back(OldInst);
if (!Changed)
OldInst->replaceAllUsesWith(NewInst);
return true;
}

class SPIRVLowerBitCastToNonStandardTypePass
: public llvm::PassInfoMixin<SPIRVLowerBitCastToNonStandardTypePass> {
public:
SPIRVLowerBitCastToNonStandardTypePass() {}

PreservedAnalyses
runLowerBitCastToNonStandardType(Function &F, FunctionAnalysisManager &FAM) {
// This pass doesn't cover all possible uses of non-standard types, only
// known. We assume that bad type won't be passed to a function as
// parameter, since it added by an optimization.
bool Changed = false;

std::vector<Instruction *> BCastsToNonStdVec;
std::vector<Instruction *> InstsToErase;
for (auto &BB : F)
for (auto &I : BB) {
auto *BC = dyn_cast<BitCastInst>(&I);
if (!BC)
continue;
VectorType *SrcVecTy = getVectorType(BC->getSrcTy());
if (SrcVecTy) {
uint64_t NumElemsInSrcVec = SrcVecTy->getElementCount().getValue();
if (!isValidVectorSize(NumElemsInSrcVec))
report_fatal_error("Unsupported vector type with the size of: " +
std::to_string(NumElemsInSrcVec),
false);
}
VectorType *DestVecTy = getVectorType(BC->getDestTy());
if (DestVecTy) {
uint64_t NumElemsInDestVec = DestVecTy->getElementCount().getValue();
if (!isValidVectorSize(NumElemsInDestVec))
BCastsToNonStdVec.push_back(&I);
}
}
IRBuilder<> Builder(F.getContext());
for (auto &I : BCastsToNonStdVec) {
Value *NewValue = I->getOperand(0);
VectorType *OldVecTy = getVectorType(I->getType());
Changed |=
lowerBitCastToNonStdVec(I, NewValue, OldVecTy, InstsToErase, Builder);
}

for (auto *I : InstsToErase)
I->eraseFromParent();

return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
};

class SPIRVLowerBitCastToNonStandardTypeLegacy : public FunctionPass {
public:
static char ID;
SPIRVLowerBitCastToNonStandardTypeLegacy() : FunctionPass(ID) {}

bool runOnFunction(Function &F) override {
FunctionAnalysisManager FAM;
auto PA = Impl.runLowerBitCastToNonStandardType(F, FAM);
return !PA.areAllPreserved();
}

bool doFinalization(Module &M) override {
verifyRegularizationPass(M, "SPIRVLowerBitCastToNonStandardType");
return false;
}

StringRef getPassName() const override { return "Lower nonstandard type"; }

private:
SPIRVLowerBitCastToNonStandardTypePass Impl;
};

char SPIRVLowerBitCastToNonStandardTypeLegacy::ID = 0;

} // namespace SPIRV

INITIALIZE_PASS(SPIRVLowerBitCastToNonStandardTypeLegacy,
"spv-lower-bitcast-to-nonstandard-type",
"Remove bitcast to nonstandard types", false, false)

llvm::FunctionPass *llvm::createSPIRVLowerBitCastToNonStandardTypeLegacy() {
return new SPIRVLowerBitCastToNonStandardTypeLegacy();
}
1 change: 1 addition & 0 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4266,6 +4266,7 @@ void addPassesForSPIRV(legacy::PassManager &PassMgr,
PassMgr.add(createSPIRVLowerBoolLegacy());
PassMgr.add(createSPIRVLowerMemmoveLegacy());
PassMgr.add(createSPIRVLowerSaddWithOverflowLegacy());
PassMgr.add(createSPIRVLowerBitCastToNonStandardTypeLegacy());
}

bool isValidLLVMModule(Module *M, SPIRVErrorLog &ErrorLog) {
Expand Down
50 changes: 50 additions & 0 deletions test/lower-non-standard-types.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv -s %t.bc -o - | llvm-dis -o - | FileCheck %s --implicit-check-not="<6 x i32>"

; CHECK: [[ASCastInst:%.*]] = addrspacecast <3 x i64> addrspace(1)* @Id to <3 x i64> addrspace(4)*
; CHECK: [[LoadInst1:%.*]] = load <3 x i64>, <3 x i64> addrspace(4)* [[ASCastInst]], align 32
; CHECK: [[ExtrElInst1:%.*]] = extractelement <3 x i64> [[LoadInst1]], i64 0
; CHECK: [[TruncInst1:%.*]] = trunc i64 [[ExtrElInst1]] to i32
; CHECK: [[LoadInst2:%.*]] = load <3 x i64>, <3 x i64> addrspace(4)* [[ASCastInst]], align 32
; CHECK: [[ExtrElInst2:%.*]] = extractelement <3 x i64> [[LoadInst2]], i64 2
; CHECK: [[LShrInst:%.*]] = lshr i64 [[ExtrElInst2]], 32
; CHECK: [[TruncInst2:%.*]] = trunc i64 [[LShrInst]] to i32
; CHECK: %conv1 = sitofp i32 [[TruncInst1]] to float
; CHECK: %conv2 = sitofp i32 [[TruncInst2]] to float

; ModuleID = 'lower-non-standard-types'
source_filename = "lower-non-standard-types.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown-sycldevice"

@Id = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; Function Attrs: convergent norecurse
define dso_local spir_func void @vmult2() local_unnamed_addr #0 !sycl_explicit_simd !4 !intel_reqd_sub_group_size !6 {
entry:
%0 = load <6 x i32>, <6 x i32> addrspace(4)* addrspacecast (<6 x i32> addrspace(1)* bitcast (<3 x i64> addrspace(1)* @Id to <6 x i32> addrspace(1)*) to <6 x i32> addrspace(4)*), align 32
%1 = load <6 x i32>, <6 x i32> addrspace(4)* addrspacecast (<6 x i32> addrspace(1)* bitcast (<3 x i64> addrspace(1)* @Id to <6 x i32> addrspace(1)*) to <6 x i32> addrspace(4)*), align 32
%2 = extractelement <6 x i32> %0, i32 1
%3 = extractelement <6 x i32> %1, i32 4
%conv1 = sitofp i32 %2 to float
%conv2 = sitofp i32 %3 to float
ret void
}

attributes #0 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-width"="256" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="lower-external-funcs-with-z.cpp" }

!llvm.module.flags = !{!0, !1}
!opencl.spir.version = !{!2}
!spirv.Source = !{!3}
!opencl.used.extensions = !{!4}
!opencl.used.optional.core.features = !{!4}
!opencl.compiler.options = !{!4}
!llvm.ident = !{!5}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 7, !"frame-pointer", i32 2}
!2 = !{i32 1, i32 2}
!3 = !{i32 0, i32 100000}
!4 = !{}
!5 = !{!"Compiler"}
!6 = !{i32 1}

0 comments on commit 0770392

Please sign in to comment.