Skip to content

Commit

Permalink
squash! handle masked_gather in late-gc-lowering
Browse files Browse the repository at this point in the history
[GCLowering] handle vectorized loads

Vectorized loads can come in 3 variants:
- scalar
- masked
- gather

All work about the same--if we run into a Loaded pointer, we need to
follow back the base pointer and track it. And otherwise they're just
normal def/use statements. But we also need to fix up after an LLVM
mistake and put a valid value in the pass-through slot.
  • Loading branch information
vtjnash committed Feb 5, 2020
1 parent 87821c9 commit 05a743a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 15 deletions.
54 changes: 39 additions & 15 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,15 @@ static std::pair<Value*,int> FindBaseValue(const State &S, Value *V, bool UseCac
CurrentV = EEI->getVectorOperand();
}
else if (auto LI = dyn_cast<LoadInst>(CurrentV)) {
if (auto PtrT = dyn_cast<PointerType>(LI->getType())) {
if (auto PtrT = dyn_cast<PointerType>(LI->getType()->getScalarType())) {
if (PtrT->getAddressSpace() == AddressSpace::Loaded) {
CurrentV = LI->getPointerOperand();
fld_idx = -1;
if (!isSpecialPtr(CurrentV->getType())) {
// Special case to bypass the check below.
// This could really be anything, but it's not loaded
// from a tracked pointer, so it doesn't matter what
// it is.
return std::make_pair(CurrentV, fld_idx);
// it is--just pick something simple.
CurrentV = ConstantPointerNull::get(Type::getInt8PtrTy(V->getContext()));
}
continue;
}
Expand All @@ -491,18 +491,25 @@ static std::pair<Value*,int> FindBaseValue(const State &S, Value *V, bool UseCac
break;
}
else if (auto II = dyn_cast<IntrinsicInst>(CurrentV)) {
// Some intrinsics behave like LoadInst
// Some intrinsics behave like LoadInst followed by a SelectInst
// This should never happen in a derived addrspace (since those cannot be stored to memory)
// so we don't need to lift these operations, but we do need to check if it's loaded and continue walking the base pointer
if (II->getIntrinsicID() == Intrinsic::masked_load ||
II->getIntrinsicID() == Intrinsic::masked_gather) {
if (auto PtrT = dyn_cast<PointerType>(II->getType())) {
if (auto PtrT = dyn_cast<PointerType>(II->getType()->getVectorElementType())) {
if (PtrT->getAddressSpace() == AddressSpace::Loaded) {
assert(isa<UndefValue>(II->getOperand(3)) && "unimplemented");
CurrentV = II->getOperand(0);
if (!isSpecialPtr(CurrentV->getType())) {
// Special case to bypass the check below.
// This could really be anything, but it's not loaded
// from a tracked pointer, so it doesn't matter what
// it is.
return std::make_pair(CurrentV, fld_idx);
if (II->getIntrinsicID() == Intrinsic::masked_load) {
fld_idx = -1;
if (!isSpecialPtr(CurrentV->getType())) {
CurrentV = ConstantPointerNull::get(Type::getInt8PtrTy(V->getContext()));
}
} else {
if (!isSpecialPtr(CurrentV->getType()->getVectorElementType())) {
CurrentV = ConstantPointerNull::get(Type::getInt8PtrTy(V->getContext()));
fld_idx = -1;
}
}
continue;
}
Expand Down Expand Up @@ -1317,6 +1324,23 @@ State LateLowerGCFrame::LocalScan(Function &F) {
II->getIntrinsicID() == Intrinsic::lifetime_end) {
continue;
}
if (II->getIntrinsicID() == Intrinsic::masked_load ||
II->getIntrinsicID() == Intrinsic::masked_gather) {
if (auto PtrT = dyn_cast<PointerType>(II->getType()->getVectorElementType())) {
if (isSpecialPtr(PtrT)) {
// LLVM sometimes tries to materialize these operations with undefined pointers in our non-integral address space.
// Hopefully LLVM didn't already propagate that information and poison our users. Set those to NULL now.
Value *passthru = II->getArgOperand(3);
if (isa<UndefValue>(passthru)) {
II->setArgOperand(3, Constant::getNullValue(passthru->getType()));
}
if (PtrT->getAddressSpace() == AddressSpace::Loaded) {
// These are not real defs
continue;
}
}
}
}
}
auto callee = CI->getCalledFunction();
if (callee && callee == typeof_func) {
Expand Down Expand Up @@ -1409,10 +1433,11 @@ State LateLowerGCFrame::LocalScan(Function &F) {
// of this object to uses of the object we're loading
// from.
SmallVector<int, 1> RefinedPtr{};
Type *Ty = LI->getType()->getScalarType();
if (isLoadFromImmut(LI) && isSpecialPtr(LI->getPointerOperand()->getType())) {
RefinedPtr.push_back(Number(S, LI->getPointerOperand()));
} else if (LI->getType()->isPointerTy() &&
isSpecialPtr(LI->getType()) &&
isSpecialPtr(Ty) &&
LooksLikeFrameRef(LI->getPointerOperand())) {
// Loads from a jlcall argument array
RefinedPtr.push_back(-1);
Expand All @@ -1422,8 +1447,7 @@ State LateLowerGCFrame::LocalScan(Function &F) {
// we know that the object is a constant as well and doesn't need rooting.
RefinedPtr.push_back(-2);
}
if (!LI->getType()->isPointerTy() ||
LI->getType()->getPointerAddressSpace() != AddressSpace::Loaded) {
if (!Ty->isPointerTy() || Ty->getPointerAddressSpace() != AddressSpace::Loaded) {
MaybeNoteDef(S, BBS, LI, BBS.Safepoints, std::move(RefinedPtr));
}
NoteOperandUses(S, BBS, I);
Expand Down
61 changes: 61 additions & 0 deletions test/llvmpasses/gcroots.ll
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,67 @@ top:
ret i8 %add
}

define i8 @vector_arrayptrs() {
; CHECK-LABEL: @vector_arrayptrs
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* %obj1, %jl_value_t addrspace(10)** [[GEP0]]
;
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) *
%arrayptrptr = bitcast %jl_value_t addrspace(11) *%decayed to <2 x i8 addrspace(13)*> addrspace(11)*
%arrayptrs = load <2 x i8 addrspace(13)*>, <2 x i8 addrspace(13)*> addrspace(11)* %arrayptrptr, align 16
%arrayptr = extractelement <2 x i8 addrspace(13)*> %arrayptrs, i32 0
call void @jl_safepoint()
%val = load i8, i8 addrspace(13)* %arrayptr
ret i8 %val
}

declare <2 x i8 addrspace(13)*> @llvm.masked.load.v2p13i8.p11v2p13i8 (<2 x i8 addrspace(13)*> addrspace(11)*, i32, <2 x i1>, <2 x i8 addrspace(13)*>)

define i8 @masked_arrayptrs() {
; CHECK-LABEL: @masked_arrayptrs
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: %arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.load.v2p13i8.p11v2p13i8(<2 x i8 addrspace(13)*> addrspace(11)* %arrayptrptr, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> zeroinitializer)
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* %obj1, %jl_value_t addrspace(10)** [[GEP0]]
;
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) *
%arrayptrptr = bitcast %jl_value_t addrspace(11) *%decayed to <2 x i8 addrspace(13)*> addrspace(11)*
%arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.load.v2p13i8.p11v2p13i8(<2 x i8 addrspace(13)*> addrspace(11)* %arrayptrptr, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> undef)
%arrayptr = extractelement <2 x i8 addrspace(13)*> %arrayptrs, i32 0
call void @jl_safepoint()
%val = load i8, i8 addrspace(13)* %arrayptr
ret i8 %val
}

declare <2 x i8 addrspace(13)*> @llvm.masked.gather.v2p13i8.v2p11p13i8 (<2 x i8 addrspace(13)* addrspace(11)*>, i32, <2 x i1>, <2 x i8 addrspace(13)*>)

define i8 @gather_arrayptrs() {
; CHECK-LABEL: @gather_arrayptrs
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: %arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.gather.v2p13i8.v2p11p13i8(<2 x i8 addrspace(13)* addrspace(11)*> %arrayptrptrs, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> zeroinitializer)
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* %obj1, %jl_value_t addrspace(10)** [[GEP0]]
;
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11)*
%arrayptrptr = bitcast %jl_value_t addrspace(11) *%decayed to i8 addrspace(13)* addrspace(11)*
%arrayptrptrs = insertelement <2 x i8 addrspace(13)* addrspace(11)*> zeroinitializer, i8 addrspace(13)* addrspace(11)* %arrayptrptr, i32 0
%arrayptrs = call <2 x i8 addrspace(13)*> @llvm.masked.gather.v2p13i8.v2p11p13i8(<2 x i8 addrspace(13)* addrspace(11)*> %arrayptrptrs, i32 16, <2 x i1> <i1 true, i1 false>, <2 x i8 addrspace(13)*> undef)
%arrayptr = extractelement <2 x i8 addrspace(13)*> %arrayptrs, i32 0
call void @jl_safepoint()
%val = load i8, i8 addrspace(13)* %arrayptr
ret i8 %val
}

!0 = !{!"jtbaa"}
!1 = !{!"jtbaa_const", !0, i64 0}
!2 = !{!1, !1, i64 0, i64 1}

0 comments on commit 05a743a

Please sign in to comment.