diff --git a/lifter/OperandUtils.cpp b/lifter/OperandUtils.cpp index 6cfa1b3..09c8425 100644 --- a/lifter/OperandUtils.cpp +++ b/lifter/OperandUtils.cpp @@ -78,25 +78,11 @@ SimplifyQuery lifterClass::createSimplifyQuery(Instruction* Inst) { // updateDomTree(*fnc); // auto DT = getDomTree(); auto DL = fnc->getParent()->getDataLayout(); - static TargetLibraryInfoImpl TLIImpl( + static llvm::TargetLibraryInfoImpl TLIImpl( Triple(fnc->getParent()->getTargetTriple())); - static TargetLibraryInfo TLI(TLIImpl); - if (BIlist.size() != BIlistsize) { - BIlistsize = BIlist.size(); - DC = new DomConditionCache(); - - for (auto BI : BIlist) { - - DC->registerBranch(BI); - SmallVector Affected; - findAffectedValues(BI->getCondition(), Affected); - for (auto affectedvalues : Affected) { - printvalue(affectedvalues); - } - } - } + static llvm::TargetLibraryInfo TLI(TLIImpl); - SimplifyQuery SQ(DL, &TLI, DT, nullptr, Inst, true, true, DC); + SimplifyQuery SQ(DL, &TLI, DT, nullptr, Inst, true, true, nullptr); return SQ; } @@ -197,9 +183,6 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, // if a is 0, select B // if a is -1, select C // then... ? - printvalue(A); - printvalue(B); - printvalue(C); if (auto X_inst = dyn_cast(A)) { auto possible_condition = analyzeValueKnownBits(X_inst, X_inst); @@ -224,7 +207,6 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, if (isXAndNotX(op0, op1, X) || isXAndNotX(op1, op0, X)) { auto possibleSimplifyand = ConstantInt::get(op1->getType(), 0); - printvalue(possibleSimplifyand); return possibleSimplifyand; } // ~X & ~X @@ -246,13 +228,11 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, if (isXorNotX(op0, op1, X) || isXorNotX(op1, op0, X)) { auto possibleSimplify = ConstantInt::get(op1->getType(), -1); - printvalue(possibleSimplify); return possibleSimplify; } if (match(op0, m_Specific(op1))) { auto possibleSimplify = ConstantInt::get(op1->getType(), 0); - printvalue(possibleSimplify); return possibleSimplify; } @@ -278,7 +258,6 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, auto handleNotAOrB = [&](Value* A, Value* B) -> Value* { if (match(A, m_Not(m_Value(C))) && match(B, m_Constant(constant_v))) { // ~(~a | b) -> a & ~b - printvalue(C); return createAndNot(C, constant_v, "not-PConst-"); } return nullptr; @@ -287,7 +266,6 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, auto handleAOrBci = [&](Value* A, Value* B) -> Value* { if (match(A, m_Value(C)) && match(B, m_Constant(constant_v))) { // ~(a | b(ci)) -> ~a & ~b - printvalue(C); return createAndFolder( createXorFolder(C, Constant::getAllOnesValue(C->getType()), @@ -302,8 +280,6 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, auto handleNotAOrNotB = [&](Value* A, Value* B) -> Value* { if (match(A, m_Not(m_Value(C))) && match(B, m_Not(m_Value(D)))) { // ~(~a | ~b) -> a & b - printvalue(C); - printvalue(D); return createAndFolder(C, D, "not-P1-"); } return nullptr; @@ -350,6 +326,7 @@ Value* lifterClass::doPatternMatching(Instruction::BinaryOps const I, } KnownBits lifterClass::analyzeValueKnownBits(Value* value, Instruction* ctxI) { + if (auto v_inst = dyn_cast(value)) { // Use find() to check if v_inst exists in the map auto it = assumptions.find(v_inst); @@ -403,29 +380,24 @@ Value* simplifyValue(Value* v, const DataLayout& DL) { return vsimplified; } - if (inst->getOpcode() == Instruction::Add) { - auto testsimp = (simplifyBinOp(inst->getOpcode(), inst->getOperand(0), - inst->getOperand(1), SQ)); - if (testsimp) - printvalue(testsimp); - } return v; } -Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { - auto it = cache.find(key); - if (it != cache.end()) { - return it->second; +inline bool isCast(uint8_t opcode) { + return Instruction::Trunc <= opcode && opcode <= Instruction::AddrSpaceCast; +}; + +Value* lifterClass::getOrCreate(const InstructionKey& key, uint8_t opcode, + const Twine& Name) { + auto it = cache.lookup(opcode, key); + if (it) { + return it; } Value* newInstruction = nullptr; - if (key.cast == 0) { - printvalue2(key.opcode); - printvalue2(key.cast); - printvalue(key.operand1); - printvalue(key.operand2); + if (isCast(opcode) == 0) { // Binary instruction if (auto select_inst = dyn_cast(key.operand1)) { printvalue2( @@ -433,9 +405,9 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { if (isa(key.operand2)) return createSelectFolder( select_inst->getCondition(), - builder.CreateBinOp(static_cast(key.opcode), + builder.CreateBinOp(static_cast(opcode), select_inst->getTrueValue(), key.operand2), - builder.CreateBinOp(static_cast(key.opcode), + builder.CreateBinOp(static_cast(opcode), select_inst->getFalseValue(), key.operand2), "lola-"); } @@ -446,9 +418,9 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { if (isa(key.operand1)) return createSelectFolder( select_inst->getCondition(), - builder.CreateBinOp(static_cast(key.opcode), + builder.CreateBinOp(static_cast(opcode), key.operand1, select_inst->getTrueValue()), - builder.CreateBinOp(static_cast(key.opcode), + builder.CreateBinOp(static_cast(opcode), key.operand1, select_inst->getFalseValue()), "lolb-"); } @@ -460,12 +432,10 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { // if inversed return createSelectFolder( select_inst->getCondition(), - builder.CreateBinOp( - static_cast(key.opcode), lhs1, - select_inst->getTrueValue()), - builder.CreateBinOp( - static_cast(key.opcode), rhs1, - select_inst->getFalseValue()), + builder.CreateBinOp(static_cast(opcode), + lhs1, select_inst->getTrueValue()), + builder.CreateBinOp(static_cast(opcode), + rhs1, select_inst->getFalseValue()), "lol2-"); } @@ -477,12 +447,10 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { // if inversed return createSelectFolder( select_inst->getCondition(), - builder.CreateBinOp( - static_cast(key.opcode), lhs1, - select_inst->getTrueValue()), - builder.CreateBinOp( - static_cast(key.opcode), rhs1, - select_inst->getFalseValue()), + builder.CreateBinOp(static_cast(opcode), + lhs1, select_inst->getTrueValue()), + builder.CreateBinOp(static_cast(opcode), + rhs1, select_inst->getFalseValue()), "lol2-"); } @@ -494,12 +462,10 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { // if inversed return createSelectFolder( select_inst->getCondition(), - builder.CreateBinOp( - static_cast(key.opcode), - select_inst->getTrueValue(), lhs), - builder.CreateBinOp( - static_cast(key.opcode), - select_inst->getFalseValue(), rhs), + builder.CreateBinOp(static_cast(opcode), + select_inst->getTrueValue(), lhs), + builder.CreateBinOp(static_cast(opcode), + select_inst->getFalseValue(), rhs), "lol2-"); } else if (match(key.operand2, m_ZExtOrSExtOrSelf( @@ -509,37 +475,34 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { // if inversed return createSelectFolder( select_inst->getCondition(), - builder.CreateBinOp( - static_cast(key.opcode), - select_inst->getTrueValue(), lhs), - builder.CreateBinOp( - static_cast(key.opcode), - select_inst->getFalseValue(), rhs), + builder.CreateBinOp(static_cast(opcode), + select_inst->getTrueValue(), lhs), + builder.CreateBinOp(static_cast(opcode), + select_inst->getFalseValue(), rhs), "lol2-"); } newInstruction = - builder.CreateBinOp(static_cast(key.opcode), + builder.CreateBinOp(static_cast(opcode), key.operand1, key.operand2, Name); - } else if (key.cast) { + } else if (isCast(opcode)) { // Cast instruction - switch (key.opcode) { + switch (opcode) { case Instruction::Trunc: case Instruction::ZExt: case Instruction::SExt: - printvalue(key.operand1); if (auto select_inst = dyn_cast(key.operand1)) { return createSelectFolder( select_inst->getCondition(), - builder.CreateCast(static_cast(key.opcode), + builder.CreateCast(static_cast(opcode), select_inst->getTrueValue(), key.destType), - builder.CreateCast(static_cast(key.opcode), + builder.CreateCast(static_cast(opcode), select_inst->getFalseValue(), key.destType), "lol-"); } newInstruction = - builder.CreateCast(static_cast(key.opcode), + builder.CreateCast(static_cast(opcode), key.operand1, key.destType); break; // Add other cast operations as needed @@ -548,7 +511,7 @@ Value* lifterClass::getOrCreate(const InstructionKey& key, const Twine& Name) { } } - cache[key] = newInstruction; + cache.insert(opcode, key, newInstruction); return newInstruction; } @@ -558,11 +521,11 @@ Value* lifterClass::createInstruction(unsigned opcode, Value* operand1, InstructionKey key; if (destType) - key = InstructionKey(opcode, operand1, destType); + key = InstructionKey(operand1, destType); else - key = InstructionKey(opcode, operand1, operand2); + key = InstructionKey(operand1, operand2); - Value* newValue = getOrCreate(key, Name); + Value* newValue = getOrCreate(key, opcode, Name); return simplifyValue( newValue, @@ -598,10 +561,14 @@ Value* lifterClass::createSelectFolder(Value* C, Value* True, Value* False, return inst; } -KnownBits computeKnownBitsFromOperation(const KnownBits& vv1, - const KnownBits& vv2, +KnownBits computeKnownBitsFromOperation(KnownBits& vv1, KnownBits& vv2, Instruction::BinaryOps opcode) { - + if (vv1.getBitWidth() > vv2.getBitWidth()) { + vv2 = vv2.zext(vv1.getBitWidth()); + } + if (vv2.getBitWidth() > vv1.getBitWidth()) { + vv1 = vv1.zext(vv2.getBitWidth()); + } if (opcode >= Instruction::Shl && opcode <= Instruction::AShr) { auto ugt_result = KnownBits::ugt( vv2, @@ -672,8 +639,7 @@ KnownBits computeKnownBitsFromOperation(const KnownBits& vv1, } default: - outs() << "\n : " << opcode; - outs().flush(); + std::cout << "\n : " << opcode; UNREACHABLE("Unsupported operation in calculatePossibleValues.\n"); break; } @@ -798,8 +764,7 @@ Value* lifterClass::folderBinOps(Value* LHS, Value* RHS, const Twine& Name, if (ConstantInt* RHSConst = dyn_cast(RHS)) { if (RHSConst->isZero()) return LHS; - - if (RHSConst->getZExtValue() > LHS->getType()->getIntegerBitWidth()) { + if (RHSConst->getZExtValue() >= LHS->getType()->getIntegerBitWidth()) { return builder.getIntN(LHS->getType()->getIntegerBitWidth(), 0); } } @@ -869,11 +834,10 @@ Value* lifterClass::folderBinOps(Value* LHS, Value* RHS, const Twine& Name, } } // this part analyses if we can simplify the instruction - if (auto simplifiedByPM = doPatternMatching(opcode, LHS, RHS)) { - return simplifiedByPM; - } - - auto inst = createInstruction(opcode, LHS, RHS, nullptr, Name); + Value* inst; + inst = doPatternMatching(opcode, LHS, RHS); + if (!inst) + inst = createInstruction(opcode, LHS, RHS, nullptr, Name); // knownbits is recursive, and goes back 5 instructions, ideally it would be // not recursive and store the info for all values @@ -883,8 +847,6 @@ Value* lifterClass::folderBinOps(Value* LHS, Value* RHS, const Twine& Name, // road auto LHSKB = analyzeValueKnownBits(LHS, dyn_cast(inst)); auto RHSKB = analyzeValueKnownBits(RHS, dyn_cast(inst)); - printvalue2(LHSKB); - printvalue2(RHSKB); auto computedBits = computeKnownBitsFromOperation(LHSKB, RHSKB, opcode); if (computedBits.isConstant() && !computedBits.hasConflict()) { @@ -1006,8 +968,9 @@ std::optional foldKnownBits(CmpInst::Predicate P, const KnownBits& LHS, return nullopt; } -Value* ICMPPatternMatcher(IRBuilder<>& builder, CmpInst::Predicate P, - Value* LHS, Value* RHS, const Twine& Name) { +Value* ICMPPatternMatcher(IRBuilder& builder, + CmpInst::Predicate P, Value* LHS, Value* RHS, + const Twine& Name) { if (auto SI = dyn_cast(LHS)) { if (P == CmpInst::ICMP_EQ && RHS == SI->getTrueValue()) @@ -1136,7 +1099,6 @@ void lifterClass::Init_Flags() { auto one = ConstantInt::getSigned(Type::getInt1Ty(context), 1); auto two = ConstantInt::getSigned(Type::getInt1Ty(context), 2); - FlagList.resize(FLAGS_END); FlagList[FLAG_CF].set(zero); FlagList[FLAG_PF].set(zero); FlagList[FLAG_AF].set(zero); @@ -1154,8 +1116,9 @@ void lifterClass::Init_Flags() { Value* lifterClass::setFlag(const Flag flag, Value* newValue) { LLVMContext& context = builder.getContext(); newValue = createTruncFolder(newValue, Type::getInt1Ty(context)); - printvalue2((int32_t)flag) printvalue(newValue); - if (flag == FLAG_RESERVED1 || flag == FLAG_RESERVED5 || flag == FLAG_IF) + // printvalue2((int32_t)flag) printvalue(newValue); + if (flag == FLAG_RESERVED1 || flag == FLAG_RESERVED5 || flag == FLAG_IF || + flag == FLAG_DF) return nullptr; FlagList[flag].set(newValue); // Set the new value directly @@ -1165,7 +1128,8 @@ Value* lifterClass::setFlag(const Flag flag, Value* newValue) { void lifterClass::setFlag(const Flag flag, std::function calculation) { // If the flag is one of the reserved ones, do not modify - if (flag == FLAG_RESERVED1 || flag == FLAG_RESERVED5 || flag == FLAG_IF) + if (flag == FLAG_RESERVED1 || flag == FLAG_RESERVED5 || flag == FLAG_IF || + flag == FLAG_DF) return; // lazy calculation @@ -1181,18 +1145,14 @@ Value* lifterClass::getFlag(const Flag flag) { return ConstantInt::getSigned(Type::getInt1Ty(context), 0); } -// destroy these functions below -RegisterManager& lifterClass::getRegisters() { return Registers; } -void lifterClass::setRegisters(RegisterManager newRegisters) { - Registers = newRegisters; -} - +// ?? Value* memoryAlloc; Value* TEB; void initMemoryAlloc(Value* allocArg) { memoryAlloc = allocArg; } Value* getMemory() { return memoryAlloc; } +// ?? -void lifterClass::InitRegisters(Function* function, ZyanU64 rip) { +void lifterClass::InitRegisters(Function* function, const ZyanU64 rip) { // rsp // rsp_unaligned = %rsp % 16 @@ -1275,9 +1235,9 @@ Value* lifterClass::GetRFLAGSValue() { int shiftAmount = flag; Value* shiftedFlagValue = createShlFolder( - createZExtFolder(flagValue, Type::getInt64Ty(context), "createrflag1"), + createZExtFolder(flagValue, Type::getInt64Ty(context), "createrflag1-"), ConstantInt::get(Type::getInt64Ty(context), shiftAmount), - "createrflag2"); + "createrflag2-"); rflags = createOrFolder(rflags, shiftedFlagValue, "creatingrflag"); } return rflags; @@ -1285,6 +1245,11 @@ Value* lifterClass::GetRFLAGSValue() { Value* lifterClass::GetRegisterValue(const ZydisRegister key) { + if (key == ZYDIS_REGISTER_RIP) { + return ConstantInt::getSigned(Type::getInt64Ty(builder.getContext()), + blockInfo.runtime_address); + } + if (key == ZYDIS_REGISTER_AH || key == ZYDIS_REGISTER_CH || key == ZYDIS_REGISTER_DH || key == ZYDIS_REGISTER_BH) { return GetValueFromHighByteRegister(key); @@ -1344,23 +1309,19 @@ Value* lifterClass::SetValueToSubRegister_8b(const ZydisRegister reg, fullRegisterValue = createZExtOrTruncFolder(fullRegisterValue, Type::getInt64Ty(context)); - uint64_t mask = 0xFFFFFFFFFFFFFFFFULL; - if (reg == ZYDIS_REGISTER_AH || reg == ZYDIS_REGISTER_CH || - reg == ZYDIS_REGISTER_DH || reg == ZYDIS_REGISTER_BH) { - mask = 0xFFFFFFFFFFFF00FFULL; - } else { - mask = 0xFFFFFFFFFFFFFF00ULL; - } - - Value* maskValue = ConstantInt::get(Type::getInt64Ty(context), mask); Value* extendedValue = createZExtFolder(value, Type::getInt64Ty(context), "extendedValue"); + bool isHighByteReg = (reg == ZYDIS_REGISTER_AH || reg == ZYDIS_REGISTER_CH || + reg == ZYDIS_REGISTER_DH || reg == ZYDIS_REGISTER_BH); + + uint64_t mask = isHighByteReg ? 0xFFFFFFFFFFFF00FFULL : 0xFFFFFFFFFFFFFF00ULL; + + Value* maskValue = ConstantInt::get(Type::getInt64Ty(context), mask); Value* maskedFullReg = createAndFolder(fullRegisterValue, maskValue, "maskedreg"); - if (reg == ZYDIS_REGISTER_AH || reg == ZYDIS_REGISTER_CH || - reg == ZYDIS_REGISTER_DH || reg == ZYDIS_REGISTER_BH) { + if (isHighByteReg) { extendedValue = createShlFolder(extendedValue, 8, "shiftedValue"); } @@ -1469,7 +1430,8 @@ Value* lifterClass::GetEffectiveAddress(const ZydisDecodedOperand& op, Type::getIntNTy(context, possiblesize)); } -Value* ConvertIntToPTR(IRBuilder<>& builder, Value* effectiveAddress) { +Value* ConvertIntToPTR(IRBuilder& builder, + Value* effectiveAddress) { LLVMContext& context = builder.getContext(); std::vector indices; diff --git a/lifter/lifter.cpp b/lifter/lifter.cpp index 4d29362..b9af1ce 100644 --- a/lifter/lifter.cpp +++ b/lifter/lifter.cpp @@ -127,7 +127,10 @@ void InitFunction_and_LiftInstructions(const ZyanU64 runtime_address, function_name.c_str(), lifting_module); const string block_name = "entry"; auto bb = llvm::BasicBlock::Create(context, block_name.c_str(), function); - llvm::IRBuilder<> builder = llvm::IRBuilder<>(bb); + + InstSimplifyFolder Folder(lifting_module.getDataLayout()); + llvm::IRBuilder builder = + llvm::IRBuilder(bb, Folder); // auto RegisterList = InitRegisters(builder, function, runtime_address);