Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESIMD] Add set_kernel_properties API and use_double_grf property. #6182

Merged
merged 15 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===----------- ESIMDUtils.hpp - ESIMD t-forms-related utility functions ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Utility functions for processing ESIMD code.
//===----------------------------------------------------------------------===//

#include "llvm/IR/Function.h"

#include <functional>

namespace llvm {
namespace esimd {

constexpr char ATTR_DOUBLE_GRF[] = "esimd-double-grf";

using CallGraphNodeAction = std::function<void(Function *)>;
void traverseCallgraphUp(llvm::Function *F, CallGraphNodeAction NodeF,
bool ErrorOnNonCallUse);

// Traverses call graph starting from given function up the call chain applying
// given action to each function met on the way. If \c ErrorOnNonCallUse
// parameter is true, then no functions' uses are allowed except calls.
// Otherwise, any function where use of the current one happened is added to the
// call graph as if the use was a call.
template <class CallGraphNodeActionF>
void traverseCallgraphUp(Function *F, CallGraphNodeActionF ActionF,
bool ErrorOnNonCallUse = true) {
traverseCallgraphUp(F, CallGraphNodeAction{ActionF}, ErrorOnNonCallUse);
}

// Tells whether given function is a ESIMD kernel.
bool isESIMDKernel(const Function &F);

} // namespace esimd
} // namespace llvm
7 changes: 7 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/LowerESIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ class ESIMDLowerVecArgPass : public PassInfoMixin<ESIMDLowerVecArgPass> {
ModulePass *createESIMDLowerVecArgPass();
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);

// Lowers calls to __esimd_set_kernel_properties
class SYCLLowerESIMDKernelPropsPass
: public PassInfoMixin<SYCLLowerESIMDKernelPropsPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};

} // namespace llvm

#endif // LLVM_SYCLLOWERIR_LOWERESIMD_H
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ MODULE_PASS("memprof-module", ModuleMemProfilerPass())
MODULE_PASS("poison-checking", PoisonCheckingPass())
MODULE_PASS("pseudo-probe-update", PseudoProbeUpdatePass())
MODULE_PASS("LowerESIMD", SYCLLowerESIMDPass())
MODULE_PASS("lower-esimd-kernel-props", SYCLLowerESIMDKernelPropsPass())
MODULE_PASS("ESIMDLowerVecArg", ESIMDLowerVecArgPass())
MODULE_PASS("esimd-verifier", ESIMDVerifierPass())
MODULE_PASS("lower-invoke-simd", SYCLLowerInvokeSimdPass())
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ set_property(GLOBAL PROPERTY LLVMGenXIntrinsics_BINARY_PROP ${LLVMGenXIntrinsics

add_llvm_component_library(LLVMSYCLLowerIR
ESIMD/LowerESIMD.cpp
ESIMD/LowerESIMDKernelProps.cpp
ESIMD/LowerESIMDVLoadVStore.cpp
ESIMD/LowerESIMDVecArg.cpp
ESIMD/ESIMDUtils.cpp
ESIMD/ESIMDVerifier.cpp
LowerInvokeSimd.cpp
LowerWGScope.cpp
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"

#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/Casting.h"

namespace llvm {
namespace esimd {

void traverseCallgraphUp(llvm::Function *F, CallGraphNodeAction ActionF,
bool ErrorOnNonCallUse) {
SmallPtrSet<Function *, 32> FunctionsVisited;
SmallVector<Function *, 32> Worklist{F};

while (!Worklist.empty()) {
Function *CurF = Worklist.pop_back_val();
FunctionsVisited.insert(CurF);
// Apply the action function.
ActionF(CurF);

// Update all callers as well.
for (auto It = CurF->use_begin(); It != CurF->use_end(); It++) {
auto FCall = It->getUser();
auto ErrMsg =
llvm::Twine(__FILE__ " ") +
"Function use other than call detected while traversing call\n"
"graph up to a kernel";
if (!isa<CallInst>(FCall)) {
// A use other than a call is met...
if (ErrorOnNonCallUse) {
// ... non-call is an error - report
llvm::report_fatal_error(ErrMsg);
} else {
// ... non-call is OK - add using function to the worklist
if (auto *I = dyn_cast<Instruction>(FCall)) {
auto UseF = I->getFunction();

if (!FunctionsVisited.count(UseF)) {
Worklist.push_back(UseF);
}
}
}
} else {
auto *CI = cast<CallInst>(FCall);

if ((CI->getCalledFunction() != CurF) && ErrorOnNonCallUse) {
// CurF is used in a call, but not as the callee.
llvm::report_fatal_error(ErrMsg);
} else {
auto FCaller = CI->getFunction();

if (!FunctionsVisited.count(FCaller)) {
Worklist.push_back(FCaller);
}
}
}
}
}
}

bool isESIMDKernel(const Function &F) {
return (F.getCallingConv() == CallingConv::SPIR_KERNEL) &&
(F.getMetadata("sycl_explicit_simd") != nullptr);
}

} // namespace esimd
} // namespace llvm
113 changes: 61 additions & 52 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
Expand Down Expand Up @@ -71,7 +72,8 @@ class SYCLLowerESIMDLegacyPass : public ModulePass {

char SYCLLowerESIMDLegacyPass::ID = 0;
INITIALIZE_PASS(SYCLLowerESIMDLegacyPass, "LowerESIMD",
"Lower constructs specific to Close To Metal", false, false)
"Lower constructs specific to the 'explicit SIMD' extension",
false, false)

// Public interface to the SYCLLowerESIMDPass.
ModulePass *llvm::createSYCLLowerESIMDPass() {
Expand Down Expand Up @@ -899,59 +901,63 @@ static inline llvm::Metadata *getMD(llvm::Value *V) {
return llvm::ValueAsMetadata::get(V);
}

/// Updates genx.kernels metadata attribute \p MD for the given function \p F.
/// The value of the attribute is updated only if the new value \p NewVal is
/// bigger than what is already stored in the attribute.
// TODO: 1) In general this function is supposed to handle intrinsics
// translated into kernel's metadata. So, the primary/intended usage model is
// when such intrinsics are called from kernels.
// 2) For now such intrinsics are also handled in functions directly called
// from kernels and being translate into those caller-kernel meeven though such
// behaviour is not fully specified/documented.
// 3) This code (or the code in FE) must verify that slm_init or other such
// intrinsic is not called from another module because kernels in that other
// module would not get updated meta data attributes.
static void updateGenXMDNodes(llvm::Function *F, genx::KernelMDOp MD,
uint64_t NewVal) {
llvm::NamedMDNode *GenXKernelMD =
F->getParent()->getNamedMetadata(GENX_KERNEL_METADATA);
assert(GenXKernelMD && "invalid genx.kernels metadata");

SmallPtrSet<Function *, 32> FunctionsVisited;
SmallVector<Function *, 32> Worklist{F};
while (!Worklist.empty()) {
Function *CurF = Worklist.pop_back_val();
FunctionsVisited.insert(CurF);

// Update the meta data attribute for the current function.
// A functor which updates ESIMD kernel's uint64_t metadata in case it is less
// than the given one. Used in callgraph traversal to update nbarriers or SLM
// size metadata. Update is performed by the '()' operator and happens only
// when given function matches one of the kernels - thus, only reachable kernels
// are updated.
struct UpdateUint64MetaDataToMaxValue {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactors the original updateGenXMDNodes:

  1. Call graph traversal is factored out into traverseCallgraphUp above. This functor represents call graph action.
  2. It is slightly optimized to pre-select candidate nodes for fewer actions in the node action function.

Module &M;
// The uint64_t metadata key to update.
genx::KernelMDOp Key;
// The new metadata value. Must be greater than the old for update to happen.
uint64_t NewVal;
// Pre-selected nodes from GENX_KERNEL_METADATA which can only potentially be
// updated.
SmallVector<MDNode *, 4> CandidatesToUpdate;

UpdateUint64MetaDataToMaxValue(Module &M, genx::KernelMDOp Key,
uint64_t NewVal)
: M(M), Key(Key), NewVal(NewVal) {
// Pre-select nodes for update to do less work in the '()' operator.
llvm::NamedMDNode *GenXKernelMD = M.getNamedMetadata(GENX_KERNEL_METADATA);
assert(GenXKernelMD && "invalid genx.kernels metadata");
for (auto Node : GenXKernelMD->operands()) {
if (Node->getNumOperands() <= MD ||
getVal(Node->getOperand(genx::KernelMDOp::FunctionRef)) != CurF)
if (Node->getNumOperands() <= (unsigned)Key) {
continue;

llvm::Value *Old = getVal(Node->getOperand(MD));
}
llvm::Value *Old = getVal(Node->getOperand(Key));
uint64_t OldVal = cast<llvm::ConstantInt>(Old)->getZExtValue();

if (OldVal < NewVal) {
llvm::Value *New = llvm::ConstantInt::get(Old->getType(), NewVal);
Node->replaceOperandWith(MD, getMD(New));
CandidatesToUpdate.push_back(Node);
}
}
}

void operator()(Function *F) {
// Update the meta data attribute for the current function.
for (auto Node : CandidatesToUpdate) {
assert(Node->getNumOperands() > (unsigned)Key);

// Update all callers as well.
for (auto It = CurF->use_begin(); It != CurF->use_end(); It++) {
auto FCall = It->getUser();
if (!isa<CallInst>(FCall))
llvm::report_fatal_error(
llvm::Twine(__FILE__ " ") +
"Found an intrinsic violating assumption on usage from a kernel or "
"a func directly called from a kernel");

auto FCaller = cast<CallInst>(FCall)->getFunction();
if (!FunctionsVisited.count(FCaller))
Worklist.push_back(FCaller);
if (getVal(Node->getOperand(genx::KernelMDOp::FunctionRef)) != F) {
continue;
}
llvm::Value *Old = getVal(Node->getOperand(Key));
#ifndef NDEBUG
uint64_t OldVal = cast<llvm::ConstantInt>(Old)->getZExtValue();
assert(OldVal < NewVal);
#endif // NDEBUG
llvm::Value *New = llvm::ConstantInt::get(Old->getType(), NewVal);
Node->replaceOperandWith(Key, getMD(New));
}
}
}
};

// TODO Specify document behavior for slm_init and nbarrier_init when:
// 1) they are called not from kernels
// 2) there are multiple such calls reachable from a kernel
// 3) when a call in external function linked by the Back-End

// This function sets/updates VCSLMSize attribute to the kernels
// calling this intrinsic initializing SLM memory.
Expand All @@ -964,7 +970,9 @@ static void translateSLMInit(CallInst &CI) {

uint64_t NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue();
assert(NewVal != 0 && "zero slm bytes being requested");
updateGenXMDNodes(F, genx::KernelMDOp::SLMSize, NewVal);
UpdateUint64MetaDataToMaxValue SetMaxSLMSize{
*F->getParent(), genx::KernelMDOp::SLMSize, NewVal};
esimd::traverseCallgraphUp(F, SetMaxSLMSize);
}

// This function sets/updates VCNamedBarrierCount attribute to the kernels
Expand All @@ -979,7 +987,9 @@ static void translateNbarrierInit(CallInst &CI) {

auto NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue();
assert(NewVal != 0 && "zero named barrier count being requested");
updateGenXMDNodes(F, genx::KernelMDOp::NBarrierCnt, NewVal);
UpdateUint64MetaDataToMaxValue SetMaxNBarrierCnt{
*F->getParent(), genx::KernelMDOp::NBarrierCnt, NewVal};
esimd::traverseCallgraphUp(F, SetMaxNBarrierCnt);
}

static void translatePackMask(CallInst &CI) {
Expand Down Expand Up @@ -1514,8 +1524,7 @@ void generateKernelMetadata(Module &M) {

for (auto &F : M.functions()) {
// Skip non-SIMD kernels.
if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
F.getMetadata("sycl_explicit_simd") == nullptr)
if (!esimd::isESIMDKernel(F))
continue;

// Metadata node containing N i32s, where N is the number of kernel
Expand Down Expand Up @@ -1708,15 +1717,14 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,

// process ESIMD builtins that go through special handling instead of
// the translation procedure
// TODO FIXME slm_init should be made top-level __esimd_slm_init

if (Name.startswith("__esimd_slm_init") &&
isa<ConstantInt>(CI->getArgOperand(0))) {
// tag the kernel with meta-data SLMSize, and remove this builtin
translateSLMInit(*CI);
ToErase.push_back(CI);
continue;
}

if (Name.startswith("__esimd_nbarrier_init")) {
translateNbarrierInit(*CI);
ToErase.push_back(CI);
Expand Down Expand Up @@ -1748,12 +1756,13 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
continue;
}
}

if (Name.startswith("__esimd_get_surface_index")) {
translateGetSurfaceIndex(*CI);
ToErase.push_back(CI);
continue;
}
assert(!Name.startswith("__esimd_set_kernel_properties") &&
"__esimd_set_kernel_properties must have been lowered");

if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
continue;
Expand Down
Loading