#include <Security/cssmalloc.h>
#include <Security/memutils.h>
#include <Security/globalizer.h>
#include <Security/trackingallocator.h>
#include <stdlib.h>
#include <errno.h>
using LowLevelMemoryUtilities::alignof;
using LowLevelMemoryUtilities::increment;
using LowLevelMemoryUtilities::alignUp;
extern "C" size_t malloc_size(void *);
bool CssmAllocator::operator == (const CssmAllocator &alloc) const throw()
{
return this == &alloc;
}
CssmAllocator::~CssmAllocator()
{
}
struct DefaultCssmAllocator : public CssmAllocator {
void *malloc(size_t size) throw(std::bad_alloc);
void free(void *addr) throw();
void *realloc(void *addr, size_t size) throw(std::bad_alloc);
};
struct SensitiveCssmAllocator : public DefaultCssmAllocator {
void free(void *addr) throw();
void *realloc(void *addr, size_t size) throw(std::bad_alloc);
};
struct DefaultAllocators {
DefaultCssmAllocator standard;
SensitiveCssmAllocator sensitive;
};
static ModuleNexus<DefaultAllocators> defaultAllocators;
CssmAllocator &CssmAllocator::standard(uint32 request)
{
switch (request) {
case normal:
return defaultAllocators().standard;
case sensitive:
return defaultAllocators().sensitive;
default:
CssmError::throwMe(CSSM_ERRCODE_MEMORY_ERROR);
}
}
void *DefaultCssmAllocator::malloc(size_t size) throw(std::bad_alloc)
{
if (void *result = ::malloc(size))
return result;
throw std::bad_alloc();
}
void DefaultCssmAllocator::free(void *addr) throw()
{
::free(addr);
}
void *DefaultCssmAllocator::realloc(void *addr, size_t newSize) throw(std::bad_alloc)
{
if (void *result = ::realloc(addr, newSize))
return result;
throw std::bad_alloc();
}
void SensitiveCssmAllocator::free(void *addr) throw()
{
memset(addr, 0, malloc_size(addr));
DefaultCssmAllocator::free(addr);
}
void *SensitiveCssmAllocator::realloc(void *addr, size_t newSize) throw(std::bad_alloc)
{
size_t oldSize = malloc_size(addr);
if (newSize < oldSize)
memset(increment(addr, newSize), 0, oldSize - newSize);
return DefaultCssmAllocator::realloc(addr, newSize);
}
TrackingAllocator::~TrackingAllocator()
{
AllocSet::iterator first = mAllocSet.begin(), last = mAllocSet.end();
for (; first != last; ++first)
mAllocator.free(*first);
}
void *CssmMemoryFunctionsAllocator::malloc(size_t size) throw(std::bad_alloc)
{ return functions.malloc(size); }
void CssmMemoryFunctionsAllocator::free(void *addr) throw()
{ return functions.free(addr); }
void *CssmMemoryFunctionsAllocator::realloc(void *addr, size_t size) throw(std::bad_alloc)
{ return functions.realloc(addr, size); }
CssmAllocatorMemoryFunctions::CssmAllocatorMemoryFunctions(CssmAllocator &alloc)
{
AllocRef = &alloc;
malloc_func = relayMalloc;
free_func = relayFree;
realloc_func = relayRealloc;
calloc_func = relayCalloc;
}
void *CssmAllocatorMemoryFunctions::relayMalloc(size_t size, void *ref) throw(std::bad_alloc)
{ return allocator(ref).malloc(size); }
void CssmAllocatorMemoryFunctions::relayFree(void *mem, void *ref) throw()
{ allocator(ref).free(mem); }
void *CssmAllocatorMemoryFunctions::relayRealloc(void *mem, size_t size, void *ref) throw(std::bad_alloc)
{ return allocator(ref).realloc(mem, size); }
void *CssmAllocatorMemoryFunctions::relayCalloc(uint32 count, size_t size, void *ref) throw(std::bad_alloc)
{
void *mem = allocator(ref).malloc(size * count);
memset(mem, 0, size * count);
return mem;
}
void *CssmHeap::operator new (size_t size, CssmAllocator *alloc) throw(std::bad_alloc)
{
if (alloc == NULL)
alloc = &CssmAllocator::standard();
size = alignUp(size, alignof<CssmAllocator *>());
size_t totalSize = size + sizeof(CssmAllocator *);
void *addr = alloc->malloc(totalSize);
*(CssmAllocator **)increment(addr, size) = alloc;
return addr;
}
void CssmHeap::operator delete (void *addr, size_t size, CssmAllocator *alloc) throw()
{
alloc->free(addr); }
void CssmHeap::operator delete (void *addr, size_t size) throw()
{
void *end = increment(addr, alignUp(size, alignof<CssmAllocator *>()));
(*(CssmAllocator **)end)->free(addr);
}