Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle masked_gather in late-gc-lowering #34583

Merged
merged 2 commits into from
Feb 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 62 additions & 22 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,22 +474,50 @@ static std::pair<Value*,int> FindBaseValue(const State &S, Value *V, bool UseCac
CurrentV = EEI->getVectorOperand();
}
else if (auto LI = dyn_cast<LoadInst>(CurrentV)) {
if (auto PtrT = dyn_cast<PointerType>(LI->getType())) {
if (auto PtrT = dyn_cast<PointerType>(LI->getType()->getScalarType())) {
if (PtrT->getAddressSpace() == AddressSpace::Loaded) {
CurrentV = LI->getPointerOperand();
fld_idx = -1;
if (!isSpecialPtr(CurrentV->getType())) {
// Special case to bypass the check below.
// This could really be anything, but it's not loaded
// from a tracked pointer, so it doesn't matter what
// it is.
return std::make_pair(CurrentV, fld_idx);
// it is--just pick something simple.
CurrentV = ConstantPointerNull::get(Type::getInt8PtrTy(V->getContext()));
}
continue;
}
}
// In general a load terminates a walk
break;
}
else if (auto II = dyn_cast<IntrinsicInst>(CurrentV)) {
// Some intrinsics behave like LoadInst followed by a SelectInst
// This should never happen in a derived addrspace (since those cannot be stored to memory)
// so we don't need to lift these operations, but we do need to check if it's loaded and continue walking the base pointer
if (II->getIntrinsicID() == Intrinsic::masked_load ||
II->getIntrinsicID() == Intrinsic::masked_gather) {
if (auto PtrT = dyn_cast<PointerType>(II->getType()->getVectorElementType())) {
if (PtrT->getAddressSpace() == AddressSpace::Loaded) {
assert(isa<UndefValue>(II->getOperand(3)) && "unimplemented");
CurrentV = II->getOperand(0);
if (II->getIntrinsicID() == Intrinsic::masked_load) {
fld_idx = -1;
if (!isSpecialPtr(CurrentV->getType())) {
CurrentV = ConstantPointerNull::get(Type::getInt8PtrTy(V->getContext()));
}
} else {
if (!isSpecialPtr(CurrentV->getType()->getVectorElementType())) {
CurrentV = ConstantPointerNull::get(Type::getInt8PtrTy(V->getContext()));
fld_idx = -1;
}
}
continue;
}
}
// In general a load terminates a walk
break;
}
}
else {
break;
}
Expand Down Expand Up @@ -1296,6 +1324,23 @@ State LateLowerGCFrame::LocalScan(Function &F) {
II->getIntrinsicID() == Intrinsic::lifetime_end) {
continue;
}
if (II->getIntrinsicID() == Intrinsic::masked_load ||
II->getIntrinsicID() == Intrinsic::masked_gather) {
if (auto PtrT = dyn_cast<PointerType>(II->getType()->getVectorElementType())) {
if (isSpecialPtr(PtrT)) {
// LLVM sometimes tries to materialize these operations with undefined pointers in our non-integral address space.
// Hopefully LLVM didn't already propagate that information and poison our users. Set those to NULL now.
Value *passthru = II->getArgOperand(3);
if (isa<UndefValue>(passthru)) {
II->setArgOperand(3, Constant::getNullValue(passthru->getType()));
}
if (PtrT->getAddressSpace() == AddressSpace::Loaded) {
// These are not real defs
continue;
}
}
}
}
}
auto callee = CI->getCalledFunction();
if (callee && callee == typeof_func) {
Expand Down Expand Up @@ -1388,10 +1433,11 @@ State LateLowerGCFrame::LocalScan(Function &F) {
// of this object to uses of the object we're loading
// from.
SmallVector<int, 1> RefinedPtr{};
Type *Ty = LI->getType()->getScalarType();
if (isLoadFromImmut(LI) && isSpecialPtr(LI->getPointerOperand()->getType())) {
RefinedPtr.push_back(Number(S, LI->getPointerOperand()));
} else if (LI->getType()->isPointerTy() &&
isSpecialPtr(LI->getType()) &&
isSpecialPtr(Ty) &&
LooksLikeFrameRef(LI->getPointerOperand())) {
// Loads from a jlcall argument array
RefinedPtr.push_back(-1);
Expand All @@ -1401,8 +1447,7 @@ State LateLowerGCFrame::LocalScan(Function &F) {
// we know that the object is a constant as well and doesn't need rooting.
RefinedPtr.push_back(-2);
}
if (!LI->getType()->isPointerTy() ||
LI->getType()->getPointerAddressSpace() != AddressSpace::Loaded) {
if (!Ty->isPointerTy() || Ty->getPointerAddressSpace() != AddressSpace::Loaded) {
MaybeNoteDef(S, BBS, LI, BBS.Safepoints, std::move(RefinedPtr));
}
NoteOperandUses(S, BBS, I);
Expand Down Expand Up @@ -1970,21 +2015,16 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S) {
for (BasicBlock &BB : F) {
for (auto it = BB.begin(); it != BB.end();) {
Instruction *I = &*it;
if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
// strip all constant alias information, as it might depend on the gc having
// preserved a gc root, which stops being true after this pass (#32215)
// we'd like to call RewriteStatepointsForGC::stripNonValidData here, but
// that function asserts that the GC strategy must be named either "statepoint-example" or "coreclr",
// while we don't give a name to our GC in the IR, and C++ scope rules prohibit us from using it,
// so instead we reimplement it here badly
if (I->getMetadata(LLVMContext::MD_invariant_load))
I->setMetadata(LLVMContext::MD_invariant_load, NULL);
if (MDNode *TBAA = I->getMetadata(LLVMContext::MD_tbaa)) {
if (TBAA->getNumOperands() == 4 && isTBAA(TBAA, {"jtbaa_const"})) {
MDNode *MutableTBAA = createMutableTBAAAccessTag(TBAA);
if (MutableTBAA != TBAA)
I->setMetadata(LLVMContext::MD_tbaa, MutableTBAA);
}
// strip all constant alias information, as it might depend on the gc having
// preserved a gc root, which stops being true after this pass (#32215)
// similar to RewriteStatepointsForGC::stripNonValidData, but less aggressive
if (I->getMetadata(LLVMContext::MD_invariant_load))
I->setMetadata(LLVMContext::MD_invariant_load, NULL);
if (MDNode *TBAA = I->getMetadata(LLVMContext::MD_tbaa)) {
if (TBAA->getNumOperands() == 4 && isTBAA(TBAA, {"jtbaa_const"})) {
MDNode *MutableTBAA = createMutableTBAAAccessTag(TBAA);
if (MutableTBAA != TBAA)
I->setMetadata(LLVMContext::MD_tbaa, MutableTBAA);
}
}
auto *CI = dyn_cast<CallInst>(&*it);
Expand Down
61 changes: 61 additions & 0 deletions test/llvmpasses/gcroots.ll
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,67 @@ top:
ret i8 %add
}

define i8 @vector_arrayptrs() {
; CHECK-LABEL: @vector_arrayptrs
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* %obj1, %jl_value_t addrspace(10)** [[GEP0]]
;
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) *
%arrayptrptr = bitcast %jl_value_t addrspace(11) *%decayed to <2 x i8 addrspace(13)*> addrspace(11)*
%arrayptrs = load <2 x i8 addrspace(13)*>, <2 x i8 addrspace(13)*> addrspace(11)* %arrayptrptr, align 16
%arrayptr = extractelement <2 x i8 addrspace(13)*> %arrayptrs, i32 0
call void @jl_safepoint()
%val = load i8, i8 addrspace(13)* %arrayptr
ret i8 %val
}

declare <2 x i8 addrspace(13)*> @llvm.masked.load.v2p13i8.p11v2p13i8 (<2 x i8 addrspace(13)*> addrspace(11)*, i32, <2 x i1>, <2 x i8 addrspace(13)*>)

define i8 @masked_arrayptrs() {
; CHECK-LABEL: @masked_arrayptrs
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: %arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.load.v2p13i8.p11v2p13i8(<2 x i8 addrspace(13)*> addrspace(11)* %arrayptrptr, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> zeroinitializer)
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* %obj1, %jl_value_t addrspace(10)** [[GEP0]]
;
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) *
%arrayptrptr = bitcast %jl_value_t addrspace(11) *%decayed to <2 x i8 addrspace(13)*> addrspace(11)*
%arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.load.v2p13i8.p11v2p13i8(<2 x i8 addrspace(13)*> addrspace(11)* %arrayptrptr, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> undef)
%arrayptr = extractelement <2 x i8 addrspace(13)*> %arrayptrs, i32 0
call void @jl_safepoint()
%val = load i8, i8 addrspace(13)* %arrayptr
ret i8 %val
}

declare <2 x i8 addrspace(13)*> @llvm.masked.gather.v2p13i8.v2p11p13i8 (<2 x i8 addrspace(13)* addrspace(11)*>, i32, <2 x i1>, <2 x i8 addrspace(13)*>)

define i8 @gather_arrayptrs() {
; CHECK-LABEL: @gather_arrayptrs
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: %arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.gather.v2p13i8.v2p11p13i8(<2 x i8 addrspace(13)* addrspace(11)*> %arrayptrptrs, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> zeroinitializer)
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* %obj1, %jl_value_t addrspace(10)** [[GEP0]]
;
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11)*
%arrayptrptr = bitcast %jl_value_t addrspace(11) *%decayed to i8 addrspace(13)* addrspace(11)*
%arrayptrptrs = insertelement <2 x i8 addrspace(13)* addrspace(11)*> zeroinitializer, i8 addrspace(13)* addrspace(11)* %arrayptrptr, i32 0
%arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.gather.v2p13i8.v2p11p13i8(<2 x i8 addrspace(13)* addrspace(11)*> %arrayptrptrs, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> undef)
%arrayptr = extractelement <2 x i8 addrspace(13)*> %arrayptrs, i32 0
call void @jl_safepoint()
%val = load i8, i8 addrspace(13)* %arrayptr
ret i8 %val
}

!0 = !{!"jtbaa"}
!1 = !{!"jtbaa_const", !0, i64 0}
!2 = !{!1, !1, i64 0, i64 1}