#ifndef _SECURITY_REFCOUNT_H_
#define _SECURITY_REFCOUNT_H_
#include <security_utilities/threading.h>
#include <libkern/OSAtomic.h>
namespace Security {
#if DEBUG_REFCOUNTS
# define RCDEBUG_CREATE() secinfo("refcount", "%p: CREATE", this)
# define RCDEBUG(_kind, n) secinfo("refcount", "%p: %s: %d", this, #_kind, n)
#else
# define RCDEBUG_CREATE()
# define RCDEBUG(kind, _args...)
#endif
class RefCount {
public:
RefCount() : mRefCount(0) { RCDEBUG_CREATE(); }
protected:
template <class T> friend class RefPointer;
void ref() const
{
OSAtomicIncrement32(&mRefCount);
RCDEBUG(UP, mRefCount);
}
unsigned int unref() const
{
RCDEBUG(DOWN, mRefCount - 1);
return OSAtomicDecrement32(&mRefCount);
}
unsigned int refCountForDebuggingOnly() const { return mRefCount; }
private:
volatile mutable int32_t mRefCount;
};
template <class T>
class RefPointer {
template <class Sub> friend class RefPointer; public:
RefPointer() : ptr(0) {} RefPointer(const RefPointer& p) { if (p) p->ref(); ptr = p.ptr; }
RefPointer(T *p) { if (p) p->ref(); ptr = p; }
template <class Sub>
RefPointer(const RefPointer<Sub>& p) { if (p) p->ref(); ptr = p.ptr; }
~RefPointer() { release(); }
RefPointer& operator = (const RefPointer& p) { setPointer(p.ptr); return *this; }
RefPointer& operator = (T * p) { setPointer(p); return *this; }
template <class Sub>
RefPointer& operator = (const RefPointer<Sub>& p) { setPointer(p.ptr); return *this; }
T* get () const { _check(); return ptr; } operator T * () const { _check(); return ptr; }
T * operator -> () const { _check(); return ptr; }
T & operator * () const { _check(); return *ptr; }
protected:
void release_internal()
{
if (ptr && ptr->unref() == 0)
{
delete ptr;
ptr = NULL;
}
}
void release()
{
StLock<Mutex> mutexLock(mMutex);
release_internal();
}
void setPointer(T *p)
{
StLock<Mutex> mutexLock(mMutex);
if (p)
{
p->ref();
}
release_internal();
ptr = p;
}
void _check() const { }
T *ptr;
Mutex mMutex;
};
template <class T>
bool operator <(const RefPointer<T> &r1, const RefPointer<T> &r2)
{
T *p1 = r1.get(), *p2 = r2.get();
return p1 && p2 ? *p1 < *p2 : p1 < p2;
}
template <class T>
bool operator ==(const RefPointer<T> &r1, const RefPointer<T> &r2)
{
T *p1 = r1.get(), *p2 = r2.get();
return p1 && p2 ? *p1 == *p2 : p1 == p2;
}
template <class T>
bool operator !=(const RefPointer<T> &r1, const RefPointer<T> &r2)
{
T *p1 = r1.get(), *p2 = r2.get();
return p1 && p2 ? *p1 != *p2 : p1 != p2;
}
}
#endif // !_SECURITY_REFCOUNT_H_