Skip to content

Commit 0c85362

Browse files
wdx727lifengxiang1025zcfh
authored andcommitted
Automerge: Adding Matching and Inference Functionality to Propeller-PR4: Implement matching and inference and create clusters (#165868)
Adding Matching and Inference Functionality to Propeller. For detailed information, please refer to the following RFC: https://discourse.llvm.org/t/rfc-adding-matching-and-inference-functionality-to-propeller/86238. This is the fourth PR, which is used to implement matching and inference and create the clusters. The associated PRs are: PR1: llvm/llvm-project#160706 PR2: llvm/llvm-project#162963 PR3: llvm/llvm-project#164223 co-authors: lifengxiang1025 [lifengxiang@kuaishou.com](mailto:lifengxiang@kuaishou.com); zcfh [wuminghui03@kuaishou.com](mailto:wuminghui03@kuaishou.com) Co-authored-by: lifengxiang1025 <lifengxiang@kuaishou.com> Co-authored-by: zcfh <wuminghui03@kuaishou.com>
2 parents 95ad781 + 1a88d04 commit 0c85362

13 files changed

+496
-6
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.h ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Infer weights for all basic blocks using matching and inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
14+
#define LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
15+
16+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
17+
#include "llvm/CodeGen/MachineFunctionPass.h"
18+
#include "llvm/Transforms/Utils/SampleProfileInference.h"
19+
20+
namespace llvm {
21+
22+
class BasicBlockMatchingAndInference : public MachineFunctionPass {
23+
private:
24+
using Edge = std::pair<const MachineBasicBlock *, const MachineBasicBlock *>;
25+
using BlockWeightMap = DenseMap<const MachineBasicBlock *, uint64_t>;
26+
using EdgeWeightMap = DenseMap<Edge, uint64_t>;
27+
using BlockEdgeMap = DenseMap<const MachineBasicBlock *,
28+
SmallVector<const MachineBasicBlock *, 8>>;
29+
30+
struct WeightInfo {
31+
// Weight of basic blocks.
32+
BlockWeightMap BlockWeights;
33+
// Weight of edges.
34+
EdgeWeightMap EdgeWeights;
35+
};
36+
37+
public:
38+
static char ID;
39+
BasicBlockMatchingAndInference();
40+
41+
StringRef getPassName() const override {
42+
return "Basic Block Matching and Inference";
43+
}
44+
45+
void getAnalysisUsage(AnalysisUsage &AU) const override;
46+
47+
bool runOnMachineFunction(MachineFunction &F) override;
48+
49+
std::optional<WeightInfo> getWeightInfo(StringRef FuncName) const;
50+
51+
private:
52+
StringMap<WeightInfo> ProgramWeightInfo;
53+
54+
WeightInfo initWeightInfoByMatching(MachineFunction &MF);
55+
56+
void generateWeightInfoByInference(MachineFunction &MF,
57+
WeightInfo &MatchWeight);
58+
};
59+
60+
} // end namespace llvm
61+
62+
#endif // LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H

llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class BasicBlockSectionsProfileReader {
8686
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
8787
const UniqueBBID &SinkBBID) const;
8888

89+
// Return the complete function path and cluster info for the given function.
90+
std::pair<bool, FunctionPathAndClusterInfo>
91+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
92+
8993
private:
9094
StringRef getAliasName(StringRef FuncName) const {
9195
auto R = FuncAliasMap.find(FuncName);
@@ -195,6 +199,9 @@ class BasicBlockSectionsProfileReaderWrapperPass : public ImmutablePass {
195199
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
196200
const UniqueBBID &DestBBID) const;
197201

202+
std::pair<bool, FunctionPathAndClusterInfo>
203+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
204+
198205
// Initializes the FunctionNameToDIFilename map for the current module and
199206
// then reads the profile for the matching functions.
200207
bool doInitialization(Module &M) override;

llvm/include/llvm/CodeGen/MachineBlockHashInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ struct BlendedBlockHash {
8080
return Dist;
8181
}
8282

83+
uint16_t getOpcodeHash() const { return OpcodeHash; }
84+
8385
private:
8486
/// The offset of the basic block from the function start.
8587
uint16_t Offset{0};

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ LLVM_ABI MachineFunctionPass *createBasicBlockSectionsPass();
6969

7070
LLVM_ABI MachineFunctionPass *createBasicBlockPathCloningPass();
7171

72+
/// createBasicBlockMatchingAndInferencePass - This pass enables matching
73+
/// and inference when using propeller.
74+
LLVM_ABI MachineFunctionPass *createBasicBlockMatchingAndInferencePass();
75+
7276
/// createMachineBlockHashInfoPass - This pass computes basic block hashes.
7377
LLVM_ABI MachineFunctionPass *createMachineBlockHashInfoPass();
7478

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ LLVM_ABI void initializeAlwaysInlinerLegacyPassPass(PassRegistry &);
5555
LLVM_ABI void initializeAssignmentTrackingAnalysisPass(PassRegistry &);
5656
LLVM_ABI void initializeAssumptionCacheTrackerPass(PassRegistry &);
5757
LLVM_ABI void initializeAtomicExpandLegacyPass(PassRegistry &);
58+
LLVM_ABI void initializeBasicBlockMatchingAndInferencePass(PassRegistry &);
5859
LLVM_ABI void initializeBasicBlockPathCloningPass(PassRegistry &);
5960
LLVM_ABI void
6061
initializeBasicBlockSectionsProfileReaderWrapperPassPass(PassRegistry &);

llvm/include/llvm/Transforms/Utils/SampleProfileInference.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ template <typename FT> class SampleProfileInference {
130130
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
131131
BlockWeightMap &SampleBlockWeights)
132132
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights) {}
133+
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
134+
BlockWeightMap &SampleBlockWeights,
135+
EdgeWeightMap &SampleEdgeWeights)
136+
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights),
137+
SampleEdgeWeights(SampleEdgeWeights) {}
133138

134139
/// Apply the profile inference algorithm for a given function
135140
void apply(BlockWeightMap &BlockWeights, EdgeWeightMap &EdgeWeights);
@@ -157,6 +162,9 @@ template <typename FT> class SampleProfileInference {
157162

158163
/// Map basic blocks to their sampled weights.
159164
BlockWeightMap &SampleBlockWeights;
165+
166+
/// Map edges to their sampled weights.
167+
EdgeWeightMap SampleEdgeWeights;
160168
};
161169

162170
template <typename BT>
@@ -266,6 +274,14 @@ FlowFunction SampleProfileInference<BT>::createFlowFunction(
266274
FlowJump Jump;
267275
Jump.Source = BlockIndex[BB];
268276
Jump.Target = BlockIndex[Succ];
277+
auto It = SampleEdgeWeights.find(std::make_pair(BB, Succ));
278+
if (It != SampleEdgeWeights.end()) {
279+
Jump.HasUnknownWeight = false;
280+
Jump.Weight = It->second;
281+
} else {
282+
Jump.HasUnknownWeight = true;
283+
Jump.Weight = 0;
284+
}
269285
Func.Jumps.push_back(Jump);
270286
}
271287
}
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.cpp ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// In Propeller's profile, we have already read the hash values of basic blocks,
10+
// as well as the weights of basic blocks and edges in the CFG. In this file,
11+
// we first match the basic blocks in the profile with those in the current
12+
// MachineFunction using the basic block hash, thereby obtaining the weights of
13+
// some basic blocks and edges. Subsequently, we infer the weights of all basic
14+
// blocks using an inference algorithm.
15+
//
16+
// TODO: Integrate part of the code in this file with BOLT's implementation into
17+
// the LLVM infrastructure, enabling both BOLT and Propeller to reuse it.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
21+
#include "llvm/CodeGen/BasicBlockMatchingAndInference.h"
22+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
23+
#include "llvm/CodeGen/MachineBlockHashInfo.h"
24+
#include "llvm/CodeGen/Passes.h"
25+
#include "llvm/InitializePasses.h"
26+
#include <llvm/Support/CommandLine.h>
27+
28+
using namespace llvm;
29+
30+
static cl::opt<float>
31+
PropellerInferThreshold("propeller-infer-threshold",
32+
cl::desc("Threshold for infer stale profile"),
33+
cl::init(0.6), cl::Optional);
34+
35+
/// The object is used to identify and match basic blocks given their hashes.
36+
class StaleMatcher {
37+
public:
38+
/// Initialize stale matcher.
39+
void init(const std::vector<MachineBasicBlock *> &Blocks,
40+
const std::vector<BlendedBlockHash> &Hashes) {
41+
assert(Blocks.size() == Hashes.size() &&
42+
"incorrect matcher initialization");
43+
for (size_t I = 0; I < Blocks.size(); I++) {
44+
MachineBasicBlock *Block = Blocks[I];
45+
uint16_t OpHash = Hashes[I].getOpcodeHash();
46+
OpHashToBlocks[OpHash].push_back(std::make_pair(Hashes[I], Block));
47+
}
48+
}
49+
50+
/// Find the most similar block for a given hash.
51+
MachineBasicBlock *matchBlock(BlendedBlockHash BlendedHash) const {
52+
auto BlockIt = OpHashToBlocks.find(BlendedHash.getOpcodeHash());
53+
if (BlockIt == OpHashToBlocks.end()) {
54+
return nullptr;
55+
}
56+
MachineBasicBlock *BestBlock = nullptr;
57+
uint64_t BestDist = std::numeric_limits<uint64_t>::max();
58+
for (auto It : BlockIt->second) {
59+
MachineBasicBlock *Block = It.second;
60+
BlendedBlockHash Hash = It.first;
61+
uint64_t Dist = Hash.distance(BlendedHash);
62+
if (BestBlock == nullptr || Dist < BestDist) {
63+
BestDist = Dist;
64+
BestBlock = Block;
65+
}
66+
}
67+
return BestBlock;
68+
}
69+
70+
private:
71+
using HashBlockPairType = std::pair<BlendedBlockHash, MachineBasicBlock *>;
72+
std::unordered_map<uint16_t, std::vector<HashBlockPairType>> OpHashToBlocks;
73+
};
74+
75+
INITIALIZE_PASS_BEGIN(BasicBlockMatchingAndInference,
76+
"machine-block-match-infer",
77+
"Machine Block Matching and Inference Analysis", true,
78+
true)
79+
INITIALIZE_PASS_DEPENDENCY(MachineBlockHashInfo)
80+
INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
81+
INITIALIZE_PASS_END(BasicBlockMatchingAndInference, "machine-block-match-infer",
82+
"Machine Block Matching and Inference Analysis", true, true)
83+
84+
char BasicBlockMatchingAndInference::ID = 0;
85+
86+
BasicBlockMatchingAndInference::BasicBlockMatchingAndInference()
87+
: MachineFunctionPass(ID) {
88+
initializeBasicBlockMatchingAndInferencePass(
89+
*PassRegistry::getPassRegistry());
90+
}
91+
92+
void BasicBlockMatchingAndInference::getAnalysisUsage(AnalysisUsage &AU) const {
93+
AU.addRequired<MachineBlockHashInfo>();
94+
AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
95+
AU.setPreservesAll();
96+
MachineFunctionPass::getAnalysisUsage(AU);
97+
}
98+
99+
std::optional<BasicBlockMatchingAndInference::WeightInfo>
100+
BasicBlockMatchingAndInference::getWeightInfo(StringRef FuncName) const {
101+
auto It = ProgramWeightInfo.find(FuncName);
102+
if (It == ProgramWeightInfo.end()) {
103+
return std::nullopt;
104+
}
105+
return It->second;
106+
}
107+
108+
BasicBlockMatchingAndInference::WeightInfo
109+
BasicBlockMatchingAndInference::initWeightInfoByMatching(MachineFunction &MF) {
110+
std::vector<MachineBasicBlock *> Blocks;
111+
std::vector<BlendedBlockHash> Hashes;
112+
auto BSPR = &getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>();
113+
auto MBHI = &getAnalysis<MachineBlockHashInfo>();
114+
for (auto &Block : MF) {
115+
Blocks.push_back(&Block);
116+
Hashes.push_back(BlendedBlockHash(MBHI->getMBBHash(Block)));
117+
}
118+
StaleMatcher Matcher;
119+
Matcher.init(Blocks, Hashes);
120+
BasicBlockMatchingAndInference::WeightInfo MatchWeight;
121+
auto [IsValid, PathAndClusterInfo] =
122+
BSPR->getFunctionPathAndClusterInfo(MF.getName());
123+
if (!IsValid)
124+
return MatchWeight;
125+
for (auto &BlockCount : PathAndClusterInfo.NodeCounts) {
126+
if (PathAndClusterInfo.BBHashes.count(BlockCount.first.BaseID)) {
127+
auto Hash = PathAndClusterInfo.BBHashes[BlockCount.first.BaseID];
128+
MachineBasicBlock *Block = Matcher.matchBlock(BlendedBlockHash(Hash));
129+
// When a basic block has clone copies, sum their counts.
130+
if (Block != nullptr)
131+
MatchWeight.BlockWeights[Block] += BlockCount.second;
132+
}
133+
}
134+
for (auto &PredItem : PathAndClusterInfo.EdgeCounts) {
135+
auto PredID = PredItem.first.BaseID;
136+
if (!PathAndClusterInfo.BBHashes.count(PredID))
137+
continue;
138+
auto PredHash = PathAndClusterInfo.BBHashes[PredID];
139+
MachineBasicBlock *PredBlock =
140+
Matcher.matchBlock(BlendedBlockHash(PredHash));
141+
if (PredBlock == nullptr)
142+
continue;
143+
for (auto &SuccItem : PredItem.second) {
144+
auto SuccID = SuccItem.first.BaseID;
145+
auto EdgeWeight = SuccItem.second;
146+
if (PathAndClusterInfo.BBHashes.count(SuccID)) {
147+
auto SuccHash = PathAndClusterInfo.BBHashes[SuccID];
148+
MachineBasicBlock *SuccBlock =
149+
Matcher.matchBlock(BlendedBlockHash(SuccHash));
150+
// When an edge has clone copies, sum their counts.
151+
if (SuccBlock != nullptr)
152+
MatchWeight.EdgeWeights[std::make_pair(PredBlock, SuccBlock)] +=
153+
EdgeWeight;
154+
}
155+
}
156+
}
157+
return MatchWeight;
158+
}
159+
160+
void BasicBlockMatchingAndInference::generateWeightInfoByInference(
161+
MachineFunction &MF,
162+
BasicBlockMatchingAndInference::WeightInfo &MatchWeight) {
163+
BlockEdgeMap Successors;
164+
for (auto &Block : MF) {
165+
for (auto *Succ : Block.successors())
166+
Successors[&Block].push_back(Succ);
167+
}
168+
SampleProfileInference<MachineFunction> SPI(
169+
MF, Successors, MatchWeight.BlockWeights, MatchWeight.EdgeWeights);
170+
BlockWeightMap BlockWeights;
171+
EdgeWeightMap EdgeWeights;
172+
SPI.apply(BlockWeights, EdgeWeights);
173+
ProgramWeightInfo.try_emplace(
174+
MF.getName(), BasicBlockMatchingAndInference::WeightInfo{
175+
std::move(BlockWeights), std::move(EdgeWeights)});
176+
}
177+
178+
bool BasicBlockMatchingAndInference::runOnMachineFunction(MachineFunction &MF) {
179+
if (MF.empty())
180+
return false;
181+
auto MatchWeight = initWeightInfoByMatching(MF);
182+
// If the ratio of the number of MBBs in matching to the total number of MBBs
183+
// in the function is less than the threshold value, the processing should be
184+
// abandoned.
185+
if (static_cast<float>(MatchWeight.BlockWeights.size()) / MF.size() <
186+
PropellerInferThreshold) {
187+
return false;
188+
}
189+
generateWeightInfoByInference(MF, MatchWeight);
190+
return false;
191+
}
192+
193+
MachineFunctionPass *llvm::createBasicBlockMatchingAndInferencePass() {
194+
return new BasicBlockMatchingAndInference();
195+
}

0 commit comments

Comments
 (0)