Skip to content

Commit 9e3ce6e

Browse files
wdx727lifengxiang1025zcfh
committed
Adding Matching and Inference Functionality to Propeller-PR2: Compute basic block hash and add Matching and Inference.
Co-authored-by: lifengxiang1025 <lifengxiang@kuaishou.com> Co-authored-by: zcfh <wuminghui03@kuaishou.com>
1 parent 9a46060 commit 9e3ce6e

20 files changed

+682
-16
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.h ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Infer weights for all basic blocks using matching and inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
14+
#define LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
15+
16+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
17+
#include "llvm/CodeGen/MachineFunctionPass.h"
18+
#include "llvm/Transforms/Utils/SampleProfileInference.h"
19+
20+
namespace llvm {
21+
22+
class BasicBlockMatchingAndInference : public MachineFunctionPass {
23+
private:
24+
using Edge = std::pair<const MachineBasicBlock *, const MachineBasicBlock *>;
25+
using BlockWeightMap = DenseMap<const MachineBasicBlock *, uint64_t>;
26+
using EdgeWeightMap = DenseMap<Edge, uint64_t>;
27+
using BlockEdgeMap = DenseMap<const MachineBasicBlock *,
28+
SmallVector<const MachineBasicBlock *, 8>>;
29+
30+
struct WeightInfo {
31+
// Weight of basic blocks.
32+
BlockWeightMap BlockWeights;
33+
// Weight of edges.
34+
EdgeWeightMap EdgeWeights;
35+
};
36+
37+
public:
38+
static char ID;
39+
BasicBlockMatchingAndInference();
40+
41+
StringRef getPassName() const override {
42+
return "Basic Block Matching and Inference";
43+
}
44+
45+
void getAnalysisUsage(AnalysisUsage &AU) const override;
46+
47+
bool runOnMachineFunction(MachineFunction &F) override;
48+
49+
std::optional<WeightInfo> getWeightInfo(StringRef FuncName) const;
50+
51+
private:
52+
StringMap<WeightInfo> ProgramWeightInfo;
53+
54+
WeightInfo initWeightInfoByMatching(MachineFunction &MF);
55+
56+
void generateWeightInfoByInference(MachineFunction &MF,
57+
WeightInfo &MatchWeight);
58+
};
59+
60+
} // end namespace llvm
61+
62+
#endif // LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H

llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ struct FunctionPathAndClusterInfo {
5454
DenseMap<UniqueBBID, uint64_t> NodeCounts;
5555
// Edge counts for each edge, stored as a nested map.
5656
DenseMap<UniqueBBID, DenseMap<UniqueBBID, uint64_t>> EdgeCounts;
57+
// Hash for each basic block.
58+
DenseMap<unsigned, uint64_t> BBHashes;
5759
};
5860

5961
class BasicBlockSectionsProfileReader {
@@ -86,6 +88,10 @@ class BasicBlockSectionsProfileReader {
8688
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
8789
const UniqueBBID &SinkBBID) const;
8890

91+
// Return the complete function path and cluster info for the given function.
92+
std::pair<bool, FunctionPathAndClusterInfo>
93+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
94+
8995
private:
9096
StringRef getAliasName(StringRef FuncName) const {
9197
auto R = FuncAliasMap.find(FuncName);
@@ -195,6 +201,9 @@ class BasicBlockSectionsProfileReaderWrapperPass : public ImmutablePass {
195201
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
196202
const UniqueBBID &DestBBID) const;
197203

204+
std::pair<bool, FunctionPathAndClusterInfo>
205+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
206+
198207
// Initializes the FunctionNameToDIFilename map for the current module and
199208
// then reads the profile for the matching functions.
200209
bool doInitialization(Module &M) override;
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===- llvm/CodeGen/MachineBlockHashInfo.h ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Compute the hashes of basic blocks.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CODEGEN_MACHINEBLOCKHASHINFO_H
14+
#define LLVM_CODEGEN_MACHINEBLOCKHASHINFO_H
15+
16+
#include "llvm/CodeGen/MachineFunctionPass.h"
17+
18+
namespace llvm {
19+
20+
/// An object wrapping several components of a basic block hash. The combined
21+
/// (blended) hash is represented and stored as one uint64_t, while individual
22+
/// components are of smaller size (e.g., uint16_t or uint8_t).
23+
struct BlendedBlockHash {
24+
private:
25+
static uint64_t combineHashes(uint16_t Hash1, uint16_t Hash2, uint16_t Hash3,
26+
uint16_t Hash4) {
27+
uint64_t Hash = 0;
28+
29+
Hash |= uint64_t(Hash4);
30+
Hash <<= 16;
31+
32+
Hash |= uint64_t(Hash3);
33+
Hash <<= 16;
34+
35+
Hash |= uint64_t(Hash2);
36+
Hash <<= 16;
37+
38+
Hash |= uint64_t(Hash1);
39+
40+
return Hash;
41+
}
42+
43+
static void parseHashes(uint64_t Hash, uint16_t &Hash1, uint16_t &Hash2,
44+
uint16_t &Hash3, uint16_t &Hash4) {
45+
Hash1 = Hash & 0xffff;
46+
Hash >>= 16;
47+
48+
Hash2 = Hash & 0xffff;
49+
Hash >>= 16;
50+
51+
Hash3 = Hash & 0xffff;
52+
Hash >>= 16;
53+
54+
Hash4 = Hash & 0xffff;
55+
Hash >>= 16;
56+
}
57+
58+
public:
59+
explicit BlendedBlockHash() {}
60+
61+
explicit BlendedBlockHash(uint64_t CombinedHash) {
62+
parseHashes(CombinedHash, Offset, OpcodeHash, InstrHash, NeighborHash);
63+
}
64+
65+
/// Combine the blended hash into uint64_t.
66+
uint64_t combine() const {
67+
return combineHashes(Offset, OpcodeHash, InstrHash, NeighborHash);
68+
}
69+
70+
/// Compute a distance between two given blended hashes. The smaller the
71+
/// distance, the more similar two blocks are. For identical basic blocks,
72+
/// the distance is zero.
73+
uint64_t distance(const BlendedBlockHash &BBH) const {
74+
assert(OpcodeHash == BBH.OpcodeHash &&
75+
"incorrect blended hash distance computation");
76+
uint64_t Dist = 0;
77+
// Account for NeighborHash
78+
Dist += NeighborHash == BBH.NeighborHash ? 0 : 1;
79+
Dist <<= 16;
80+
// Account for InstrHash
81+
Dist += InstrHash == BBH.InstrHash ? 0 : 1;
82+
Dist <<= 16;
83+
// Account for Offset
84+
Dist += (Offset >= BBH.Offset ? Offset - BBH.Offset : BBH.Offset - Offset);
85+
return Dist;
86+
}
87+
88+
/// The offset of the basic block from the function start.
89+
uint16_t Offset{0};
90+
/// (Loose) Hash of the basic block instructions, excluding operands.
91+
uint16_t OpcodeHash{0};
92+
/// (Strong) Hash of the basic block instructions, including opcodes and
93+
/// operands.
94+
uint16_t InstrHash{0};
95+
/// Hash of the (loose) basic block together with (loose) hashes of its
96+
/// successors and predecessors.
97+
uint16_t NeighborHash{0};
98+
};
99+
100+
class MachineBlockHashInfo : public MachineFunctionPass {
101+
DenseMap<unsigned, uint64_t> MBBHashInfo;
102+
103+
public:
104+
static char ID;
105+
MachineBlockHashInfo();
106+
107+
StringRef getPassName() const override { return "Basic Block Hash Compute"; }
108+
109+
void getAnalysisUsage(AnalysisUsage &AU) const override;
110+
111+
bool runOnMachineFunction(MachineFunction &F) override;
112+
113+
uint64_t getMBBHash(const MachineBasicBlock &MBB);
114+
};
115+
116+
} // end namespace llvm
117+
118+
#endif // LLVM_CODEGEN_MACHINEBLOCKHASHINFO_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ LLVM_ABI MachineFunctionPass *createBasicBlockSectionsPass();
6969

7070
LLVM_ABI MachineFunctionPass *createBasicBlockPathCloningPass();
7171

72+
/// createBasicBlockMatchingAndInferencePass - This pass enables matching
73+
/// and inference when using propeller.
74+
MachineFunctionPass *createBasicBlockMatchingAndInferencePass();
75+
76+
/// createMachineBlockHashInfoPass - This pass computes basic block hashes.
77+
MachineFunctionPass *createMachineBlockHashInfoPass();
78+
7279
/// createMachineFunctionSplitterPass - This pass splits machine functions
7380
/// using profile information.
7481
LLVM_ABI MachineFunctionPass *createMachineFunctionSplitterPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,15 @@ LLVM_ABI void initializeAlwaysInlinerLegacyPassPass(PassRegistry &);
5555
LLVM_ABI void initializeAssignmentTrackingAnalysisPass(PassRegistry &);
5656
LLVM_ABI void initializeAssumptionCacheTrackerPass(PassRegistry &);
5757
LLVM_ABI void initializeAtomicExpandLegacyPass(PassRegistry &);
58+
LLVM_ABI void initializeBasicBlockMatchingAndInferencePass(PassRegistry &);
5859
LLVM_ABI void initializeBasicBlockPathCloningPass(PassRegistry &);
5960
LLVM_ABI void
6061
initializeBasicBlockSectionsProfileReaderWrapperPassPass(PassRegistry &);
6162
LLVM_ABI void initializeBasicBlockSectionsPass(PassRegistry &);
6263
LLVM_ABI void initializeBarrierNoopPass(PassRegistry &);
6364
LLVM_ABI void initializeBasicAAWrapperPassPass(PassRegistry &);
6465
LLVM_ABI void initializeBlockFrequencyInfoWrapperPassPass(PassRegistry &);
66+
LLVM_ABI void initializeMachineBlockHashInfoPass(PassRegistry &);
6567
LLVM_ABI void initializeBranchFolderLegacyPass(PassRegistry &);
6668
LLVM_ABI void initializeBranchProbabilityInfoWrapperPassPass(PassRegistry &);
6769
LLVM_ABI void initializeBranchRelaxationLegacyPass(PassRegistry &);

llvm/include/llvm/MC/MCContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class MCContext {
175175
unsigned GetInstance(unsigned LocalLabelVal);
176176

177177
/// SHT_LLVM_BB_ADDR_MAP version to emit.
178-
uint8_t BBAddrMapVersion = 3;
178+
uint8_t BBAddrMapVersion = 4;
179179

180180
/// The file name of the log file from the environment variable
181181
/// AS_SECURE_LOG_FILE. Which must be set before the .secure_log_unique

llvm/include/llvm/Transforms/Utils/SampleProfileInference.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ template <typename FT> class SampleProfileInference {
130130
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
131131
BlockWeightMap &SampleBlockWeights)
132132
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights) {}
133+
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
134+
BlockWeightMap &SampleBlockWeights,
135+
EdgeWeightMap &SampleEdgeWeights)
136+
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights),
137+
SampleEdgeWeights(SampleEdgeWeights) {}
133138

134139
/// Apply the profile inference algorithm for a given function
135140
void apply(BlockWeightMap &BlockWeights, EdgeWeightMap &EdgeWeights);
@@ -157,6 +162,9 @@ template <typename FT> class SampleProfileInference {
157162

158163
/// Map basic blocks to their sampled weights.
159164
BlockWeightMap &SampleBlockWeights;
165+
166+
/// Map edges to their sampled weights.
167+
EdgeWeightMap SampleEdgeWeights;
160168
};
161169

162170
template <typename BT>
@@ -266,6 +274,14 @@ FlowFunction SampleProfileInference<BT>::createFlowFunction(
266274
FlowJump Jump;
267275
Jump.Source = BlockIndex[BB];
268276
Jump.Target = BlockIndex[Succ];
277+
auto It = SampleEdgeWeights.find(std::make_pair(BB, Succ));
278+
if (It != SampleEdgeWeights.end()) {
279+
Jump.HasUnknownWeight = false;
280+
Jump.Weight = It->second;
281+
} else {
282+
Jump.HasUnknownWeight = true;
283+
Jump.Weight = 0;
284+
}
269285
Func.Jumps.push_back(Jump);
270286
}
271287
}

llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "llvm/CodeGen/GCMetadataPrinter.h"
4242
#include "llvm/CodeGen/LazyMachineBlockFrequencyInfo.h"
4343
#include "llvm/CodeGen/MachineBasicBlock.h"
44+
#include "llvm/CodeGen/MachineBlockHashInfo.h"
4445
#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
4546
#include "llvm/CodeGen/MachineConstantPool.h"
4647
#include "llvm/CodeGen/MachineDominators.h"
@@ -183,6 +184,8 @@ static cl::opt<bool> PrintLatency(
183184
cl::desc("Print instruction latencies as verbose asm comments"), cl::Hidden,
184185
cl::init(false));
185186

187+
extern cl::opt<bool> EmitBBHash;
188+
186189
STATISTIC(EmittedInsts, "Number of machine instrs printed");
187190

188191
char AsmPrinter::ID = 0;
@@ -473,6 +476,8 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
473476
AU.addRequired<GCModuleInfo>();
474477
AU.addRequired<LazyMachineBlockFrequencyInfoPass>();
475478
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
479+
if (EmitBBHash)
480+
AU.addRequired<MachineBlockHashInfo>();
476481
}
477482

478483
bool AsmPrinter::doInitialization(Module &M) {
@@ -1438,7 +1443,7 @@ getBBAddrMapFeature(const MachineFunction &MF, int NumMBBSectionRanges,
14381443
MF.hasBBSections() && NumMBBSectionRanges > 1,
14391444
static_cast<bool>(BBAddrMapSkipEmitBBEntries),
14401445
HasCalls,
1441-
false};
1446+
static_cast<bool>(EmitBBHash)};
14421447
}
14431448

14441449
void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
@@ -1497,6 +1502,9 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
14971502
PrevMBBEndSymbol = MBBSymbol;
14981503
}
14991504

1505+
auto MBHI =
1506+
Features.BBHash ? &getAnalysis<MachineBlockHashInfo>() : nullptr;
1507+
15001508
if (!Features.OmitBBEntries) {
15011509
OutStreamer->AddComment("BB id");
15021510
// Emit the BB ID for this basic block.
@@ -1524,6 +1532,10 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
15241532
emitLabelDifferenceAsULEB128(MBB.getEndSymbol(), CurrentLabel);
15251533
// Emit the Metadata.
15261534
OutStreamer->emitULEB128IntValue(getBBAddrMapMetadata(MBB));
1535+
// Emit the Hash.
1536+
if (MBHI) {
1537+
OutStreamer->emitInt64(MBHI->getMBBHash(MBB));
1538+
}
15271539
}
15281540
PrevMBBEndSymbol = MBB.getEndSymbol();
15291541
}

0 commit comments

Comments
 (0)