diff --git a/MemoryModule.c b/MemoryModule.c index cabaabb..d8d3f3a 100644 --- a/MemoryModule.c +++ b/MemoryModule.c @@ -56,6 +56,8 @@ typedef struct { BOOL initialized; BOOL isDLL; BOOL isRelocated; + CustomAllocFunc alloc; + CustomFreeFunc free; CustomLoadLibraryFunc loadLibrary; CustomGetProcAddressFunc getProcAddress; CustomFreeLibraryFunc freeLibrary; @@ -115,10 +117,11 @@ CopySections(const unsigned char *data, size_t size, PIMAGE_NT_HEADERS old_heade // uninitialized data section_size = old_headers->OptionalHeader.SectionAlignment; if (section_size > 0) { - dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress, + dest = (unsigned char *)module->alloc(codeBase + section->VirtualAddress, section_size, MEM_COMMIT, - PAGE_READWRITE); + PAGE_READWRITE, + module->userdata); if (dest == NULL) { return FALSE; } @@ -139,10 +142,11 @@ CopySections(const unsigned char *data, size_t size, PIMAGE_NT_HEADERS old_heade } // commit memory block and copy data from dll - dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress, + dest = (unsigned char *)module->alloc(codeBase + section->VirtualAddress, section->SizeOfRawData, MEM_COMMIT, - PAGE_READWRITE); + PAGE_READWRITE, + module->userdata); if (dest == NULL) { return FALSE; } @@ -202,7 +206,7 @@ FinalizeSection(PMEMORYMODULE module, PSECTIONFINALIZEDATA sectionData) { (sectionData->size % module->pageSize) == 0) ) { // Only allowed to decommit whole pages - VirtualFree(sectionData->address, sectionData->size, MEM_DECOMMIT); + module->free(sectionData->address, sectionData->size, MEM_DECOMMIT, module->userdata); } return TRUE; } @@ -429,6 +433,18 @@ BuildImportTable(PMEMORYMODULE module) return result; } +LPVOID MemoryDefaultAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata) +{ + UNREFERENCED_PARAMETER(userdata); + return VirtualAlloc(address, size, allocationType, protect); +} + +BOOL MemoryDefaultFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void* userdata) +{ + UNREFERENCED_PARAMETER(userdata); + return VirtualFree(lpAddress, dwSize, dwFreeType); +} + HCUSTOMMODULE MemoryDefaultLoadLibrary(LPCSTR filename, void *userdata) { HMODULE result; @@ -455,10 +471,12 @@ void MemoryDefaultFreeLibrary(HCUSTOMMODULE module, void *userdata) HMEMORYMODULE MemoryLoadLibrary(const void *data, size_t size) { - return MemoryLoadLibraryEx(data, size, MemoryDefaultLoadLibrary, MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, NULL); + return MemoryLoadLibraryEx(data, size, MemoryDefaultAlloc, MemoryDefaultFree, MemoryDefaultLoadLibrary, MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, NULL); } HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size, + CustomAllocFunc allocMemory, + CustomFreeFunc freeMemory, CustomLoadLibraryFunc loadLibrary, CustomGetProcAddressFunc getProcAddress, CustomFreeLibraryFunc freeLibrary, @@ -535,17 +553,19 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size, // reserve memory for image of library // XXX: is it correct to commit the complete memory region at once? // calling DllEntry raises an exception if we don't... - code = (unsigned char *)VirtualAlloc((LPVOID)(old_header->OptionalHeader.ImageBase), + code = (unsigned char *)allocMemory((LPVOID)(old_header->OptionalHeader.ImageBase), alignedImageSize, MEM_RESERVE | MEM_COMMIT, - PAGE_READWRITE); + PAGE_READWRITE, + userdata); if (code == NULL) { // try to allocate memory at arbitrary position - code = (unsigned char *)VirtualAlloc(NULL, + code = (unsigned char *)allocMemory(NULL, alignedImageSize, MEM_RESERVE | MEM_COMMIT, - PAGE_READWRITE); + PAGE_READWRITE, + userdata); if (code == NULL) { SetLastError(ERROR_OUTOFMEMORY); return NULL; @@ -554,13 +574,15 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size, result = (PMEMORYMODULE)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(MEMORYMODULE)); if (result == NULL) { - VirtualFree(code, 0, MEM_RELEASE); + freeMemory(code, 0, MEM_RELEASE, userdata); SetLastError(ERROR_OUTOFMEMORY); return NULL; } result->codeBase = code; result->isDLL = (old_header->FileHeader.Characteristics & IMAGE_FILE_DLL) != 0; + result->alloc = allocMemory; + result->free = freeMemory; result->loadLibrary = loadLibrary; result->getProcAddress = getProcAddress; result->freeLibrary = freeLibrary; @@ -572,10 +594,11 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size, } // commit memory for headers - headers = (unsigned char *)VirtualAlloc(code, + headers = (unsigned char *)allocMemory(code, old_header->OptionalHeader.SizeOfHeaders, MEM_COMMIT, - PAGE_READWRITE); + PAGE_READWRITE, + userdata); // copy PE header to code memcpy(headers, dos_header, old_header->OptionalHeader.SizeOfHeaders); @@ -724,7 +747,7 @@ void MemoryFreeLibrary(HMEMORYMODULE mod) if (module->codeBase != NULL) { // release memory of library - VirtualFree(module->codeBase, 0, MEM_RELEASE); + module->free(module->codeBase, 0, MEM_RELEASE, module->userdata); } HeapFree(GetProcessHeap(), 0, module); diff --git a/MemoryModule.h b/MemoryModule.h index d8f3938..a728f6b 100644 --- a/MemoryModule.h +++ b/MemoryModule.h @@ -39,6 +39,8 @@ typedef void *HCUSTOMMODULE; extern "C" { #endif +typedef LPVOID (*CustomAllocFunc)(LPVOID, SIZE_T, DWORD, DWORD, void*); +typedef BOOL (*CustomFreeFunc)(LPVOID, SIZE_T, DWORD, void*); typedef HCUSTOMMODULE (*CustomLoadLibraryFunc)(LPCSTR, void *); typedef FARPROC (*CustomGetProcAddressFunc)(HCUSTOMMODULE, LPCSTR, void *); typedef void (*CustomFreeLibraryFunc)(HCUSTOMMODULE, void *); @@ -58,6 +60,8 @@ HMEMORYMODULE MemoryLoadLibrary(const void *, size_t); * Dependencies will be resolved using passed callback methods. */ HMEMORYMODULE MemoryLoadLibraryEx(const void *, size_t, + CustomAllocFunc, + CustomFreeFunc, CustomLoadLibraryFunc, CustomGetProcAddressFunc, CustomFreeLibraryFunc, @@ -117,6 +121,22 @@ int MemoryLoadString(HMEMORYMODULE, UINT, LPTSTR, int); */ int MemoryLoadStringEx(HMEMORYMODULE, UINT, LPTSTR, int, WORD); +/** +* Default implementation of CustomAllocFunc that calls VirtualAlloc +* internally to allocate memory for a library +* +* This is the default as used by MemoryLoadLibrary. +*/ +LPVOID MemoryDefaultAlloc(LPVOID, SIZE_T, DWORD, DWORD, void *); + +/** +* Default implementation of CustomFreeFunc that calls VirtualFree +* internally to free the memory used by a library +* +* This is the default as used by MemoryLoadLibrary. +*/ +BOOL MemoryDefaultFree(LPVOID, SIZE_T, DWORD, void *); + /** * Default implementation of CustomLoadLibraryFunc that calls LoadLibraryA * internally to load an additional libary. diff --git a/example/DllLoader/DllLoader.cpp b/example/DllLoader/DllLoader.cpp index 4ff1a20..973f5b4 100644 --- a/example/DllLoader/DllLoader.cpp +++ b/example/DllLoader/DllLoader.cpp @@ -46,12 +46,48 @@ void LoadFromFile(void) FreeLibrary(handle); } +void* ReadLibrary(long* pSize) { + long read; + void* result; + FILE* fp; + + fp = _tfopen(DLL_FILE, _T("rb")); + if (fp == NULL) + { + _tprintf(_T("Can't open DLL file \"%s\"."), DLL_FILE); + return NULL; + } + + fseek(fp, 0, SEEK_END); + *pSize = ftell(fp); + if (*pSize < 0) + { + fclose(fp); + return NULL; + } + + result = (unsigned char *)malloc(*pSize); + if (result == NULL) + { + return NULL; + } + + fseek(fp, 0, SEEK_SET); + read = fread(result, 1, *pSize, fp); + fclose(fp); + if (read != static_cast(*pSize)) + { + free(result); + return NULL; + } + + return result; +} + void LoadFromMemory(void) { - FILE *fp; - unsigned char *data=NULL; + void *data; long size; - size_t read; HMEMORYMODULE handle; addNumberProc addNumber; HMEMORYRSRC resourceInfo; @@ -59,23 +95,12 @@ void LoadFromMemory(void) LPVOID resourceData; TCHAR buffer[100]; - fp = _tfopen(DLL_FILE, _T("rb")); - if (fp == NULL) + data = ReadLibrary(&size); + if (data == NULL) { - _tprintf(_T("Can't open DLL file \"%s\"."), DLL_FILE); - goto exit; + return; } - fseek(fp, 0, SEEK_END); - size = ftell(fp); - assert(size >= 0); - data = (unsigned char *)malloc(size); - assert(data != NULL); - fseek(fp, 0, SEEK_SET); - read = fread(data, 1, size, fp); - assert(read == static_cast(size)); - fclose(fp); - handle = MemoryLoadLibrary(data, size); if (handle == NULL) { @@ -105,11 +130,161 @@ void LoadFromMemory(void) free(data); } +#define MAX_CALLS 20 + +struct CallList { + int current_alloc_call, current_free_call; + CustomAllocFunc alloc_calls[MAX_CALLS]; + CustomFreeFunc free_calls[MAX_CALLS]; +}; + +LPVOID MemoryFailingAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata) +{ + UNREFERENCED_PARAMETER(address); + UNREFERENCED_PARAMETER(size); + UNREFERENCED_PARAMETER(allocationType); + UNREFERENCED_PARAMETER(protect); + UNREFERENCED_PARAMETER(userdata); + return NULL; +} + +LPVOID MemoryMockAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata) +{ + CallList* calls = (CallList*)userdata; + CustomAllocFunc current_func = calls->alloc_calls[calls->current_alloc_call++]; + assert(current_func != NULL); + return current_func(address, size, allocationType, protect, NULL); +} + +BOOL MemoryMockFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void* userdata) +{ + CallList* calls = (CallList*)userdata; + CustomFreeFunc current_func = calls->free_calls[calls->current_free_call++]; + assert(current_func != NULL); + return current_func(lpAddress, dwSize, dwFreeType, NULL); +} + +void InitFuncs(void** funcs, va_list args) { + for (int i = 0; ; i++) { + assert(i < MAX_CALLS); + funcs[i] = va_arg(args, void*); + if (funcs[i] == NULL) break; + } +} + +void InitAllocFuncs(CallList* calls, ...) { + va_list args; + va_start(args, calls); + InitFuncs((void**)calls->alloc_calls, args); + va_end(args); + calls->current_alloc_call = 0; +} + +void InitFreeFuncs(CallList* calls, ...) { + va_list args; + va_start(args, calls); + InitFuncs((void**)calls->free_calls, args); + va_end(args); + calls->current_free_call = 0; +} + +void InitFreeFunc(CallList* calls, CustomFreeFunc freeFunc) { + for (int i = 0; i < MAX_CALLS; i++) { + calls->free_calls[i] = freeFunc; + } + calls->current_free_call = 0; +} + +void TestFailingAllocation(void *data, long size) { + CallList expected_calls; + HMEMORYMODULE handle; + + InitAllocFuncs(&expected_calls, MemoryFailingAlloc, MemoryFailingAlloc, NULL); + InitFreeFuncs(&expected_calls, NULL); + + handle = MemoryLoadLibraryEx( + data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary, + MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls); + + assert(handle == NULL); + assert(GetLastError() == ERROR_OUTOFMEMORY); + assert(expected_calls.current_free_call == 0); + + MemoryFreeLibrary(handle); + assert(expected_calls.current_free_call == 0); +} + +void TestCleanupAfterFailingAllocation(void *data, long size) { + CallList expected_calls; + HMEMORYMODULE handle; + int free_calls_after_loading; + + InitAllocFuncs(&expected_calls, + MemoryDefaultAlloc, + MemoryDefaultAlloc, + MemoryDefaultAlloc, + MemoryDefaultAlloc, + MemoryFailingAlloc, + NULL); + InitFreeFuncs(&expected_calls, MemoryDefaultFree, NULL); + + handle = MemoryLoadLibraryEx( + data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary, + MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls); + + free_calls_after_loading = expected_calls.current_free_call; + + MemoryFreeLibrary(handle); + assert(expected_calls.current_free_call == free_calls_after_loading); +} + +void TestFreeAfterDefaultAlloc(void *data, long size) { + CallList expected_calls; + HMEMORYMODULE handle; + int free_calls_after_loading; + + // Note: free might get called internally multiple times + InitFreeFunc(&expected_calls, MemoryDefaultFree); + + handle = MemoryLoadLibraryEx( + data, size, MemoryDefaultAlloc, MemoryMockFree, MemoryDefaultLoadLibrary, + MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls); + + assert(handle != NULL); + free_calls_after_loading = expected_calls.current_free_call; + + MemoryFreeLibrary(handle); + assert(expected_calls.current_free_call == free_calls_after_loading + 1); +} + +void TestCustomAllocAndFree(void) +{ + void *data; + long size; + + data = ReadLibrary(&size); + if (data == NULL) + { + return; + } + + _tprintf(_T("Test MemoryLoadLibraryEx after initially failing allocation function\n")); + TestFailingAllocation(data, size); + _tprintf(_T("Test cleanup after MemoryLoadLibraryEx with failing allocation function\n")); + TestCleanupAfterFailingAllocation(data, size); + _tprintf(_T("Test custom free function after MemoryLoadLibraryEx\n")); + TestFreeAfterDefaultAlloc(data, size); + + free(data); +} + int main() { LoadFromFile(); printf("\n\n"); LoadFromMemory(); + printf("\n\n"); + TestCustomAllocAndFree(); return 0; }