Skip to content

Commit

Permalink
core: Add common signal dispatch system and use for on-demand TCB pat…
Browse files Browse the repository at this point in the history
…ches.
  • Loading branch information
squidbus committed Sep 10, 2024
1 parent c08c528 commit fca438b
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 269 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ set(CORE src/core/aerolib/stubs.cpp
src/core/module.cpp
src/core/module.h
src/core/platform.h
src/core/signals.cpp
src/core/signals.h
src/core/tls.cpp
src/core/tls.h
src/core/virtual_memory.cpp
Expand Down
130 changes: 46 additions & 84 deletions src/core/cpu_patches.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <Zydis/Zydis.h>
#include <xbyak/xbyak.h>
#include "common/alignment.h"
#include "common/assert.h"
#include "common/types.h"
#include "core/signals.h"
#include "core/tls.h"
#include "cpu_patches.h"

Expand Down Expand Up @@ -537,7 +539,7 @@ static bool FilterRosetta2Only(const ZydisDecodedOperand*) {
return ret;
}

#endif // __APPLE__
#else // __APPLE__

static bool FilterTcbAccess(const ZydisDecodedOperand* operands) {
const auto& dst_op = operands[0];
Expand Down Expand Up @@ -583,6 +585,8 @@ static void GenerateTcbAccess(const ZydisDecodedOperand* operands, Xbyak::CodeGe
#endif
}

#endif // __APPLE__

using PatchFilter = bool (*)(const ZydisDecodedOperand*);
using InstructionGenerator = void (*)(const ZydisDecodedOperand*, Xbyak::CodeGenerator&);
struct PatchInfo {
Expand All @@ -596,7 +600,14 @@ struct PatchInfo {
bool trampoline;
};

static const std::unordered_map<ZydisMnemonic, PatchInfo> OnDemandPatches = {
static const std::unordered_map<ZydisMnemonic, PatchInfo> Patches = {
#if defined(_WIN32)
// Windows needs a trampoline.
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, true}},
#elif !defined(__APPLE__)
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, false}},
#endif

#ifdef __APPLE__
// Patches for instruction sets not supported by Rosetta 2.
// BMI1
Expand All @@ -611,20 +622,8 @@ static const std::unordered_map<ZydisMnemonic, PatchInfo> OnDemandPatches = {
#endif
};

// TODO: Currently only illegal instruction patches are set up to be caught at runtime.
// TODO: These other patches like TCB access should be moved into the same system in the future.
static const std::unordered_map<ZydisMnemonic, PatchInfo> StartupPatches = {
#if defined(_WIN32)
// Windows needs a trampoline.
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, true}},
#elif !defined(__APPLE__)
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, false}},
#endif
};

static std::once_flag init_flag;
static ZydisDecoder instr_decoder;
static ZydisFormatter instr_formatter;

struct PatchModule {
/// Mutex controlling access to module code regions.
Expand All @@ -636,6 +635,9 @@ struct PatchModule {
/// End of the module.
u8* end;

/// Tracker for patched code locations.
std::set<u8*> patched;

/// Code generator for patching the module.
Xbyak::CodeGenerator patch_gen;

Expand All @@ -656,32 +658,27 @@ static PatchModule& GetModule(const void* ptr) {
return std::prev(upper_bound)->second;
}

static u64 TryPatch(u8* code, PatchModule& module,
const std::unordered_map<ZydisMnemonic, PatchInfo>& patches,
bool required = false) {
static bool TryPatch(void* code_address) {
auto* code = static_cast<u8*>(code_address);
auto& module = GetModule(code);

std::unique_lock lock{module.mutex};

// Return early if already patched, in case multiple threads signaled at the same time.
if (std::ranges::find(module.patched, code) != module.patched.end()) {
return true;
}

ZydisDecodedInstruction instruction;
ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT];
const auto status =
ZydisDecoderDecodeFull(&instr_decoder, code, module.end - code, &instruction, operands);
if (!ZYAN_SUCCESS(status)) {
if (required) {
UNREACHABLE_MSG("Unable to decode instruction at {}", fmt::ptr(code));
}
return 1;
}

// Assume a jmp is an existing patch, in case multiple threads signaled at the same time.
if (instruction.mnemonic == ZYDIS_MNEMONIC_JMP) {
if (required) {
LOG_INFO(Core, "Skipping already patched instruction at {}", fmt::ptr(code));
}
return instruction.length;
return false;
}

if (patches.contains(instruction.mnemonic)) {
const auto& patch_info = patches.at(instruction.mnemonic);
if (Patches.contains(instruction.mnemonic)) {
const auto& patch_info = Patches.at(instruction.mnemonic);
if (patch_info.filter(operands)) {
auto& patch_gen = module.patch_gen;

Expand Down Expand Up @@ -714,59 +711,34 @@ static u64 TryPatch(u8* code, PatchModule& module,
// Fill remaining space with nops.
patch_gen.nop(instruction.length - patch_size);

module.patched.insert(code);
LOG_DEBUG(Core, "Patched instruction '{}' at: {}",
ZydisMnemonicGetString(instruction.mnemonic), fmt::ptr(code));
return instruction.length;
return true;
}
}
}

if (required) {
char buffer[256];
ZydisFormatterFormatInstruction(&instr_formatter, &instruction, operands,
instruction.operand_count_visible, buffer, sizeof(buffer),
reinterpret_cast<u64>(code), ZYAN_NULL);
UNIMPLEMENTED_MSG("Encountered instruction at {} without patch: {}", fmt::ptr(code),
buffer);
}

return instruction.length;
return false;
}

#if defined(_WIN32)
static LONG WINAPI SignalHandler(EXCEPTION_POINTERS* pExp) noexcept {
const u32 ec = pExp->ExceptionRecord->ExceptionCode;
if (ec == EXCEPTION_ILLEGAL_INSTRUCTION) {
auto* code = reinterpret_cast<u8*>(pExp->ExceptionRecord->ExceptionAddress);
auto& module = GetModule(code);
TryPatch(code, module, OnDemandPatches, true);
return EXCEPTION_CONTINUE_EXECUTION;
}
return EXCEPTION_CONTINUE_SEARCH;
static bool PatchesAccessViolationHandler(void* code_address, void* fault_address, bool is_write) {
return TryPatch(code_address);
}
#else
static void SignalHandler(int sig, siginfo_t* info, void* raw_context) {
auto* code = static_cast<u8*>(info->si_addr);
auto& module = GetModule(code);
TryPatch(code, module, OnDemandPatches, true);

static bool PatchesIllegalInstructionHandler(void* code_address) {
return TryPatch(code_address);
}
#endif

static void PatchesInit() {
ZydisDecoderInit(&instr_decoder, ZYDIS_MACHINE_MODE_LONG_64, ZYDIS_STACK_WIDTH_64);
ZydisFormatterInit(&instr_formatter, ZYDIS_FORMATTER_STYLE_INTEL);

if (!OnDemandPatches.empty()) {
#if defined(_WIN32)
ASSERT_MSG(AddVectoredExceptionHandler(0, SignalHandler),
"Failed to register code patching exception handler.");
#else
constexpr struct sigaction action {
.sa_flags = SA_SIGINFO | SA_ONSTACK, .sa_sigaction = SignalHandler, .sa_mask = 0,
};
ASSERT_MSG(sigaction(SIGILL, &action, nullptr) == 0,
"Failed to register code patching signal handler.");
#endif
if (!Patches.empty()) {
auto* signals = Signals::Instance();
// Should be called last.
constexpr auto priority = std::numeric_limits<u32>::max();
signals->RegisterAccessViolationHandler(PatchesAccessViolationHandler, priority);
signals->RegisterIllegalInstructionHandler(PatchesIllegalInstructionHandler, priority);
}
}

Expand All @@ -782,24 +754,14 @@ void RegisterPatchModule(void* module_ptr, u64 module_size, void* trampoline_are
}

void PrePatchInstructions(u64 segment_addr, u64 segment_size) {
auto& module = GetModule(reinterpret_cast<void*>(segment_addr));

if (!StartupPatches.empty()) {
u8* code = reinterpret_cast<u8*>(segment_addr);
u8* end = code + segment_size;
while (code < end) {
code += TryPatch(code, module, StartupPatches);
}
}

#ifdef __APPLE__
// HACK: For some reason patching in the signal handler at the start of a page does not work
// under Rosetta 2. Patch any instructions at the start of a page ahead of time.
if (!OnDemandPatches.empty()) {
u8* code_page = reinterpret_cast<u8*>(Common::AlignUp(segment_addr, 0x1000));
u8* end_page = code_page + Common::AlignUp(segment_size, 0x1000);
if (!Patches.empty()) {
auto* code_page = reinterpret_cast<u8*>(Common::AlignUp(segment_addr, 0x1000));
const auto* end_page = code_page + Common::AlignUp(segment_size, 0x1000);
while (code_page < end_page) {
TryPatch(code_page, module, OnDemandPatches);
TryPatch(code_page);
code_page += 0x1000;
}
}
Expand Down
Loading

0 comments on commit fca438b

Please sign in to comment.