Files
ReshadePluginsCore/external/safetyhook/src/os.windows.cpp

343 lines
10 KiB
C++
Raw Normal View History

#include <map>
#include <memory>
#include <mutex>
#include "safetyhook/common.hpp"
#include "safetyhook/utility.hpp"
#if SAFETYHOOK_OS_WINDOWS
#define NOMINMAX
#if __has_include(<Windows.h>)
#include <Windows.h>
#elif __has_include(<windows.h>)
#include <windows.h>
#else
#error "Windows.h not found"
#endif
#include "safetyhook/os.hpp"
namespace safetyhook {
std::expected<uint8_t*, OsError> vm_allocate(uint8_t* address, size_t size, VmAccess access) {
DWORD protect = 0;
if (access == VM_ACCESS_R) {
protect = PAGE_READONLY;
} else if (access == VM_ACCESS_RW) {
protect = PAGE_READWRITE;
} else if (access == VM_ACCESS_RX) {
protect = PAGE_EXECUTE_READ;
} else if (access == VM_ACCESS_RWX) {
protect = PAGE_EXECUTE_READWRITE;
} else {
return std::unexpected{OsError::FAILED_TO_ALLOCATE};
}
auto* result = VirtualAlloc(address, size, MEM_COMMIT | MEM_RESERVE, protect);
if (result == nullptr) {
return std::unexpected{OsError::FAILED_TO_ALLOCATE};
}
return static_cast<uint8_t*>(result);
}
void vm_free(uint8_t* address) {
VirtualFree(address, 0, MEM_RELEASE);
}
std::expected<uint32_t, OsError> vm_protect(uint8_t* address, size_t size, VmAccess access) {
DWORD protect = 0;
if (access == VM_ACCESS_R) {
protect = PAGE_READONLY;
} else if (access == VM_ACCESS_RW) {
protect = PAGE_READWRITE;
} else if (access == VM_ACCESS_RX) {
protect = PAGE_EXECUTE_READ;
} else if (access == VM_ACCESS_RWX) {
protect = PAGE_EXECUTE_READWRITE;
} else {
return std::unexpected{OsError::FAILED_TO_PROTECT};
}
return vm_protect(address, size, protect);
}
std::expected<uint32_t, OsError> vm_protect(uint8_t* address, size_t size, uint32_t protect) {
DWORD old_protect = 0;
if (VirtualProtect(address, size, protect, &old_protect) == FALSE) {
return std::unexpected{OsError::FAILED_TO_PROTECT};
}
return old_protect;
}
std::expected<VmBasicInfo, OsError> vm_query(uint8_t* address) {
MEMORY_BASIC_INFORMATION mbi{};
auto result = VirtualQuery(address, &mbi, sizeof(mbi));
if (result == 0) {
return std::unexpected{OsError::FAILED_TO_QUERY};
}
VmAccess access{};
access.read = (mbi.Protect & (PAGE_READONLY | PAGE_READWRITE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) != 0;
access.write = (mbi.Protect & (PAGE_READWRITE | PAGE_EXECUTE_READWRITE)) != 0;
access.execute = (mbi.Protect & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) != 0;
VmBasicInfo info{};
info.address = static_cast<uint8_t*>(mbi.AllocationBase);
info.size = mbi.RegionSize;
info.access = access;
info.is_free = mbi.State == MEM_FREE;
return info;
}
bool vm_is_readable(uint8_t* address, size_t size) {
return IsBadReadPtr(address, size) == FALSE;
}
bool vm_is_writable(uint8_t* address, size_t size) {
return IsBadWritePtr(address, size) == FALSE;
}
bool vm_is_executable(uint8_t* address) {
// Check if the address is in a valid module allowing us to potentially skip a heavier memory query.
HMODULE image{};
if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
reinterpret_cast<LPTSTR>(address), &image) ||
image == nullptr) {
return vm_query(address).value_or(VmBasicInfo{}).access.execute;
}
// Just check if the section is executable.
const auto* image_base = reinterpret_cast<uint8_t*>(image);
const auto* dos_hdr = reinterpret_cast<const IMAGE_DOS_HEADER*>(image_base);
if (dos_hdr->e_magic != IMAGE_DOS_SIGNATURE) {
return vm_query(address).value_or(VmBasicInfo{}).access.execute;
}
const auto* nt_hdr = reinterpret_cast<const IMAGE_NT_HEADERS*>(image_base + dos_hdr->e_lfanew);
if (nt_hdr->Signature != IMAGE_NT_SIGNATURE) {
return vm_query(address).value_or(VmBasicInfo{}).access.execute;
}
const auto* section = IMAGE_FIRST_SECTION(nt_hdr);
for (auto i = 0; i < nt_hdr->FileHeader.NumberOfSections; ++i, ++section) {
if (address >= image_base + section->VirtualAddress &&
address < image_base + section->VirtualAddress + section->Misc.VirtualSize) {
return (section->Characteristics & IMAGE_SCN_MEM_EXECUTE) != 0;
}
}
return vm_query(address).value_or(VmBasicInfo{}).access.execute;
}
SystemInfo system_info() {
SystemInfo info{};
SYSTEM_INFO si{};
GetSystemInfo(&si);
info.page_size = si.dwPageSize;
info.allocation_granularity = si.dwAllocationGranularity;
info.min_address = static_cast<uint8_t*>(si.lpMinimumApplicationAddress);
info.max_address = static_cast<uint8_t*>(si.lpMaximumApplicationAddress);
return info;
}
struct TrapInfo {
uint8_t* from_page_start;
uint8_t* from_page_end;
uint8_t* from;
uint8_t* to_page_start;
uint8_t* to_page_end;
uint8_t* to;
size_t len;
};
class TrapManager final {
public:
static std::mutex mutex;
static std::unique_ptr<TrapManager> instance;
static bool is_destructed;
TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); }
~TrapManager() {
if (m_trap_veh != nullptr) {
RemoveVectoredExceptionHandler(m_trap_veh);
}
is_destructed = true;
}
TrapInfo* find_trap(uint8_t* address) {
auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) {
return address >= trap.second.from && address < trap.second.from + trap.second.len;
});
if (search == m_traps.end()) {
return nullptr;
}
return &search->second;
}
TrapInfo* find_trap_page(uint8_t* address) {
auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) {
return address >= trap.second.from_page_start && address < trap.second.from_page_end;
});
if (search != m_traps.end()) {
return &search->second;
}
search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) {
return address >= trap.second.to_page_start && address < trap.second.to_page_end;
});
if (search != m_traps.end()) {
return &search->second;
}
return nullptr;
}
void add_trap(uint8_t* from, uint8_t* to, size_t len) {
TrapInfo info{};
info.from_page_start = align_down(from, 0x1000);
info.from_page_end = align_up(from + len, 0x1000);
info.from = from;
info.to_page_start = align_down(to, 0x1000);
info.to_page_end = align_up(to + len, 0x1000);
info.to = to;
info.len = len;
m_traps.insert_or_assign(from, std::move(info));
}
private:
std::map<uint8_t*, TrapInfo> m_traps;
PVOID m_trap_veh{};
static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) {
auto exception_code = exp->ExceptionRecord->ExceptionCode;
if (exception_code != EXCEPTION_ACCESS_VIOLATION) {
return EXCEPTION_CONTINUE_SEARCH;
}
std::scoped_lock lock{mutex};
auto* faulting_address = reinterpret_cast<uint8_t*>(exp->ExceptionRecord->ExceptionInformation[1]);
auto* trap = instance->find_trap(faulting_address);
if (trap == nullptr) {
if (instance->find_trap_page(faulting_address) != nullptr) {
return EXCEPTION_CONTINUE_EXECUTION;
} else {
return EXCEPTION_CONTINUE_SEARCH;
}
}
auto* ctx = exp->ContextRecord;
for (size_t i = 0; i < trap->len; i++) {
fix_ip(ctx, trap->from + i, trap->to + i);
}
return EXCEPTION_CONTINUE_EXECUTION;
}
};
std::mutex TrapManager::mutex;
std::unique_ptr<TrapManager> TrapManager::instance;
bool TrapManager::is_destructed = false;
void find_me() {
}
static std::mutex virtual_protect_mutex;
void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function<void()>& run_fn) {
MEMORY_BASIC_INFORMATION find_me_mbi{};
MEMORY_BASIC_INFORMATION from_mbi{};
MEMORY_BASIC_INFORMATION to_mbi{};
VirtualQuery(reinterpret_cast<void*>(find_me), &find_me_mbi, sizeof(find_me_mbi));
VirtualQuery(from, &from_mbi, sizeof(from_mbi));
VirtualQuery(to, &to_mbi, sizeof(to_mbi));
auto new_protect = PAGE_READWRITE;
if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) {
new_protect = PAGE_EXECUTE_READWRITE;
}
auto si = system_info();
auto* from_page_start = align_down(from, si.page_size);
auto* from_page_end = align_up(from + len, si.page_size);
auto* vp_start = reinterpret_cast<uint8_t*>(&VirtualProtect);
auto* vp_end = vp_start + 0x20;
if (!(from_page_end < vp_start || vp_end < from_page_start)) {
new_protect = PAGE_EXECUTE_READWRITE;
}
if (!TrapManager::is_destructed) {
std::scoped_lock lock{TrapManager::mutex};
if (TrapManager::instance == nullptr) {
TrapManager::instance = std::make_unique<TrapManager>();
}
TrapManager::instance->add_trap(from, to, len);
}
// Make sure we aren't working on a different address in the same memory page on a different thread.
std::scoped_lock vp_lock{virtual_protect_mutex};
DWORD from_protect;
DWORD to_protect;
VirtualProtect(from, len, new_protect, &from_protect);
VirtualProtect(to, len, new_protect, &to_protect);
if (run_fn) {
run_fn();
}
VirtualProtect(to, len, to_protect, &to_protect);
VirtualProtect(from, len, from_protect, &from_protect);
}
void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) {
auto* ctx = reinterpret_cast<CONTEXT*>(thread_ctx);
#if SAFETYHOOK_ARCH_X86_64
auto ip = ctx->Rip;
#elif SAFETYHOOK_ARCH_X86_32
auto ip = ctx->Eip;
#endif
if (ip == reinterpret_cast<uintptr_t>(old_ip)) {
ip = reinterpret_cast<uintptr_t>(new_ip);
}
#if SAFETYHOOK_ARCH_X86_64
ctx->Rip = ip;
#elif SAFETYHOOK_ARCH_X86_32
ctx->Eip = ip;
#endif
}
} // namespace safetyhook
#endif