diff --git a/gematria/datasets/find_accessed_addrs_exegesis.cc b/gematria/datasets/find_accessed_addrs_exegesis.cc index 861be66..cbedc2e 100644 --- a/gematria/datasets/find_accessed_addrs_exegesis.cc +++ b/gematria/datasets/find_accessed_addrs_exegesis.cc @@ -37,6 +37,7 @@ #include "llvm/MC/MCRegisterInfo.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/lib/Target/X86/MCTargetDesc/X86MCTargetDesc.h" #include "llvm/tools/llvm-exegesis/lib/BenchmarkCode.h" #include "llvm/tools/llvm-exegesis/lib/BenchmarkResult.h" @@ -134,21 +135,30 @@ Expected ExegesisAnnotator::findAccessedAddrs( BenchCode.Key.MemoryValues["memdef1"] = MemVal; const llvm::MCRegisterInfo &MRI = State.getRegInfo(); + std::vector UsedRegisters = gematria::getUsedRegisters( + *DisInstructions, State.getRegInfo(), State.getInstrInfo()); - for (unsigned i = 0; - i < MRI.getRegClass(X86::GR64_NOREX2RegClassID).getNumRegs(); ++i) { - RegisterValue RegVal; - RegVal.Register = - MRI.getRegClass(X86::GR64_NOREX2RegClassID).getRegister(i); - RegVal.Value = APInt(64, kInitialRegVal); - BenchCode.Key.RegisterInitialValues.push_back(RegVal); - } + for (unsigned RegisterIndex : UsedRegisters) { + RegisterAndValue *NewRegisterValue = MemAnnotations.add_initial_registers(); + NewRegisterValue->set_register_name( + State.getRegInfo().getName(RegisterIndex)); - for (unsigned i = 0; i < MRI.getRegClass(X86::VR128RegClassID).getNumRegs(); - ++i) { RegisterValue RegVal; - RegVal.Register = MRI.getRegClass(X86::VR128RegClassID).getRegister(i); - RegVal.Value = APInt(128, kInitialRegVal); + RegVal.Register = RegisterIndex; + + if (MRI.getRegClass(X86::GR64_NOREX2RegClassID).contains(RegisterIndex)) { + RegVal.Value = APInt(64, kInitialRegVal); + NewRegisterValue->set_register_value(kInitialRegVal); + } else if (MRI.getRegClass(X86::VR128RegClassID).contains(RegisterIndex)) { + RegVal.Value = APInt(128, kInitialRegVal); + NewRegisterValue->set_register_value(kInitialRegVal); + } else if (RegisterIndex == X86::EFLAGS) { + RegVal.Value = 0; + NewRegisterValue->set_register_value(0); + } else { + report_fatal_error("Found unhandled register case for used register."); + } + BenchCode.Key.RegisterInitialValues.push_back(RegVal); } @@ -211,19 +221,8 @@ Expected ExegesisAnnotator::findAccessedAddrs( MemAnnotations.add_accessed_blocks(Mapping.Address); } - std::vector UsedRegisters = gematria::getUsedRegisters( - *DisInstructions, State.getRegInfo(), State.getInstrInfo()); - MemAnnotations.mutable_initial_registers()->Reserve(UsedRegisters.size()); - for (const unsigned UsedRegister : UsedRegisters) { - RegisterAndValue *new_register_value = - MemAnnotations.add_initial_registers(); - new_register_value->set_register_name( - State.getRegInfo().getName(UsedRegister)); - new_register_value->set_register_value(kInitialRegVal); - } - std::optional LoopRegister = gematria::getUnusedGPRegister( *DisInstructions, State.getRegInfo(), State.getInstrInfo()); diff --git a/gematria/datasets/find_accessed_addrs_exegesis_test.cc b/gematria/datasets/find_accessed_addrs_exegesis_test.cc index b8a97a8..d93d03e 100644 --- a/gematria/datasets/find_accessed_addrs_exegesis_test.cc +++ b/gematria/datasets/find_accessed_addrs_exegesis_test.cc @@ -160,5 +160,16 @@ TEST_F(FindAccessedAddrsExegesisTest, DISABLED_QuitMaxAnnotationAttempts) { ASSERT_FALSE(static_cast(AddrsOrErr)); } +TEST_F(FindAccessedAddrsExegesisTest, MovsqImplictDfUse) { + // Test that we can successfully find the accessed addrs for a movsq + // instruction, which makes things more complicated by explicitly using + // the df register. We do not care about the specific addresses in this + // case. + auto AddrsOrErr = FindAccessedAddrsExegesis(R"asm( + movsq + )asm"); + ASSERT_TRUE(static_cast(AddrsOrErr)); +} + } // namespace } // namespace gematria