#include #include #include #include "safetyhook/common.hpp" #include "safetyhook/utility.hpp" #if SAFETYHOOK_OS_WINDOWS #define NOMINMAX #if __has_include() #include #elif __has_include() #include #else #error "Windows.h not found" #endif #include "safetyhook/os.hpp" namespace safetyhook { std::expected vm_allocate(uint8_t* address, size_t size, VmAccess access) { DWORD protect = 0; if (access == VM_ACCESS_R) { protect = PAGE_READONLY; } else if (access == VM_ACCESS_RW) { protect = PAGE_READWRITE; } else if (access == VM_ACCESS_RX) { protect = PAGE_EXECUTE_READ; } else if (access == VM_ACCESS_RWX) { protect = PAGE_EXECUTE_READWRITE; } else { return std::unexpected{OsError::FAILED_TO_ALLOCATE}; } auto* result = VirtualAlloc(address, size, MEM_COMMIT | MEM_RESERVE, protect); if (result == nullptr) { return std::unexpected{OsError::FAILED_TO_ALLOCATE}; } return static_cast(result); } void vm_free(uint8_t* address) { VirtualFree(address, 0, MEM_RELEASE); } std::expected vm_protect(uint8_t* address, size_t size, VmAccess access) { DWORD protect = 0; if (access == VM_ACCESS_R) { protect = PAGE_READONLY; } else if (access == VM_ACCESS_RW) { protect = PAGE_READWRITE; } else if (access == VM_ACCESS_RX) { protect = PAGE_EXECUTE_READ; } else if (access == VM_ACCESS_RWX) { protect = PAGE_EXECUTE_READWRITE; } else { return std::unexpected{OsError::FAILED_TO_PROTECT}; } return vm_protect(address, size, protect); } std::expected vm_protect(uint8_t* address, size_t size, uint32_t protect) { DWORD old_protect = 0; if (VirtualProtect(address, size, protect, &old_protect) == FALSE) { return std::unexpected{OsError::FAILED_TO_PROTECT}; } return old_protect; } std::expected vm_query(uint8_t* address) { MEMORY_BASIC_INFORMATION mbi{}; auto result = VirtualQuery(address, &mbi, sizeof(mbi)); if (result == 0) { return std::unexpected{OsError::FAILED_TO_QUERY}; } VmAccess access{}; access.read = (mbi.Protect & (PAGE_READONLY | PAGE_READWRITE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) != 0; access.write = (mbi.Protect & (PAGE_READWRITE | PAGE_EXECUTE_READWRITE)) != 0; access.execute = (mbi.Protect & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) != 0; VmBasicInfo info{}; info.address = static_cast(mbi.AllocationBase); info.size = mbi.RegionSize; info.access = access; info.is_free = mbi.State == MEM_FREE; return info; } bool vm_is_readable(uint8_t* address, size_t size) { return IsBadReadPtr(address, size) == FALSE; } bool vm_is_writable(uint8_t* address, size_t size) { return IsBadWritePtr(address, size) == FALSE; } bool vm_is_executable(uint8_t* address) { // Check if the address is in a valid module allowing us to potentially skip a heavier memory query. HMODULE image{}; if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, reinterpret_cast(address), &image) || image == nullptr) { return vm_query(address).value_or(VmBasicInfo{}).access.execute; } // Just check if the section is executable. const auto* image_base = reinterpret_cast(image); const auto* dos_hdr = reinterpret_cast(image_base); if (dos_hdr->e_magic != IMAGE_DOS_SIGNATURE) { return vm_query(address).value_or(VmBasicInfo{}).access.execute; } const auto* nt_hdr = reinterpret_cast(image_base + dos_hdr->e_lfanew); if (nt_hdr->Signature != IMAGE_NT_SIGNATURE) { return vm_query(address).value_or(VmBasicInfo{}).access.execute; } const auto* section = IMAGE_FIRST_SECTION(nt_hdr); for (auto i = 0; i < nt_hdr->FileHeader.NumberOfSections; ++i, ++section) { if (address >= image_base + section->VirtualAddress && address < image_base + section->VirtualAddress + section->Misc.VirtualSize) { return (section->Characteristics & IMAGE_SCN_MEM_EXECUTE) != 0; } } return vm_query(address).value_or(VmBasicInfo{}).access.execute; } SystemInfo system_info() { SystemInfo info{}; SYSTEM_INFO si{}; GetSystemInfo(&si); info.page_size = si.dwPageSize; info.allocation_granularity = si.dwAllocationGranularity; info.min_address = static_cast(si.lpMinimumApplicationAddress); info.max_address = static_cast(si.lpMaximumApplicationAddress); return info; } struct TrapInfo { uint8_t* from_page_start; uint8_t* from_page_end; uint8_t* from; uint8_t* to_page_start; uint8_t* to_page_end; uint8_t* to; size_t len; }; class TrapManager final { public: static std::mutex mutex; static std::unique_ptr instance; static bool is_destructed; TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); } ~TrapManager() { if (m_trap_veh != nullptr) { RemoveVectoredExceptionHandler(m_trap_veh); } is_destructed = true; } TrapInfo* find_trap(uint8_t* address) { auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { return address >= trap.second.from && address < trap.second.from + trap.second.len; }); if (search == m_traps.end()) { return nullptr; } return &search->second; } TrapInfo* find_trap_page(uint8_t* address) { auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { return address >= trap.second.from_page_start && address < trap.second.from_page_end; }); if (search != m_traps.end()) { return &search->second; } search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { return address >= trap.second.to_page_start && address < trap.second.to_page_end; }); if (search != m_traps.end()) { return &search->second; } return nullptr; } void add_trap(uint8_t* from, uint8_t* to, size_t len) { TrapInfo info{}; info.from_page_start = align_down(from, 0x1000); info.from_page_end = align_up(from + len, 0x1000); info.from = from; info.to_page_start = align_down(to, 0x1000); info.to_page_end = align_up(to + len, 0x1000); info.to = to; info.len = len; m_traps.insert_or_assign(from, std::move(info)); } private: std::map m_traps; PVOID m_trap_veh{}; static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) { auto exception_code = exp->ExceptionRecord->ExceptionCode; if (exception_code != EXCEPTION_ACCESS_VIOLATION) { return EXCEPTION_CONTINUE_SEARCH; } std::scoped_lock lock{mutex}; auto* faulting_address = reinterpret_cast(exp->ExceptionRecord->ExceptionInformation[1]); auto* trap = instance->find_trap(faulting_address); if (trap == nullptr) { if (instance->find_trap_page(faulting_address) != nullptr) { return EXCEPTION_CONTINUE_EXECUTION; } else { return EXCEPTION_CONTINUE_SEARCH; } } auto* ctx = exp->ContextRecord; for (size_t i = 0; i < trap->len; i++) { fix_ip(ctx, trap->from + i, trap->to + i); } return EXCEPTION_CONTINUE_EXECUTION; } }; std::mutex TrapManager::mutex; std::unique_ptr TrapManager::instance; bool TrapManager::is_destructed = false; void find_me() { } static std::mutex virtual_protect_mutex; void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn) { MEMORY_BASIC_INFORMATION find_me_mbi{}; MEMORY_BASIC_INFORMATION from_mbi{}; MEMORY_BASIC_INFORMATION to_mbi{}; VirtualQuery(reinterpret_cast(find_me), &find_me_mbi, sizeof(find_me_mbi)); VirtualQuery(from, &from_mbi, sizeof(from_mbi)); VirtualQuery(to, &to_mbi, sizeof(to_mbi)); auto new_protect = PAGE_READWRITE; if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) { new_protect = PAGE_EXECUTE_READWRITE; } auto si = system_info(); auto* from_page_start = align_down(from, si.page_size); auto* from_page_end = align_up(from + len, si.page_size); auto* vp_start = reinterpret_cast(&VirtualProtect); auto* vp_end = vp_start + 0x20; if (!(from_page_end < vp_start || vp_end < from_page_start)) { new_protect = PAGE_EXECUTE_READWRITE; } if (!TrapManager::is_destructed) { std::scoped_lock lock{TrapManager::mutex}; if (TrapManager::instance == nullptr) { TrapManager::instance = std::make_unique(); } TrapManager::instance->add_trap(from, to, len); } // Make sure we aren't working on a different address in the same memory page on a different thread. std::scoped_lock vp_lock{virtual_protect_mutex}; DWORD from_protect; DWORD to_protect; VirtualProtect(from, len, new_protect, &from_protect); VirtualProtect(to, len, new_protect, &to_protect); if (run_fn) { run_fn(); } VirtualProtect(to, len, to_protect, &to_protect); VirtualProtect(from, len, from_protect, &from_protect); } void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) { auto* ctx = reinterpret_cast(thread_ctx); #if SAFETYHOOK_ARCH_X86_64 auto ip = ctx->Rip; #elif SAFETYHOOK_ARCH_X86_32 auto ip = ctx->Eip; #endif if (ip == reinterpret_cast(old_ip)) { ip = reinterpret_cast(new_ip); } #if SAFETYHOOK_ARCH_X86_64 ctx->Rip = ip; #elif SAFETYHOOK_ARCH_X86_32 ctx->Eip = ip; #endif } } // namespace safetyhook #endif