#ifndef _H_CSSMALLOC
#define _H_CSSMALLOC
#include <security_utilities/alloc.h>
#include <Security/cssm.h>
#include <cstring>
namespace Security
{
class CssmMemoryFunctions : public PodWrapper<CssmMemoryFunctions, CSSM_MEMORY_FUNCS> {
public:
CssmMemoryFunctions(const CSSM_MEMORY_FUNCS &funcs)
{ *(CSSM_MEMORY_FUNCS *)this = funcs; }
CssmMemoryFunctions() { }
void *malloc(size_t size) const throw(std::bad_alloc);
void free(void *mem) const throw() { free_func(mem, AllocRef); }
void *realloc(void *mem, size_t size) const throw(std::bad_alloc);
void *calloc(uint32 count, size_t size) const throw(std::bad_alloc);
bool operator == (const CSSM_MEMORY_FUNCS &other) const throw()
{ return !memcmp(this, &other, sizeof(*this)); }
};
inline void *CssmMemoryFunctions::malloc(size_t size) const throw(std::bad_alloc)
{
if (void *addr = malloc_func(size, AllocRef))
return addr;
throw std::bad_alloc();
}
inline void *CssmMemoryFunctions::calloc(uint32 count, size_t size) const throw(std::bad_alloc)
{
if (void *addr = calloc_func(count, size, AllocRef))
return addr;
throw std::bad_alloc();
}
inline void *CssmMemoryFunctions::realloc(void *mem, size_t size) const throw(std::bad_alloc)
{
if (void *addr = realloc_func(mem, size, AllocRef))
return addr;
throw std::bad_alloc();
}
class CssmMemoryFunctionsAllocator : public Allocator {
public:
CssmMemoryFunctionsAllocator(const CssmMemoryFunctions &memFuncs) : functions(memFuncs) { }
void *malloc(size_t size) throw(std::bad_alloc);
void free(void *addr) throw();
void *realloc(void *addr, size_t size) throw(std::bad_alloc);
operator const CssmMemoryFunctions & () const throw() { return functions; }
private:
const CssmMemoryFunctions functions;
};
class CssmAllocatorMemoryFunctions : public CssmMemoryFunctions {
public:
CssmAllocatorMemoryFunctions(Allocator &alloc);
CssmAllocatorMemoryFunctions() { AllocRef = NULL ; }
private:
static void *relayMalloc(size_t size, void *ref) throw(std::bad_alloc);
static void relayFree(void *mem, void *ref) throw();
static void *relayRealloc(void *mem, size_t size, void *ref) throw(std::bad_alloc);
static void *relayCalloc(uint32 count, size_t size, void *ref) throw(std::bad_alloc);
static Allocator &allocator(void *ref) throw()
{ return *reinterpret_cast<Allocator *>(ref); }
};
template <class Base, class Wrapper = Base>
class CssmVector {
public:
CssmVector(uint32 &cnt, Base * &vec, Allocator &alloc = Allocator::standard())
: count(cnt), vector(reinterpret_cast<Wrapper * &>(vec)),
allocator(alloc)
{
count = 0;
vector = NULL;
}
~CssmVector() { allocator.free(vector); }
uint32 &count;
Wrapper * &vector;
Allocator &allocator;
public:
Wrapper &operator [] (uint32 ix)
{ assert(ix < count); return vector[ix]; }
void operator += (const Wrapper &add)
{
vector = reinterpret_cast<Wrapper *>(allocator.realloc(vector, (count + 1) * sizeof(Wrapper)));
vector[count++] = add;
}
};
}
#endif //_H_CSSMALLOC