#include "config.h"
#include "Module.h"
#include <WebCore/DelayLoadedModulesEnumerator.h>
#include <WebCore/ImportedFunctionsEnumerator.h>
#include <WebCore/ImportedModulesEnumerator.h>
#include <shlwapi.h>
using namespace WebCore;
namespace WebKit {
bool Module::load()
{
ASSERT(!::PathIsRelativeW(m_path.charactersWithNullTermination()));
m_module = ::LoadLibraryExW(m_path.charactersWithNullTermination(), 0, LOAD_WITH_ALTERED_SEARCH_PATH);
return m_module;
}
void Module::unload()
{
if (!m_module)
return;
::FreeLibrary(m_module);
m_module = 0;
}
static void memcpyToReadOnlyMemory(void* destination, const void* source, size_t size)
{
DWORD originalProtection;
if (!::VirtualProtect(destination, size, PAGE_READWRITE, &originalProtection))
return;
memcpy(destination, source, size);
::VirtualProtect(destination, size, originalProtection, &originalProtection);
}
static const void* const* findFunctionPointerAddress(ImportedModulesEnumeratorBase& modules, const char* importDLLName, const char* importFunctionName)
{
for (; !modules.isAtEnd(); modules.next()) {
if (_stricmp(importDLLName, modules.currentModuleName()))
continue;
for (ImportedFunctionsEnumerator functions = modules.functionsEnumerator(); !functions.isAtEnd(); functions.next()) {
const char* currentFunctionName = functions.currentFunctionName();
if (!currentFunctionName || strcmp(importFunctionName, currentFunctionName))
continue;
return functions.addressOfCurrentFunctionPointer();
}
break;
}
return 0;
}
static const void* const* findFunctionPointerAddress(HMODULE module, const char* importDLLName, const char* importFunctionName)
{
PEImage image(module);
ImportedModulesEnumerator importedModules(image);
if (const void* const* functionPointerAddress = findFunctionPointerAddress(importedModules, importDLLName, importFunctionName))
return functionPointerAddress;
DelayLoadedModulesEnumerator delayLoadedModules(image);
return findFunctionPointerAddress(delayLoadedModules, importDLLName, importFunctionName);
}
void Module::installIATHook(const char* importDLLName, const char* importFunctionName, const void* hookFunction)
{
if (!m_module)
return;
const void* const* functionPointerAddress = findFunctionPointerAddress(m_module, importDLLName, importFunctionName);
if (!functionPointerAddress || *functionPointerAddress == hookFunction)
return;
memcpyToReadOnlyMemory(const_cast<const void**>(functionPointerAddress), &hookFunction, sizeof(hookFunction));
}
void* Module::platformFunctionPointer(const char* functionName) const
{
if (!m_module)
return 0;
return ::GetProcAddress(m_module, functionName);
}
}