interception_win.cc [plain text]
#ifdef _WIN32
#include "interception.h"
#include <windows.h>
namespace __interception {
static void _memset(void *p, int value, size_t sz) {
for (size_t i = 0; i < sz; ++i)
((char*)p)[i] = (char)value;
}
static void _memcpy(void *dst, void *src, size_t sz) {
char *dst_c = (char*)dst,
*src_c = (char*)src;
for (size_t i = 0; i < sz; ++i)
dst_c[i] = src_c[i];
}
static void WriteJumpInstruction(char *jmp_from, char *to) {
ptrdiff_t offset = to - jmp_from - 5;
*jmp_from = '\xE9';
*(ptrdiff_t*)(jmp_from + 1) = offset;
}
static char *GetMemoryForTrampoline(size_t size) {
const int POOL_SIZE = 1024;
static char *pool = NULL;
static size_t pool_used = 0;
if (!pool) {
pool = (char *)VirtualAlloc(NULL, POOL_SIZE, MEM_RESERVE | MEM_COMMIT,
PAGE_EXECUTE_READWRITE);
if (!pool)
return NULL;
_memset(pool, 0xCC , POOL_SIZE);
}
if (pool_used + size > POOL_SIZE)
return NULL;
char *ret = pool + pool_used;
pool_used += size;
return ret;
}
static size_t RoundUpToInstrBoundary(size_t size, char *code) {
size_t cursor = 0;
while (cursor < size) {
switch (code[cursor]) {
case '\x51': case '\x52': case '\x53': case '\x54': case '\x55': case '\x56': case '\x57': case '\x5D': cursor++;
continue;
case '\x6A': cursor += 2;
continue;
case '\xE9': cursor += 5;
continue;
}
switch (*(unsigned short*)(code + cursor)) { case 0xFF8B: case 0xEC8B: case 0xC033: cursor += 2;
continue;
case 0x458B: case 0x5D8B: case 0xEC83: case 0x75FF: cursor += 3;
continue;
case 0xC1F7: case 0x25FF: cursor += 6;
continue;
case 0x3D83: cursor += 7;
continue;
}
switch (0x00FFFFFF & *(unsigned int*)(code + cursor)) {
case 0x24448A: case 0x24448B: case 0x244C8B: case 0x24548B: case 0x24748B: case 0x247C8B: cursor += 4;
continue;
}
__debugbreak();
return 0;
}
return cursor;
}
bool OverrideFunction(uptr old_func, uptr new_func, uptr *orig_old_func) {
#ifdef _WIN64
#error OverrideFunction is not yet supported on x64
#endif
char *old_bytes = (char *)old_func;
size_t head = 5;
if (orig_old_func) {
head = RoundUpToInstrBoundary(head, old_bytes);
if (!head)
return false;
char *trampoline = GetMemoryForTrampoline(head + 5);
if (!trampoline)
return false;
_memcpy(trampoline, old_bytes, head);
WriteJumpInstruction(trampoline + head, old_bytes + head);
*orig_old_func = (uptr)trampoline;
}
DWORD old_prot, unused_prot;
if (!VirtualProtect((void *)old_bytes, head, PAGE_EXECUTE_READWRITE,
&old_prot))
return false;
WriteJumpInstruction(old_bytes, (char *)new_func);
_memset(old_bytes + 5, 0xCC , head - 5);
if (!VirtualProtect((void *)old_bytes, head, old_prot, &unused_prot))
return false;
return true;
}
static const void **InterestingDLLsAvailable() {
const char *InterestingDLLs[] = {"kernel32.dll",
"msvcr110.dll", "msvcr120.dll", NULL};
static void *result[ARRAY_SIZE(InterestingDLLs)] = { 0 };
if (!result[0]) {
for (size_t i = 0, j = 0; InterestingDLLs[i]; ++i) {
if (HMODULE h = GetModuleHandleA(InterestingDLLs[i]))
result[j++] = (void *)h;
}
}
return (const void **)&result[0];
}
static bool GetFunctionAddressInDLLs(const char *func_name, uptr *func_addr) {
*func_addr = 0;
const void **DLLs = InterestingDLLsAvailable();
for (size_t i = 0; *func_addr == 0 && DLLs[i]; ++i)
*func_addr = (uptr)GetProcAddress((HMODULE)DLLs[i], func_name);
return (*func_addr != 0);
}
bool OverrideFunction(const char *name, uptr new_func, uptr *orig_old_func) {
uptr orig_func;
if (!GetFunctionAddressInDLLs(name, &orig_func))
return false;
return OverrideFunction(orig_func, new_func, orig_old_func);
}
}
#endif // _WIN32