diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index f6e5bcc529d3..3ce4f6165215 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2747,6 +2747,47 @@ class AdjointGenerator auto &DL = gutils->newFunc->getParent()->getDataLayout(); auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0); + if (!vd.isKnownPastPointer()) { + // If unknown type results, and zeroing known undef allocation, consider + // integers + if (auto CI = dyn_cast(MS.getOperand(1))) + if (CI->isZero()) { + auto root = getBaseObject(MS.getOperand(0)); + bool writtenTo = false; + if (isa(root) || isAllocationCall(root, gutils->TLI)) { + Instruction *cur = MS.getPrevNode(); + while (cur) { + if (cur == root) + break; + if (auto MCI = dyn_cast(MS.getOperand(2))) { + if (auto II = dyn_cast(cur)) { + // If the start of the lifetime for more memory than being + // memset, its valid. + if (II->getIntrinsicID() == Intrinsic::lifetime_start) { + if (getBaseObject(II->getOperand(1)) == root) { + if (auto CI2 = dyn_cast(II->getOperand(0))) { + if (MCI->getValue().ult(CI2->getValue())) + break; + } + } + } + } + } + if (cur->mayWriteToMemory()) { + writtenTo = true; + break; + } + cur = cur->getPrevNode(); + } + + if (!writtenTo) { + vd = TypeTree(BaseType::Pointer); + vd.insert({-1}, BaseType::Integer); + } + } + } + } + if (!vd.isKnownPastPointer()) { // If unknown type results, consider the intersection of all incoming. if (isa(MS.getOperand(0)) || isa(MS.getOperand(0))) {