Skip to content

Commit f8af029

Browse files
authored
Continued Type Analysis Updates (rust-lang#410)
* More non-strict type handling * Merge global constants
1 parent 06f7d22 commit f8af029

File tree

6 files changed

+132
-27
lines changed

6 files changed

+132
-27
lines changed

Diff for: enzyme/Enzyme/TypeAnalysis/ConcreteType.h

+7
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,13 @@ class ConcreteType {
283283
/// BaseType::Unknown
284284
bool operator&=(const ConcreteType CT) { return andIn(CT); }
285285

286+
/// Keep only mappings where the type is not an `Anything`
287+
ConcreteType PurgeAnything() const {
288+
if (SubTypeEnum == BaseType::Anything)
289+
return BaseType::Unknown;
290+
return *this;
291+
}
292+
286293
/// Set this to the logical `binop` of itself and RHS, using the Binop Op,
287294
/// returning true if this was changed.
288295
/// This function will error on an invalid type combination

Diff for: enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

+41-24
Original file line numberDiff line numberDiff line change
@@ -299,26 +299,30 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA,
299299
int Off = (int)ai.getLimitedValue();
300300

301301
getConstantAnalysis(Op, TA, analysis);
302-
auto mid = analysis[Op].ShiftIndices(DL, /*init offset*/ 0,
303-
/*maxSize*/ ObjSize,
304-
/*addOffset*/ Off);
305-
302+
auto mid = analysis[Op];
306303
if (TA.fntypeinfo.Function->getParent()
307304
->getDataLayout()
308305
.getTypeSizeInBits(CA->getType()) >= 16) {
309306
mid.ReplaceIntWithAnything();
310307
}
311308

312-
Result |= mid;
309+
Result |= mid.ShiftIndices(DL, /*init offset*/ 0,
310+
/*maxSize*/ ObjSize,
311+
/*addOffset*/ Off);
313312
}
313+
Result = Result.CanonicalizeValue(
314+
(TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits(
315+
CA->getType()) +
316+
7) /
317+
8,
318+
DL);
314319
return;
315320
}
316321

317322
// Type of an sequence is the aggregation of
318323
// the subtypes
319324
if (auto CD = dyn_cast<ConstantDataSequential>(Val)) {
320325
TypeTree &Result = analysis[Val];
321-
322326
for (unsigned i = 0, size = CD->getNumElements(); i < size; ++i) {
323327
assert(TA.fntypeinfo.Function);
324328
auto Op = CD->getElementAsConstant(i);
@@ -349,18 +353,24 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA,
349353
int Off = (int)ai.getLimitedValue();
350354

351355
getConstantAnalysis(Op, TA, analysis);
352-
auto mid = analysis[Op].ShiftIndices(DL, /*init offset*/ 0,
353-
/*maxSize*/ ObjSize,
354-
/*addOffset*/ Off);
355-
356+
auto mid = analysis[Op];
356357
if (TA.fntypeinfo.Function->getParent()
357358
->getDataLayout()
358359
.getTypeSizeInBits(CD->getType()) >= 16) {
359360
mid.ReplaceIntWithAnything();
360361
}
362+
Result |= mid.ShiftIndices(DL, /*init offset*/ 0,
363+
/*maxSize*/ ObjSize,
364+
/*addOffset*/ Off);
361365

362366
Result |= mid;
363367
}
368+
Result = Result.CanonicalizeValue(
369+
(TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits(
370+
CD->getType()) +
371+
7) /
372+
8,
373+
DL);
364374
return;
365375
}
366376

@@ -849,12 +859,12 @@ void TypeAnalyzer::considerTBAA() {
849859
(DL.getTypeSizeInBits(SI->getValueOperand()->getType()) + 7) / 8;
850860
updateAnalysis(SI->getPointerOperand(),
851861
vdptr
862+
// Don't propagate "Anything" into ptr
863+
.PurgeAnything()
852864
// Cut off any values outside of store
853865
.ShiftIndices(DL, /*init offset*/ 0,
854866
/*max size*/ StoreSize,
855867
/*new offset*/ 0)
856-
// Don't propagate "Anything" into ptr
857-
.PurgeAnything()
858868
.Only(-1),
859869
SI);
860870
TypeTree req = vdptr.Only(-1);
@@ -863,12 +873,12 @@ void TypeAnalyzer::considerTBAA() {
863873
auto LoadSize = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8;
864874
updateAnalysis(LI->getPointerOperand(),
865875
vdptr
876+
// Don't propagate "Anything" into ptr
877+
.PurgeAnything()
866878
// Cut off any values outside of load
867879
.ShiftIndices(DL, /*init offset*/ 0,
868880
/*max size*/ LoadSize,
869881
/*new offset*/ 0)
870-
// Don't propagate "Anything" into ptr
871-
.PurgeAnything()
872882
.Only(-1),
873883
LI);
874884
TypeTree req = vdptr.Only(-1);
@@ -1125,12 +1135,12 @@ void TypeAnalyzer::visitCmpInst(CmpInst &cmp) {
11251135
if (direction & UP) {
11261136
updateAnalysis(
11271137
cmp.getOperand(0),
1128-
TypeTree(getAnalysis(cmp.getOperand(1)).Data0().PurgeAnything()[{}])
1138+
TypeTree(getAnalysis(cmp.getOperand(1)).Inner0().PurgeAnything())
11291139
.Only(-1),
11301140
&cmp);
11311141
updateAnalysis(
11321142
cmp.getOperand(1),
1133-
TypeTree(getAnalysis(cmp.getOperand(0)).Data0().PurgeAnything()[{}])
1143+
TypeTree(getAnalysis(cmp.getOperand(0)).Inner0().PurgeAnything())
11341144
.Only(-1),
11351145
&cmp);
11361146
}
@@ -1194,8 +1204,9 @@ void TypeAnalyzer::visitStoreInst(StoreInst &I) {
11941204
// Only propagate mappings in range that aren't "Anything" into the pointer
11951205
auto ptr = TypeTree(BaseType::Pointer);
11961206
auto purged = getAnalysis(I.getValueOperand())
1207+
.PurgeAnything()
11971208
.ShiftIndices(DL, /*start*/ 0, StoreSize, /*addOffset*/ 0)
1198-
.PurgeAnything();
1209+
.ReplaceMinus();
11991210
ptr |= purged;
12001211

12011212
if (direction & UP) {
@@ -2460,13 +2471,13 @@ void TypeAnalyzer::visitMemTransferCommon(llvm::CallInst &MTI) {
24602471

24612472
auto &dl = MTI.getParent()->getParent()->getParent()->getDataLayout();
24622473
TypeTree res = getAnalysis(MTI.getArgOperand(0))
2474+
.PurgeAnything()
24632475
.Data0()
2464-
.ShiftIndices(dl, 0, sz, 0)
2465-
.PurgeAnything();
2476+
.ShiftIndices(dl, 0, sz, 0);
24662477
TypeTree res2 = getAnalysis(MTI.getArgOperand(1))
2478+
.PurgeAnything()
24672479
.Data0()
2468-
.ShiftIndices(dl, 0, sz, 0)
2469-
.PurgeAnything();
2480+
.ShiftIndices(dl, 0, sz, 0);
24702481

24712482
bool Legal = true;
24722483
res.checkedOrIn(res2, /*PointerIntSame*/ false, Legal);
@@ -3772,11 +3783,11 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
37723783

37733784
auto &dl = call.getParent()->getParent()->getParent()->getDataLayout();
37743785
TypeTree res = getAnalysis(call.getArgOperand(0))
3786+
.PurgeAnything()
37753787
.Data0()
3776-
.ShiftIndices(dl, 0, sz, 0)
3777-
.PurgeAnything();
3788+
.ShiftIndices(dl, 0, sz, 0);
37783789
TypeTree res2 =
3779-
getAnalysis(&call).Data0().ShiftIndices(dl, 0, sz, 0).PurgeAnything();
3790+
getAnalysis(&call).PurgeAnything().Data0().ShiftIndices(dl, 0, sz, 0);
37803791

37813792
res.orIn(res2, /*PointerIntSame*/ false);
37823793
res.insert({}, BaseType::Pointer);
@@ -4524,6 +4535,9 @@ void TypeAnalyzer::visitIPOCall(CallInst &call, Function &fn) {
45244535

45254536
TypeResults STR = interprocedural.analyzeFunction(typeInfo);
45264537

4538+
if (EnzymePrintType)
4539+
llvm::errs() << " ending IPO of " << call << "\n";
4540+
45274541
if (hasUp) {
45284542
auto a = fn.arg_begin();
45294543
#if LLVM_VERSION_MAJOR >= 14
@@ -4533,6 +4547,9 @@ void TypeAnalyzer::visitIPOCall(CallInst &call, Function &fn) {
45334547
#endif
45344548
{
45354549
auto dt = STR.query(a);
4550+
if (EnzymePrintType)
4551+
llvm::errs() << " updating " << *arg << " = " << dt.str()
4552+
<< " via IPO of " << call << " arg " << *a << "\n";
45364553
updateAnalysis(arg, dt, &call);
45374554
++a;
45384555
}

Diff for: enzyme/Enzyme/TypeAnalysis/TypeTree.h

+15
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,21 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
826826
return dat;
827827
}
828828

829+
/// Replace -1 with 0
830+
TypeTree ReplaceMinus() const {
831+
TypeTree dat;
832+
for (const auto pair : mapping) {
833+
if (pair.second == ConcreteType(BaseType::Anything))
834+
continue;
835+
std::vector<int> nex = pair.first;
836+
for (auto &v : nex)
837+
if (v == -1)
838+
v = 0;
839+
dat.insert(nex, pair.second);
840+
}
841+
return dat;
842+
}
843+
829844
/// Replace all integer subtypes with anything
830845
void ReplaceIntWithAnything() {
831846
for (auto &pair : mapping) {

Diff for: enzyme/test/TypeAnalysis/loadptrptr.ll

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ entry:
2323
!5 = !{!"Simple C++ TBAA"}
2424

2525
; CHECK: caller - {} |{}:{}
26-
; CHECK-NEXT: i64* %q: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer}
26+
; CHECK-NEXT: i64* %q: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Integer, [-1,0,1]:Integer, [-1,0,2]:Integer, [-1,0,3]:Integer, [-1,0,4]:Integer, [-1,0,5]:Integer, [-1,0,6]:Integer, [-1,0,7]:Integer}
2727
; CHECK-NEXT: entry
28-
; CHECK-NEXT: %p = bitcast i64* %q to i64**: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer}
28+
; CHECK-NEXT: %p = bitcast i64* %q to i64**: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Integer, [-1,0,1]:Integer, [-1,0,2]:Integer, [-1,0,3]:Integer, [-1,0,4]:Integer, [-1,0,5]:Integer, [-1,0,6]:Integer, [-1,0,7]:Integer}
2929
; CHECK-NEXT: %x = alloca i64, align 8: {[-1]:Pointer, [-1,-1]:Integer}
3030
; CHECK-NEXT: store i64 132, i64* %x, align 8: {}
3131
; CHECK-NEXT: store i64* %x, i64** %p, align 8: {}
32-
; CHECK-NEXT: %ld = load i64*, i64** %p, align 8: {[-1]:Pointer, [-1,-1]:Integer}
32+
; CHECK-NEXT: %ld = load i64*, i64** %p, align 8: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}
3333
; CHECK-NEXT: ret void: {}

Diff for: enzyme/test/TypeAnalysis/mergeglob.ll

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=callee -o /dev/null | FileCheck %s
2+
3+
@ptr = internal constant [4 x double] [double 3.14000e+00, double 3.14000e+00, double 3.14000e+00, double 3.14000e+00], align 8
4+
5+
define void @callee() {
6+
entry:
7+
%self = getelementptr inbounds [4 x double], [4 x double]* @ptr, i32 0, i32 0
8+
ret void
9+
}
10+
11+
; CHECK: callee - {} |
12+
; CHECK-NEXT: entry
13+
; CHECK-NEXT: %self = getelementptr inbounds [4 x double], [4 x double]* @ptr, i32 0, i32 0: {[-1]:Pointer, [-1,-1]:Float@double}
14+
; CHECK-NEXT: ret void: {}

Diff for: enzyme/test/TypeAnalysis/transf.ll

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=caller -o /dev/null | FileCheck %s
2+
3+
define void @caller(double* %p) {
4+
entry:
5+
%q = alloca [2 x double], align 8
6+
%gep0 = getelementptr [2 x double], [2 x double]* %q, i32 0, i32 0
7+
%gep = getelementptr [2 x double], [2 x double]* %q, i32 0, i32 1
8+
store double 0.000000e+00, double* %gep, align 8, !tbaa !2
9+
%o = call double* @max(double* %p, double* %gep0, i1 true)
10+
ret void
11+
}
12+
13+
define double* @max(double* %a, double* %b, i1 %cmp) {
14+
entry:
15+
%retval = alloca double*, align 8
16+
%av = load double, double* %a, align 8
17+
%bv = load double, double* %b, align 8
18+
br i1 %cmp, label %if.then, label %if.end
19+
20+
if.then:
21+
store double* %a, double** %retval, align 8
22+
br label %return
23+
24+
if.end:
25+
store double* %b, double** %retval, align 8
26+
br label %return
27+
28+
return:
29+
%res = load double*, double** %retval, align 8
30+
ret double* %res
31+
}
32+
33+
!llvm.module.flags = !{!0}
34+
!llvm.ident = !{!1}
35+
36+
!0 = !{i32 1, !"wchar_size", i32 4}
37+
!1 = !{!"clang version 7.1.0 "}
38+
!2 = !{!3, !3, i64 0, i64 8}
39+
!3 = !{!4, i64 8, !"long"}
40+
!4 = !{!5, i64 1, !"omnipotent char"}
41+
!5 = !{!"Simple C++ TBAA"}
42+
43+
44+
; CHECK: caller - {} |{[-1]:Pointer, [-1,-1]:Float@double}:{}
45+
; CHECK-NEXT: double* %p: {[-1]:Pointer, [-1,-1]:Float@double}
46+
; CHECK-NEXT: entry
47+
; CHECK-NEXT: %q = alloca [2 x double], align 8: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer}
48+
; CHECK-NEXT: %gep0 = getelementptr [2 x double], [2 x double]* %q, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Float@double}
49+
; CHECK-NEXT: %gep = getelementptr [2 x double], [2 x double]* %q, i32 0, i32 1: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}
50+
; CHECK-NEXT: store double 0.000000e+00, double* %gep, align 8, !tbaa !2: {}
51+
; CHECK-NEXT: %o = call double* @max(double* %p, double* %gep0, i1 true): {[-1]:Pointer, [-1,0]:Float@double}
52+
; CHECK-NEXT: ret void: {}

0 commit comments

Comments
 (0)