diff --git a/src/llvm-late-gc-lowering.cpp b/src/llvm-late-gc-lowering.cpp index 6c0bc5ede0f7e..05ec7d24f8f7b 100644 --- a/src/llvm-late-gc-lowering.cpp +++ b/src/llvm-late-gc-lowering.cpp @@ -295,6 +295,8 @@ struct State { // Those values that - if live out from our parent basic block - are live // at this safepoint. std::vector> LiveIfLiveOut; + // The set of values that are kept alive by the callee. + std::vector> CalleeRoots; // We don't bother doing liveness on Allocas that were not mem2reg'ed. // they just get directly sunk into the root array. std::vector Allocas; @@ -359,7 +361,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext { void NoteUseChain(State &S, BBState &BBS, User *TheUser); SmallVector GetPHIRefinements(PHINode *phi, State &S); void FixUpRefinements(ArrayRef PHINumbers, State &S); - void RefineLiveSet(BitVector &LS, State &S); + void RefineLiveSet(BitVector &LS, State &S, const std::vector &CalleeRoots); Value *EmitTagPtr(IRBuilder<> &builder, Type *T, Value *V); Value *EmitLoadTag(IRBuilder<> &builder, Value *V); }; @@ -1002,7 +1004,7 @@ void LateLowerGCFrame::MaybeNoteDef(State &S, BBState &BBS, Value *Def, const st } } -static int NoteSafepoint(State &S, BBState &BBS, CallInst *CI) { +static int NoteSafepoint(State &S, BBState &BBS, CallInst *CI, std::vector CalleeRoots) { int Number = ++S.MaxSafepointNumber; S.SafepointNumbering[CI] = Number; S.ReverseSafepointNumbering.push_back(CI); @@ -1012,6 +1014,7 @@ static int NoteSafepoint(State &S, BBState &BBS, CallInst *CI) { // computation) S.LiveSets.push_back(BBS.UpExposedUses); S.LiveIfLiveOut.push_back(std::vector{}); + S.CalleeRoots.push_back(std::move(CalleeRoots)); return Number; } @@ -1515,7 +1518,25 @@ State LateLowerGCFrame::LocalScan(Function &F) { // Intrinsics are never safepoints. continue; } - int SafepointNumber = NoteSafepoint(S, BBS, CI); + std::vector CalleeRoots; + for (Use &U : CI->arg_operands()) { + // Find all callee rooted arguments. + // Record them instead of simply remove them from live values here + // since they can be useful during refinment + // (e.g. to remove roots of objects that are refined to these) + Value *V = U; + if (isa(V) || !isa(V->getType()) || + getValueAddrSpace(V) != AddressSpace::CalleeRooted) + continue; + V = V->stripPointerCasts(); + if (!isTrackedValue(V)) + continue; + auto Num = Number(S, V); + if (Num < 0) + continue; + CalleeRoots.push_back(Num); + } + int SafepointNumber = NoteSafepoint(S, BBS, CI, std::move(CalleeRoots)); BBS.HasSafepoint = true; BBS.TopmostSafepoint = SafepointNumber; BBS.Safepoints.push_back(SafepointNumber); @@ -1845,12 +1866,18 @@ JL_USED_FUNC static void dumpSafepointsForBBName(Function &F, State &S, const ch } } -void LateLowerGCFrame::RefineLiveSet(BitVector &LS, State &S) +void LateLowerGCFrame::RefineLiveSet(BitVector &LS, State &S, const std::vector &CalleeRoots) { BitVector FullLS(S.MaxPtrNumber + 1, false); FullLS |= LS; // First expand the live set according to the refinement map // so that we can see all the values that are effectively live. + for (auto Num: CalleeRoots) { + // For callee rooted values, they are all kept alive at the safepoint. + // Make sure they are marked (even though they probably are already) + // so that other values can be refined to them. + FullLS[Num] = 1; + } bool changed; do { changed = false; @@ -1891,6 +1918,11 @@ void LateLowerGCFrame::RefineLiveSet(BitVector &LS, State &S) LS[Idx] = 0; } } + for (auto Num: CalleeRoots) { + // Now unmark all values that are rooted by the callee after + // refining other values to them. + LS[Num] = 0; + } } void LateLowerGCFrame::ComputeLiveSets(State &S) { @@ -1909,7 +1941,7 @@ void LateLowerGCFrame::ComputeLiveSets(State &S) { if (HasBitSet(BBS.LiveOut, Live)) LS[Live] = 1; } - RefineLiveSet(LS, S); + RefineLiveSet(LS, S, S.CalleeRoots[idx]); // If the function has GC preserves, figure out whether we need to // add in any extra live values. if (!S.GCPreserves.empty()) { diff --git a/test/llvmpasses/late-lower-gc.ll b/test/llvmpasses/late-lower-gc.ll index 00b0e64f05664..115d703f65f92 100644 --- a/test/llvmpasses/late-lower-gc.ll +++ b/test/llvmpasses/late-lower-gc.ll @@ -8,6 +8,7 @@ declare {}*** @julia.ptls_states() declare void @jl_safepoint() declare {} addrspace(10)* @jl_apply_generic({} addrspace(10)*, {} addrspace(10)**, i32) declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj(i8*, i64, {} addrspace(10)*) +declare i32 @rooting_callee({} addrspace(12)*, {} addrspace(12)*) define void @gc_frame_lowering(i64 %a, i64 %b) { top: @@ -74,6 +75,24 @@ top: ret void } +define i32 @callee_root({} addrspace(10)* %v0, {} addrspace(10)* %v1) { +top: +; CHECK-LABEL: @callee_root +; CHECK-NOT: @julia.new_gc_frame + %v2 = call {}*** @julia.ptls_states() + %v3 = bitcast {} addrspace(10)* %v0 to {} addrspace(10)* addrspace(10)* + %v4 = addrspacecast {} addrspace(10)* addrspace(10)* %v3 to {} addrspace(10)* addrspace(11)* + %v5 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %v4 unordered, align 8 + %v6 = bitcast {} addrspace(10)* %v1 to {} addrspace(10)* addrspace(10)* + %v7 = addrspacecast {} addrspace(10)* addrspace(10)* %v6 to {} addrspace(10)* addrspace(11)* + %v8 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %v7 unordered, align 8 + %v9 = addrspacecast {} addrspace(10)* %v5 to {} addrspace(12)* + %v10 = addrspacecast {} addrspace(10)* %v8 to {} addrspace(12)* + %v11 = call i32 @rooting_callee({} addrspace(12)* %v9, {} addrspace(12)* %v10) + ret i32 %v11 +; CHECK: ret i32 +} + !0 = !{i64 0, i64 23} !1 = !{} !2 = distinct !{!2}