diff --git a/riru/src/main/cpp/CMakeLists.txt b/riru/src/main/cpp/CMakeLists.txt index 42cd2f18..ddc9e0ec 100644 --- a/riru/src/main/cpp/CMakeLists.txt +++ b/riru/src/main/cpp/CMakeLists.txt @@ -55,7 +55,8 @@ add_library(utils STATIC util/pmparser.c util/selinux.cpp util/socket.cpp - util/tinynew.cpp) + util/tinynew.cpp + util/plt.c) add_library(riru SHARED main.cpp diff --git a/riru/src/main/cpp/main.cpp b/riru/src/main/cpp/main.cpp index e33bde00..0e19aa5d 100644 --- a/riru/src/main/cpp/main.cpp +++ b/riru/src/main/cpp/main.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include "misc.h" #include "jni_native_method.h" #include "logging.h" @@ -14,9 +16,16 @@ static int sdkLevel; static int previewSdkLevel; static char androidVersionName[PROP_VALUE_MAX + 1]; +static bool useTableOverride = false; -static JNINativeMethod *onRegisterZygote( - JNIEnv *env, const char *className, const JNINativeMethod *methods, int numMethods) { +using GetJniNativeInterface_t = const JNINativeInterface *(); +using SetTableOverride_t = void(JNINativeInterface *); +using RegisterNatives_t = jint(JNIEnv *, jclass, const JNINativeMethod *, jint); + +static SetTableOverride_t *setTableOverride = nullptr; +static RegisterNatives_t *old_RegisterNatives = nullptr; + +static JNINativeMethod *onRegisterZygote(const char *className, const JNINativeMethod *methods, int numMethods) { auto *newMethods = new JNINativeMethod[numMethods]; memcpy(newMethods, methods, sizeof(JNINativeMethod) * numMethods); @@ -115,8 +124,7 @@ static JNINativeMethod *onRegisterZygote( return newMethods; } -static JNINativeMethod *onRegisterSystemProperties( - JNIEnv *env, const char *className, const JNINativeMethod *methods, int numMethods) { +static JNINativeMethod *onRegisterSystemProperties(const char *className, const JNINativeMethod *methods, int numMethods) { auto *newMethods = new JNINativeMethod[numMethods]; memcpy(newMethods, methods, sizeof(JNINativeMethod) * numMethods); @@ -144,12 +152,25 @@ static JNINativeMethod *onRegisterSystemProperties( return newMethods; } +static JNINativeMethod *handleRegisterNative(const char *className, const JNINativeMethod *methods, int numMethods) { + if (strcmp("com/android/internal/os/Zygote", className) == 0) { + return onRegisterZygote(className, methods, numMethods); + } else if (strcmp("android/os/SystemProperties", className) == 0) { + // hook android.os.SystemProperties#native_set to prevent a critical problem on Android 9 + // see comment of SystemProperties_set in jni_native_method.cpp for detail + return onRegisterSystemProperties(className, methods, numMethods); + } else { + return nullptr; + } +} + #define XHOOK_REGISTER(PATH_REGEX, NAME) \ if (xhook_register(PATH_REGEX, #NAME, (void*) new_##NAME, (void **) &old_##NAME) != 0) \ LOGE("failed to register hook " #NAME "."); \ #define NEW_FUNC_DEF(ret, func, ...) \ - static ret (*old_##func)(__VA_ARGS__); \ + using func##_t = ret(__VA_ARGS__); \ + static func##_t *old_##func; \ static ret new_##func(__VA_ARGS__) NEW_FUNC_DEF(int, jniRegisterNativeMethods, JNIEnv *env, const char *className, @@ -158,17 +179,8 @@ NEW_FUNC_DEF(int, jniRegisterNativeMethods, JNIEnv *env, const char *className, LOGD("jniRegisterNativeMethods %s", className); - JNINativeMethod *newMethods = nullptr; - if (strcmp("com/android/internal/os/Zygote", className) == 0) { - newMethods = onRegisterZygote(env, className, methods, numMethods); - } else if (strcmp("android/os/SystemProperties", className) == 0) { - // hook android.os.SystemProperties#native_set to prevent a critical problem on Android 9 - // see comment of SystemProperties_set in jni_native_method.cpp for detail - newMethods = onRegisterSystemProperties(env, className, methods, numMethods); - } - - int res = old_jniRegisterNativeMethods(env, className, newMethods ? newMethods : methods, - numMethods); + JNINativeMethod *newMethods = handleRegisterNative(className, methods, numMethods); + int res = old_jniRegisterNativeMethods(env, className, newMethods ? newMethods : methods, numMethods); /*if (!newMethods) { NativeMethod::jniRegisterNativeMethodsPost(env, className, methods, numMethods); }*/ @@ -176,19 +188,65 @@ NEW_FUNC_DEF(int, jniRegisterNativeMethods, JNIEnv *env, const char *className, return res; } -void restore_replaced_func(JNIEnv *env) { - xhook_register(".*\\libandroid_runtime.so$", "jniRegisterNativeMethods", - (void *) old_jniRegisterNativeMethods, - nullptr); - if (xhook_refresh(0) == 0) { - xhook_clear(); - LOGD("hook removed"); +static jclass zygoteClass; +static jclass systemPropertiesClass; + +static void prepareClassesForRegisterNativeHook(JNIEnv *env) { + static bool called = false; + if (called) return; + called = true; + + auto _zygoteClass = env->FindClass("com/android/internal/os/Zygote"); + auto _systemPropertiesClass = env->FindClass("android/os/SystemProperties"); + + // There are checks that enforces no local refs exists during Runtime::Start, make them global ref + zygoteClass = (jclass) env->NewGlobalRef(_zygoteClass); + systemPropertiesClass = (jclass) env->NewGlobalRef(_systemPropertiesClass); + + env->DeleteLocalRef(_zygoteClass); + env->DeleteLocalRef(_systemPropertiesClass); +} + +static int new_RegisterNative(JNIEnv *env, jclass cls, const JNINativeMethod *methods, jint numMethods) { + prepareClassesForRegisterNativeHook(env); + + const char *className; + if (zygoteClass != nullptr && env->IsSameObject(zygoteClass, cls)) { + className = "com/android/internal/os/Zygote"; + LOGD("RegisterNative %s", className); + env->DeleteGlobalRef(zygoteClass); + } else if (systemPropertiesClass != nullptr && env->IsSameObject(systemPropertiesClass, cls)) { + className = "android/os/SystemProperties"; + LOGD("RegisterNative %s", className); + env->DeleteGlobalRef(systemPropertiesClass); + } else { + className = ""; } -#define restoreMethod(cls, method) \ - if (JNI::cls::method != nullptr) { \ - old_jniRegisterNativeMethods(env, JNI::cls::classname, JNI::cls::method, 1); \ - delete JNI::cls::method; \ + JNINativeMethod *newMethods = handleRegisterNative(className, methods, numMethods); + auto res = old_RegisterNatives(env, cls, newMethods ? newMethods : methods, numMethods); + delete newMethods; + return res; +} + +#define restoreMethod(_cls, method) \ + if (JNI::_cls::method != nullptr) { \ + if (old_jniRegisterNativeMethods) \ + old_jniRegisterNativeMethods(env, JNI::_cls::classname, JNI::_cls::method, 1); \ + delete JNI::_cls::method; \ + } + +void restore_replaced_func(JNIEnv *env) { + if (useTableOverride) { + setTableOverride(nullptr); + } else { + xhook_register(".*\\libandroid_runtime.so$", "jniRegisterNativeMethods", + (void *) old_jniRegisterNativeMethods, + nullptr); + if (xhook_refresh(0) == 0) { + xhook_clear(); + LOGD("hook removed"); + } } restoreMethod(Zygote, nativeForkAndSpecialize) @@ -252,6 +310,34 @@ void constructor() { LOGE("failed to refresh hook"); } + useTableOverride = old_jniRegisterNativeMethods == nullptr; + + if (useTableOverride) { + LOGI("no jniRegisterNativeMethods"); + + auto *GetJniNativeInterface = (GetJniNativeInterface_t *) plt_dlsym("_ZN3art21GetJniNativeInterfaceEv", nullptr); + setTableOverride = (SetTableOverride_t *) plt_dlsym("_ZN3art9JNIEnvExt16SetTableOverrideEPK18JNINativeInterface", nullptr); + + if (setTableOverride != nullptr && GetJniNativeInterface != nullptr) { + auto functions = GetJniNativeInterface(); + auto new_JNINativeInterface = new JNINativeInterface(); + memcpy(new_JNINativeInterface, functions, sizeof(JNINativeInterface)); + old_RegisterNatives = functions->RegisterNatives; + new_JNINativeInterface->RegisterNatives = new_RegisterNative; + + setTableOverride(new_JNINativeInterface); + LOGI("override table installed"); + } else { + if (GetJniNativeInterface == nullptr) LOGE("cannot find GetJniNativeInterface"); + if (setTableOverride == nullptr) LOGE("cannot find setTableOverride"); + } + + auto handle = dlopen("libnativehelper.so", 0); + if (handle) { + old_jniRegisterNativeMethods = (jniRegisterNativeMethods_t *) dlsym(handle, "jniRegisterNativeMethods"); + } + } + load_modules(); Status::WriteSelfAndModules(); diff --git a/riru/src/main/cpp/util/plt.c b/riru/src/main/cpp/util/plt.c new file mode 100644 index 00000000..3f48c8fd --- /dev/null +++ b/riru/src/main/cpp/util/plt.c @@ -0,0 +1,377 @@ +#include +#include +#include +#include +#include +#include "plt.h" + +/* + * reference: https://android.googlesource.com/platform/bionic/+/master/linker/linker_soinfo.cpp + */ +static uint32_t gnu_hash(const uint8_t *name) { + uint32_t h = 5381; + + while (*name) { + h += (h << 5) + *name++; + } + return h; +} + +static uint32_t elf_hash(const uint8_t *name) { + uint32_t h = 0, g; + + while (*name) { + h = (h << 4) + *name++; + g = h & 0xf0000000; + h ^= g; + h ^= g >> 24; + } + + return h; +} + +static ElfW(Dyn) *find_dyn_by_tag(ElfW(Dyn) *dyn, ElfW(Sxword) tag) { +while (dyn->d_tag != DT_NULL) { +if (dyn->d_tag == tag) { +return dyn; +} +++dyn; +} +return NULL; +} + +static inline bool is_global(ElfW(Sym) *sym) { + unsigned char stb = ELF_ST_BIND(sym->st_info); + if (stb == STB_GLOBAL || stb == STB_WEAK) { + return sym->st_shndx != SHN_UNDEF; + } else { + return false; + } +} + +static ElfW(Addr) * +find_symbol(struct dl_phdr_info *info, ElfW(Dyn) *base_addr, const char *symbol) { +ElfW(Dyn) *dyn; + +dyn = find_dyn_by_tag(base_addr, DT_SYMTAB); +ElfW(Sym) *dynsym = (ElfW(Sym) *) (info->dlpi_addr + dyn->d_un.d_ptr); + +dyn = find_dyn_by_tag(base_addr, DT_STRTAB); +char *dynstr = (char *) (info->dlpi_addr + dyn->d_un.d_ptr); + +dyn = find_dyn_by_tag(base_addr, DT_GNU_HASH); +if (dyn != NULL) { +ElfW(Word) *dt_gnu_hash = (ElfW(Word) *) (info->dlpi_addr + dyn->d_un.d_ptr); +size_t gnu_nbucket_ = dt_gnu_hash[0]; +uint32_t gnu_maskwords_ = dt_gnu_hash[2]; +uint32_t gnu_shift2_ = dt_gnu_hash[3]; +ElfW(Addr) *gnu_bloom_filter_ = (ElfW(Addr) *) (dt_gnu_hash + 4); +uint32_t *gnu_bucket_ = (uint32_t *) (gnu_bloom_filter_ + gnu_maskwords_); +uint32_t *gnu_chain_ = gnu_bucket_ + gnu_nbucket_ - dt_gnu_hash[1]; + +--gnu_maskwords_; + +uint32_t hash = gnu_hash((uint8_t *) symbol); +uint32_t h2 = hash >> gnu_shift2_; + +uint32_t bloom_mask_bits = sizeof(ElfW(Addr)) * 8; +uint32_t word_num = (hash / bloom_mask_bits) & gnu_maskwords_; +ElfW(Addr) bloom_word = gnu_bloom_filter_[word_num]; + +if ((1 & (bloom_word >> (hash % bloom_mask_bits)) & +(bloom_word >> (h2 % bloom_mask_bits))) == 0) { +return NULL; +} + +uint32_t n = gnu_bucket_[hash % gnu_nbucket_]; + +if (n == 0) { +return NULL; +} + +do { +ElfW(Sym) *sym = dynsym + n; +if (((gnu_chain_[n] ^ hash) >> 1) == 0 +&& is_global(sym) +&& strcmp(dynstr + sym->st_name, symbol) == 0) { +ElfW(Addr) *symbol_sym = (ElfW(Addr) *) (info->dlpi_addr + sym->st_value); +#ifdef DEBUG_PLT +LOGI("found %s(gnu+%u) in %s, %p", symbol, n, info->dlpi_name, symbol_sym); +#endif +return symbol_sym; +} +} while ((gnu_chain_[n++] & 1) == 0); + +return NULL; +} + +dyn = find_dyn_by_tag(base_addr, DT_HASH); +if (dyn != NULL) { +ElfW(Word) *dt_hash = (ElfW(Word) *) (info->dlpi_addr + dyn->d_un.d_ptr); +size_t nbucket_ = dt_hash[0]; +uint32_t *bucket_ = dt_hash + 2; +uint32_t *chain_ = bucket_ + nbucket_; + +uint32_t hash = elf_hash((uint8_t *) (symbol)); +for (uint32_t n = bucket_[hash % nbucket_]; n != 0; n = chain_[n]) { +ElfW(Sym) *sym = dynsym + n; +if (is_global(sym) && +strcmp(dynstr + sym->st_name, symbol) == 0) { +ElfW(Addr) *symbol_sym = (ElfW(Addr) *) (info->dlpi_addr + sym->st_value); +#ifdef DEBUG_PLT +LOGI("found %s(elf+%u) in %s, %p", symbol, n, info->dlpi_name, symbol_sym); +#endif +return symbol_sym; +} +} + +return NULL; +} + +return NULL; +} + +#if defined(__LP64__) +#define Elf_Rela ElfW(Rela) +#define ELF_R_SYM ELF64_R_SYM +#else +#define Elf_Rela ElfW(Rel) +#define ELF_R_SYM ELF32_R_SYM +#endif + +#ifdef DEBUG_PLT +#if defined(__x86_64__) +#define R_JUMP_SLOT R_X86_64_JUMP_SLOT +#define ELF_R_TYPE ELF64_R_TYPE +#elif defined(__i386__) +#define R_JUMP_SLOT R_386_JMP_SLOT +#define ELF_R_TYPE ELF32_R_TYPE +#elif defined(__arm__) +#define R_JUMP_SLOT R_ARM_JUMP_SLOT +#define ELF_R_TYPE ELF32_R_TYPE +#elif defined(__aarch64__) +#define R_JUMP_SLOT R_AARCH64_JUMP_SLOT +#define ELF_R_TYPE ELF64_R_TYPE +#else +#error unsupported OS +#endif +#endif + +static ElfW(Addr) *find_plt(struct dl_phdr_info *info, ElfW(Dyn) *base_addr, const char *symbol) { +ElfW(Dyn) *dyn = find_dyn_by_tag(base_addr, DT_JMPREL); +if (dyn == NULL) { +return NULL; +} +Elf_Rela *dynplt = (Elf_Rela *) (info->dlpi_addr + dyn->d_un.d_ptr); + +dyn = find_dyn_by_tag(base_addr, DT_SYMTAB); +ElfW(Sym) *dynsym = (ElfW(Sym) *) (info->dlpi_addr + dyn->d_un.d_ptr); + +dyn = find_dyn_by_tag(base_addr, DT_STRTAB); +char *dynstr = (char *) (info->dlpi_addr + dyn->d_un.d_ptr); + +dyn = find_dyn_by_tag(base_addr, DT_PLTRELSZ); +if (dyn == NULL) { +return NULL; +} +size_t count = dyn->d_un.d_val / sizeof(Elf_Rela); + +for (size_t i = 0; i < count; ++i) { +Elf_Rela *plt = dynplt + i; +#ifdef DEBUG_PLT +if (ELF_R_TYPE(plt->r_info) != R_JUMP_SLOT) { + LOGW("invalid type for plt+%zu in %s", i, info->dlpi_name); + continue; + } +#endif +size_t idx = ELF_R_SYM(plt->r_info); +idx = dynsym[idx].st_name; +if (strcmp(dynstr + idx, symbol) == 0) { +ElfW(Addr) *symbol_plt = (ElfW(Addr) *) (info->dlpi_addr + plt->r_offset); +#ifdef DEBUG_PLT +ElfW(Addr) *symbol_plt_value = (ElfW(Addr) *) *symbol_plt; + LOGI("found %s(plt+%zu) in %s, %p -> %p", symbol, i, info->dlpi_name, symbol_plt, + symbol_plt_value); +#endif +return symbol_plt; +} +} + +return NULL; +} + +static inline bool isso(const char *str) { + if (str == NULL) { + return false; + } + const char *dot = strrchr(str, '.'); + return dot != NULL + && *++dot == 's' + && *++dot == 'o' + && (*++dot == '\0' || *dot == '\r' || *dot == '\n'); +} + +static inline bool isSystem(const char *str) { + return str != NULL + && *str == '/' + && *++str == 's' + && *++str == 'y' + && *++str == 's' + && *++str == 't' + && *++str == 'e' + && *++str == 'm' + && *++str == '/'; +} + +static inline bool isVendor(const char *str) { + return str != NULL + && *str == '/' + && *++str == 'v' + && *++str == 'e' + && *++str == 'n' + && *++str == 'd' + && *++str == 'o' + && *++str == 'r' + && *++str == '/'; +} + +static inline bool isOem(const char *str) { + return str != NULL + && *str == '/' + && *++str == 'o' + && *++str == 'e' + && *++str == 'm' + && *++str == '/'; +} + +static inline bool isThirdParty(const char *str) { + if (isSystem(str) || isVendor(str) || isOem(str)) { + return false; + } else { + return true; + } +} + +static inline bool should_check_plt(Symbol *symbol, struct dl_phdr_info *info) { + const char *path = info->dlpi_name; + if (symbol->check & PLT_CHECK_PLT_ALL) { + return true; + } else if (symbol->check & PLT_CHECK_PLT_APP) { + return *path != '/' || isThirdParty(path); + } else { + return false; + } +} + +static int callback(struct dl_phdr_info *info, __unused size_t size, void *data) { + if (!isso(info->dlpi_name)) { +#ifdef DEBUG_PLT + LOGW("ignore non-so: %s", info->dlpi_name); +#endif + return 0; + } + Symbol *symbol = (Symbol *) data; +#if 0 + LOGI("Name: \"%s\" (%d segments)", info->dlpi_name, info->dlpi_phnum); +#endif + ++symbol->total; + for (ElfW(Half) phdr_idx = 0; phdr_idx < info->dlpi_phnum; ++phdr_idx) { + ElfW(Phdr) phdr = info->dlpi_phdr[phdr_idx]; + if (phdr.p_type != PT_DYNAMIC) { + continue; + } + ElfW(Dyn) *base_addr = (ElfW(Dyn) *) (info->dlpi_addr + phdr.p_vaddr); + ElfW(Addr) *addr; + addr = should_check_plt(symbol, info) ? find_plt(info, base_addr, symbol->symbol_name) : NULL; + if (addr != NULL) { + if (symbol->symbol_plt != NULL) { + ElfW(Addr) *addr_value = (ElfW(Addr) *) *addr; + ElfW(Addr) *symbol_plt_value = (ElfW(Addr) *) *symbol->symbol_plt; + if (addr_value != symbol_plt_value) { +#ifdef DEBUG_PLT + LOGW("%s, plt %p -> %p != %p", symbol->symbol_name, addr, addr_value, + symbol_plt_value); +#endif + return 1; + } + } + symbol->symbol_plt = addr; + if (symbol->check & PLT_CHECK_NAME) { + if (symbol->size == 0) { + symbol->size = 1; + symbol->names = calloc(1, sizeof(char *)); + } else { + ++symbol->size; + symbol->names = realloc(symbol->names, symbol->size * sizeof(char *)); + } +#ifdef DEBUG_PLT + LOGI("[%d]: %s", symbol->size - 1, info->dlpi_name); +#endif + symbol->names[symbol->size - 1] = strdup(info->dlpi_name); + } + } + addr = find_symbol(info, base_addr, symbol->symbol_name); + if (addr != NULL) { + symbol->symbol_sym = addr; + if (symbol->check == PLT_CHECK_SYM_ONE) { + return PLT_CHECK_SYM_ONE; + } + } + if (symbol->symbol_plt != NULL && symbol->symbol_sym != NULL) { + ElfW(Addr) *symbol_plt_value = (ElfW(Addr) *) *symbol->symbol_plt; + // stop if unmatch + if (symbol_plt_value != symbol->symbol_sym) { +#ifdef DEBUG_PLT + LOGW("%s, plt: %p -> %p != %p", symbol->symbol_name, symbol->symbol_plt, + symbol_plt_value, symbol->symbol_sym); +#endif + return 1; + } + } + } + return 0; +} + +void *plt_dlsym(const char *name, size_t *total) { + Symbol symbol; + memset(&symbol, 0, sizeof(Symbol)); + if (total == NULL) { + symbol.check = PLT_CHECK_SYM_ONE; + } + symbol.symbol_name = name; + dl_iterate_phdr_symbol(&symbol); + if (total != NULL) { + *total = symbol.total; + } + return symbol.symbol_sym; +} + +int dl_iterate_phdr_symbol(Symbol *symbol) { + int result; +#ifdef DEBUG_PLT + LOGI("start dl_iterate_phdr: %s", symbol->symbol_name); +#endif +#if __ANDROID_API__ >= 21 || !defined(__arm__) + result = dl_iterate_phdr(callback, symbol); +#else + int (*dl_iterate_phdr)(int (*)(struct dl_phdr_info *, size_t, void *), void *); + dl_iterate_phdr = dlsym(RTLD_NEXT, "dl_iterate_phdr"); + if (dl_iterate_phdr != NULL) { + result = dl_iterate_phdr(callback, symbol); + } else { + result = 0; + void *handle = dlopen("libdl.so", RTLD_NOW); + dl_iterate_phdr = dlsym(handle, "dl_iterate_phdr"); + if (dl_iterate_phdr != NULL) { + result = dl_iterate_phdr(callback, symbol); + } else { + LOGW("cannot dlsym dl_iterate_phdr"); + } + dlclose(handle); + } +#endif +#ifdef DEBUG_PLT + LOGI("complete dl_iterate_phdr: %s", symbol->symbol_name); +#endif + return result; +} \ No newline at end of file diff --git a/riru/src/main/cpp/util/plt.h b/riru/src/main/cpp/util/plt.h new file mode 100644 index 00000000..4ab3a784 --- /dev/null +++ b/riru/src/main/cpp/util/plt.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define PLT_CHECK_PLT_APP ((unsigned short) 0x1u) +#define PLT_CHECK_PLT_ALL ((unsigned short) 0x2u) +#define PLT_CHECK_NAME ((unsigned short) 0x4u) +#define PLT_CHECK_SYM_ONE ((unsigned short) 0x8u) + +typedef struct Symbol { + unsigned short check; + unsigned short size; + size_t total; + ElfW(Addr) *symbol_plt; + ElfW(Addr) *symbol_sym; + const char *symbol_name; + char **names; +} Symbol; + +int dl_iterate_phdr_symbol(Symbol *symbol); + +void *plt_dlsym(const char *name, size_t *total); + +#ifdef __cplusplus +} +#endif \ No newline at end of file