#include <string.h>
#include <errno.h> // system call error numbers
#include <unistd.h> // for select call
#include <stdlib.h> // for calloc()
#include <poll.h>
#include <sys/time.h> // for struct timeval
#include <libkern/OSAtomic.h>
#include <sys/socket.h>
#include <netdb.h>
#include "DSCThread.h" // for GetCurThreadRunState()
#include "DSTCPEndpoint.h"
#ifdef DSSERVERTCP
#include "CLog.h"
#else
#define DbgLog(...)
#endif
#include "SharedConsts.h" // for sComData
#include "DirServicesConst.h"
#include "DSTCPEndian.h"
#include "DSSwapUtils.h"
int32_t DSTCPEndpoint::mMessageID = 0;
static uint8 paramBlob[] = { \
0x30, 0x52, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x03, 0x30, 0x46, 0x02, 0x41,\
0x00, 0xa0, 0xd4, 0x42, 0xd5, 0x68, 0x08, 0x94, 0xc9, 0xef, 0xb7, 0x18, 0x9c, 0x0b, 0x72, 0x53,\
0xac, 0x8a, 0x7b, 0xc2, 0x40, 0x17, 0x96, 0x29, 0xd1, 0xf2, 0x96, 0xe8, 0x2b, 0x4e, 0x48, 0xaf,\
0x59, 0xbe, 0x29, 0xc4, 0x9b, 0x52, 0xda, 0x05, 0x18, 0x29, 0x73, 0xff, 0xd5, 0x26, 0x47, 0x53,\
0x54, 0x79, 0xf4, 0x39, 0x96, 0x6f, 0x61, 0x5e, 0xe6, 0xfc, 0x92, 0x7d, 0xf4, 0x20, 0x6e, 0xa9,\
0xa3, 0x02, 0x01, 0x02 };
#pragma mark **** Class Methods ****
#pragma mark **** Instance Methods ****
DSTCPEndpoint::DSTCPEndpoint ( const UInt32 inOpenTimeout,
const UInt32 inRWTimeout,
int inSocket ) :
mRemoteHostIPAddr (0),
mConnectFD (inSocket),
mWeHaveClosed (false),
mOpenTimeout (inOpenTimeout),
mRWTimeout (inRWTimeout),
mDefaultTimeout(inRWTimeout),
fKeyState(eKeyStateAcceptClientKey)
{
memset( &mMySockAddr, 0, sizeof(mMySockAddr) );
mRemoteHostIPString[0] = '\0';
memset( &mRemoteSockAddr, 0, sizeof(mRemoteSockAddr) );
bzero(&fPrivateKey, sizeof(fPrivateKey));
bzero(&fPublicKey, sizeof(fPublicKey));
bzero(&fDerivedKey, sizeof(fDerivedKey));
if ( cdsaCspAttach(&fcspHandle) == CSSM_OK )
{
fParamBlock.Data = paramBlob;
fParamBlock.Length = sizeof(paramBlob);
}
}
DSTCPEndpoint::~DSTCPEndpoint ( void )
{
try
{
if ( mWeHaveClosed == false )
{
DoTCPCloseSocket( mConnectFD );
}
}
catch( ... )
{
}
cdsaFreeKey( fcspHandle, &fPrivateKey );
cdsaFreeKey( fcspHandle, &fPublicKey );
cdsaFreeKey( fcspHandle, &fDerivedKey );
cdsaCspDetach( fcspHandle );
}
SInt32 DSTCPEndpoint::ConnectTo ( struct addrinfo *inAddrInfo )
{
int err = eDSNoErr;
int result = 0;
int sockfd;
time_t timesUp;
struct sockaddr *serverAddr = inAddrInfo->ai_addr;
int rc = eDSNoErr;
bool releaseZeroFD = false;
do {
sockfd = DoTCPOpenSocket();
if ( sockfd < 0 )
{
return( eDSTCPSendError );
}
mConnectFD = sockfd;
timesUp = ::time(NULL) + mOpenTimeout;
while ( ::time(NULL) < timesUp )
{
result = ::connect( mConnectFD, serverAddr, serverAddr->sa_len );
if ( result == -1 )
{
err = errno;
switch ( err )
{
case ETIMEDOUT:
continue; break;
case ECONNREFUSED:
LOG2( kStdErr, "ConnectTo: connect() error: %d, %s", err, strerror(err) );
return( eDSIPUnreachable );
break;
default: LOG2( kStdErr, "ConnectTo: connect() error: %d, %s", err, strerror(err) );
return( eDSTCPSendError );
break;
} }
else
{ if ( (sockfd != 0) && (releaseZeroFD) ) {
int rcSock = 0;
rcSock = ::close( 0 );
if ( rcSock == -1 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPCloseSocket: close() on unused socket 0 failed with error %d: %s", err, strerror(err) );
#else
LOG2( kStdErr, "DoTCPCloseSocket: close() on unused socket 0 failed with error %d: %s", err, strerror(err) );
#endif
}
else
{
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPCloseSocket: close() on unused socket 0" );
#else
LOG( kStdErr, "DoTCPCloseSocket: close() on unused socket 0" );
#endif
}
}
break;
}
}
if (sockfd == 0)
{
releaseZeroFD = true;
}
} while (sockfd == 0);
if ( result == 0 )
{
memcpy(&mRemoteSockAddr, &serverAddr, sizeof(mRemoteSockAddr));
rc = this->SetSocketOption( mConnectFD, SO_NOSIGPIPE );
if ( rc != 0 )
{
return( eDSTCPSendError );
}
LOG2( kStdErr, "Established TCP connection to %d on port %d.", inIPAddress, inPort );
return(eDSNoErr);
}
else
{
LOG2( kStdErr, "Unable to connect to %d on port %d.", inIPAddress, inPort );
return(eDSServerTimeout);
}
}
void DSTCPEndpoint::GetReverseAddressString ( char *ioBuffer,
const int inBufferLen) const
{
if ( ioBuffer != NULL )
{
::strncpy (ioBuffer, mRemoteHostIPString, inBufferLen);
}
}
Boolean DSTCPEndpoint::Connected ( void ) const
{
struct pollfd fdToPoll;
int result;
if ( mConnectFD == -1 )
return false;
fdToPoll.fd = mConnectFD;
fdToPoll.events = POLLSTANDARD;
fdToPoll.revents = 0;
result = poll( &fdToPoll, 1, 0 );
if ( result == -1 )
return false;
return ( (fdToPoll.revents & POLLHUP) == 0 );
}
void DSTCPEndpoint::CloseConnection ( void )
{
if ( mConnectFD > 0 )
{
int err = this->DoTCPCloseSocket( mConnectFD );
if ( err == eDSNoErr )
{
mConnectFD = 0;
mWeHaveClosed = true;
}
}
}
int DSTCPEndpoint::DoTCPOpenSocket (void)
{
int err;
int sockfd;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "Open socket." );
#else
LOG( kStdErr, "Open socket." );
#endif
sockfd = ::socket( AF_INET, SOCK_STREAM, 0 );
if ( sockfd == -1 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPOpenSocket: socket() error %d: %s", err, strerror(err) );
#else
LOG2( kStdErr, "DoTCPOpenSocket: Unable to open a socket with error %d: %s", err, strerror(err) );
#endif
}
err = errno;
if (err != 0)
{
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPOpenSocket: socket error %d: %s with sockfd %d", err, strerror(err), sockfd );
#else
LOG3( kStdErr, "DoTCPOpenSocket: socket error %d: %s with sockfd %d", err, strerror(err), sockfd );
#endif
}
return( sockfd );
}
int DSTCPEndpoint::SetSocketOption ( const int inSocket, const int inSocketOption )
{
int rc = 0;
int err = 0;
int val = 1;
int len = sizeof(val);
if ( inSocket != 0 )
{
if ( inSocket != mConnectFD )
{
#ifdef DSSERVERTCP
ErrLog( kLogTCPEndpoint, "SetSocketOption: invalid socket: %d", inSocket );
#else
LOG1( kStdErr, "SetSocketOption: invalid socket: %d", inSocket );
#endif
return( -1 );
}
rc = ::setsockopt( inSocket, SOL_SOCKET, inSocketOption, &val, len );
if ( rc != 0 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogError, "Unable to set socket option: Message: \"%s\", Error: %d", strerror(err), err );
#else
LOG2( kStdErr, "Unable to set socket option: Message: \"%s\", Error: %d", strerror(err), err );
#endif
}
}
return( 0 );
}
int DSTCPEndpoint::DoTCPCloseSocket ( const int inSockFD )
{
int err = eDSNoErr;
int rc = 0;
if ( inSockFD <= 0 )
{
return( eDSNoErr );
}
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "Close socket." );
#endif
rc = ::close( inSockFD );
if ( rc == -1 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPCloseSocket: close() on socket %d failed with error %d: %s", inSockFD, err, strerror(err) );
#else
LOG3( kStdErr, "DoTCPCloseSocket: close() on socket %d failed with error %d: %s", inSockFD, err, strerror(err) );
#endif
}
return( err );
}
UInt32 DSTCPEndpoint::DoTCPRecvFrom ( void *ioBuffer, const UInt32 inBufferSize )
{
int rc;
int err;
int bytesRead = 0;
fd_set readSet;
struct timeval tvTimeout = { mRWTimeout, 0 };
struct timeval tvTimeoutTime = { mRWTimeout, 0 };
time_t timeoutTime;
timeoutTime = ::time( NULL ) + mRWTimeout;
::gettimeofday (&tvTimeoutTime, NULL);
tvTimeoutTime.tv_sec += mRWTimeout;
do {
FD_ZERO( &readSet );
FD_SET( mConnectFD, &readSet );
rc = ::select( mConnectFD+1, &readSet, NULL, NULL, &tvTimeout );
if ( (rc == -1) && (EINTR == errno) )
{
struct timeval tvNow;
::gettimeofday( &tvNow, NULL );
timersub( &tvTimeoutTime, &tvNow, &tvTimeout );
if ( tvTimeout.tv_sec < 0 )
{
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPRecvFrom: connection timeout?" );
#else
LOG( kStdErr, "DoTCPRecvFrom: connection timeout?" );
#endif
throw( (SInt32)eDSTCPReceiveError );
}
}
} while ( (rc == -1) && (EINTR == errno) );
if ( rc == 0 )
{
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPRecvFrom: timed out waiting for response." );
#else
LOG( kStdErr, "DoTCPRecvFrom: timed out waiting for response." );
#endif
throw( (SInt32)kTimeoutError );
}
else if ( rc == -1 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPRecvFrom: select() error %d: %s", err, strerror(err) );
#else
LOG2( kStdErr, "DoTCPRecvFrom: select() error %d: %s", err, strerror(err) );
#endif
throw((SInt32)eDSTCPReceiveError);
}
else if ( FD_ISSET(mConnectFD, &readSet) )
{
do
{
bytesRead = ::recvfrom( mConnectFD, ioBuffer, inBufferSize, MSG_WAITALL, NULL, NULL );
} while ( (bytesRead == -1) && (errno == EAGAIN) );
if ( bytesRead == 0 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPRecvFrom: connection closed by peer - error is %d", err );
#else
LOG1( kStdErr, "DoTCPRecvFrom: connection closed by peer - error is %d", err );
#endif
throw( (SInt32)eDSTCPReceiveError );
}
else if ( bytesRead == -1 )
{
err = errno;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPRecvFrom: recvfrom error %d: %s", err, strerror(err) );
#else
LOG2( kStdErr, "DoTCPRecvFrom: recvfrom error %d: %s", err, strerror(err) );
#endif
throw( (SInt32)eDSTCPReceiveError );
}
else
{
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "DoTCPRecvFrom(): received %d bytes with endpoint %ld and connectFD %d", bytesRead, (long)this, mConnectFD );
#else
LOG3( kStdErr, "DoTCPRecvFrom(): received %d bytes with endpoint %ld and connectFD %d", bytesRead, (long)this, mConnectFD );
#endif
}
}
return( (UInt32)bytesRead );
}
SInt32 DSTCPEndpoint::SyncToMessageBody(const Boolean inStripLeadZeroes, UInt32 *outBuffLen)
{
UInt32 index = 0;
UInt32 readBytes = 0;
UInt32 newLen = 0;
UInt32 curIndex = kDSTCPEndpointMessageTagSize;
char *ourBuffer;
UInt32 buffLen = 0;
SInt32 result = eDSNoErr;
ourBuffer = (char *) calloc(kDSTCPEndpointMaxMessageSize, 1);
try
{
readBytes = DoTCPRecvFrom(ourBuffer, kDSTCPEndpointMessageTagSize);
if (readBytes != kDSTCPEndpointMessageTagSize)
{
free(ourBuffer);
*outBuffLen = 0;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "SyncToMessageBody: attempted read of %d bytes failed with %d bytes read", kDSTCPEndpointMessageTagSize, readBytes );
#else
LOG2( kStdErr, "SyncToMessageBody: attempted read of %d bytes failed with %d bytes read", kDSTCPEndpointMessageTagSize, readBytes );
#endif
return eDSTCPReceiveError;
}
}
catch( SInt32 err )
{
if (ourBuffer != nil)
{
free(ourBuffer);
}
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "SyncToMessageBody: attempted read of %d bytes failed in DoTCPRecvFrom with error %d", kDSTCPEndpointMessageTagSize, err );
#else
LOG2( kStdErr, "SyncToMessageBody: attempted read of %d bytes failed in DoTCPRecvFrom with error %d", kDSTCPEndpointMessageTagSize, err );
#endif
return eDSTCPReceiveError;
}
if (inStripLeadZeroes)
{
for ( index=0; (index < kDSTCPEndpointMessageTagSize) && (ourBuffer[index] == 0x00); index++ )
{
readBytes--;
}
try
{
while ( (readBytes < kDSTCPEndpointMessageTagSize) && (curIndex < kDSTCPEndpointMaxMessageSize) )
{
newLen = DoTCPRecvFrom(ourBuffer+curIndex, 1);
if (newLen != 1)
{
free(ourBuffer);
*outBuffLen = 0;
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "SyncToMessageBody: align frame by skipping leading zeroes - attempted read of one byte failed with %d bytes read", newLen );
#else
LOG1( kStdErr, "SyncToMessageBody: align frame by skipping leading zeroes - attempted read of one byte failed with %d bytes read", newLen );
#endif
return eDSTCPReceiveError;
}
if (ourBuffer[curIndex] != 0x00)
{
readBytes++;
}
curIndex++;
}
}
catch( SInt32 err )
{
if (ourBuffer != nil)
{
free(ourBuffer);
}
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "SyncToMessageBody: align frame by skipping leading zeroes - failed in DoTCPRecvFrom with error %l", err );
#else
LOG1( kStdErr, "SyncToMessageBody: align frame by skipping leading zeroes - failed in DoTCPRecvFrom with error %l", err );
#endif
return eDSTCPReceiveError;
}
}
if ( (readBytes == kDSTCPEndpointMessageTagSize) && (strncmp(ourBuffer+curIndex-kDSTCPEndpointMessageTagSize,"DSPX",kDSTCPEndpointMessageTagSize) == 0) )
{
try
{
newLen = DoTCPRecvFrom(&buffLen , 4);
if (newLen != 4) {
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "SyncToMessageBody: get the buffer length - attempted read of four bytes failed with %d bytes read", newLen );
#else
LOG1( kStdErr, "SyncToMessageBody: get the buffer length - attempted read of four bytes failed with %d bytes read", newLen );
#endif
*outBuffLen = 0;
}
else
{
*outBuffLen = ntohl(buffLen);
}
}
catch( SInt32 err )
{
if (ourBuffer != nil)
{
free(ourBuffer);
}
#ifdef DSSERVERTCP
DbgLog( kLogTCPEndpoint, "SyncToMessageBody: get the buffer length - failed in DoTCPRecvFrom with error %l", err );
#else
LOG1( kStdErr, "SyncToMessageBody: get the buffer length - failed in DoTCPRecvFrom with error %l", err );
#endif
return eDSTCPReceiveError;
}
}
free(ourBuffer);
return result;
}
SInt32 DSTCPEndpoint::SendBuffer ( void *inBuffer, UInt32 inLength )
{
SInt32 result = eDSNoErr;
UInt32 sendBuffLen = sizeof("DSPX") + sizeof(UInt32) + inLength;
char *sendBuff = (char *) calloc( sendBuffLen, sizeof(char) );
uint32_t offset = 0;
bcopy( "DSPX", sendBuff, kDSTCPEndpointMessageTagSize );
*((UInt32 *) (sendBuff + kDSTCPEndpointMessageTagSize)) = htonl( inLength );
bcopy( inBuffer, sendBuff + kDSTCPEndpointMessageTagSize + sizeof(UInt32), inLength);
do
{
ssize_t sentBytes = send( mConnectFD, sendBuff + offset, sendBuffLen - offset, 0 );
if ( sentBytes < 0 ) {
switch ( errno ) {
case EINTR:
case EAGAIN:
break;
default:
DSFree( sendBuff );
return eDSTCPSendError;
}
}
else {
offset += sentBytes;
}
if ( offset < sendBuffLen ) {
fd_set writeSet;
struct timeval tvTimeout = { 10, 0 };
FD_ZERO( &writeSet );
FD_SET( mConnectFD, &writeSet );
select( mConnectFD+1, NULL, &writeSet, NULL, &tvTimeout );
continue;
}
break;
} while ( 1 );
DSFree( sendBuff );
return result;
}
SInt32 DSTCPEndpoint::SendMessage( sComData *inMsg )
{
UInt32 messageSize = 0;
sComProxyData *inProxyMsg = nil;
SInt32 sendResult = eDSNoErr;
void *outBuffer = NULL;
UInt32 outLength = 0;
inProxyMsg = AllocToProxyStruct( (sComData *)inMsg );
inProxyMsg->fDataSize = inProxyMsg->fDataLength;
messageSize = sizeof(sComProxyData) + inProxyMsg->fDataLength;
inProxyMsg->fIPAddress = mRemoteHostIPAddr;
inProxyMsg->fPID = ntohs( mRemoteSockAddr.sin_port );
inProxyMsg->fMsgID = OSAtomicIncrement32( &mMessageID );
if ( inProxyMsg->type.msgt_translate != 2 ) {
SwapProxyMessage( inProxyMsg, kDSSwapHostToNetworkOrder );
}
ProcessData( true, inProxyMsg, messageSize, outBuffer, outLength );
sendResult = SendBuffer( outBuffer, outLength );
DSFree( inProxyMsg );
DSFree( outBuffer );
return sendResult;
}
SInt32 DSTCPEndpoint::GetReplyMessage( sComData **outMsg )
{
SInt32 siResult = eDSNoErr;
UInt32 buffLen = 0;
UInt32 readBytes = 0;
void *inBuffer = nil;
UInt32 inLength = 0;
sComProxyData *outProxyMsg = nil;
siResult = SyncToMessageBody(true, &inLength);
if ( (siResult == eDSNoErr) && (inLength != 0) )
{
try
{
inBuffer = (void *)calloc(1,inLength);
readBytes = DoTCPRecvFrom(inBuffer, inLength);
if (readBytes != inLength)
{
LOG( kStdErr, "GetServerReply: Couldn't read entire message block" );
siResult = eDSTCPReceiveError;
}
else
{
void *tmpOutMsg = nil;
ProcessData( false, inBuffer, inLength, tmpOutMsg, buffLen );
outProxyMsg = (sComProxyData *)tmpOutMsg;
if (buffLen == 0)
{
free(outProxyMsg);
outProxyMsg = (sComProxyData *)inBuffer;
inBuffer = nil;
buffLen = inLength;
}
}
}
catch( SInt32 err )
{
siResult = eDSTCPReceiveError;
}
}
if (inBuffer != nil)
{
free(inBuffer);
inBuffer = nil;
}
if (outProxyMsg != nil)
{
if ( outProxyMsg->type.msgt_translate != 2 ) {
SwapProxyMessage( outProxyMsg, kDSSwapNetworkToHostOrder );
}
*outMsg = AllocFromProxyStruct( outProxyMsg );
free(outProxyMsg);
outProxyMsg = nil;
}
return( siResult );
}
SInt32 DSTCPEndpoint::ClientNegotiateKey( void )
{
SInt32 result;
void *recvBuff = NULL;
UInt32 recvBuffLen = 0;
void *sendBuff = NULL;
UInt32 sendBuffLen = 0;
fKeyState = eKeyStateSendPublicKey;
do
{
result = ProcessData( true, recvBuff, recvBuffLen, sendBuff, sendBuffLen );
DSFree( recvBuff );
if ( fKeyState == eKeyStateValidKey )
break;
if ( result == eDSNoErr ) {
result = SendBuffer( sendBuff, sendBuffLen );
DSFree( sendBuff );
}
if ( result == eDSNoErr ) {
result = SyncToMessageBody( true, &recvBuffLen );
}
if ( result == eDSNoErr ) {
recvBuff = (UInt8 *) calloc( recvBuffLen, sizeof(char) );
UInt32 readBytes = DoTCPRecvFrom( recvBuff, recvBuffLen );
if ( readBytes != recvBuffLen ) {
result = eDSCorruptBuffer;
}
}
} while ( result == eDSNoErr );
DSFree( sendBuff );
DSFree( recvBuff );
return result;
}
SInt32 DSTCPEndpoint::ServerNegotiateKey( void *dataBuff, UInt32 dataBuffLen )
{
void *sendBuff = NULL;
UInt32 sendBuffLen = 0;
SInt32 result = ProcessData( true, dataBuff, dataBuffLen, sendBuff, sendBuffLen );
if ( result == eDSNoErr ) {
if ( sendBuffLen > 0 ) {
result = SendBuffer( sendBuff, sendBuffLen );
}
}
DSFree( sendBuff );
return result;
}
sComProxyData* DSTCPEndpoint::AllocToProxyStruct ( sComData *inDataMsg )
{
sComProxyData *outProxyDataMsg = nil;
int objIndex;
if (inDataMsg != nil)
{
outProxyDataMsg = (sComProxyData *)calloc( 1, sizeof(sComProxyData) + inDataMsg->fDataSize );
outProxyDataMsg->type = inDataMsg->type;
outProxyDataMsg->fMsgID = inDataMsg->fMsgID;
outProxyDataMsg->fDataSize = inDataMsg->fDataSize;
outProxyDataMsg->fDataLength = inDataMsg->fDataLength;
bcopy( inDataMsg->obj, outProxyDataMsg->obj, kObjSize + inDataMsg->fDataSize );
for ( objIndex = 0; objIndex < 10; objIndex++ )
{
if ( outProxyDataMsg->obj[ objIndex ].offset != 0 )
{
outProxyDataMsg->obj[ objIndex ].offset -= sizeof(sComData) - sizeof(sComProxyData);
}
}
}
return ( outProxyDataMsg );
}
sComData* DSTCPEndpoint::AllocFromProxyStruct ( sComProxyData *inProxyDataMsg )
{
sComData *outDataMsg = nil;
int objIndex;
if (inProxyDataMsg != nil)
{
outDataMsg = (sComData *)calloc( 1, sizeof(sComData) + inProxyDataMsg->fDataSize );
outDataMsg->type = inProxyDataMsg->type;
outDataMsg->fMsgID = inProxyDataMsg->fMsgID;
outDataMsg->fPID = inProxyDataMsg->fPID;
outDataMsg->fDataSize = inProxyDataMsg->fDataSize;
outDataMsg->fDataLength = inProxyDataMsg->fDataLength;
bcopy( inProxyDataMsg->obj, outDataMsg->obj, kObjSize + inProxyDataMsg->fDataSize );
for ( objIndex = 0; objIndex < 10; objIndex++ )
{
if ( outDataMsg->obj[ objIndex ].offset != 0 )
{
outDataMsg->obj[ objIndex ].offset += sizeof(sComData) - sizeof(sComProxyData);
}
}
outDataMsg->fUID = outDataMsg->fEffectiveUID = (uid_t) -2;
}
return ( outDataMsg );
}
SInt32 DSTCPEndpoint::ProcessData( bool bEncrypt, void *inBuffer, UInt32 inBufferLen, void *&outBuffer, UInt32 &outBufferLen )
{
SInt32 result = eDSCorruptBuffer;
CSSM_DATA plainText = { 0, NULL };
CSSM_DATA cipherText = { 0, NULL };
switch ( fKeyState )
{
case eKeyStateSendPublicKey:
if ( cdsaDhGenerateKeyPair(fcspHandle, &fPublicKey, &fPrivateKey, DH_KEY_SIZE, &fParamBlock, NULL) == CSSM_OK )
{
outBufferLen = sizeof(FourCharCode) + fPublicKey.KeyData.Length;
char *tempPtr = (char *) calloc( 1, outBufferLen );
*((FourCharCode *) tempPtr) = htonl( DSTCPAuthTag );
memcpy( tempPtr + sizeof(FourCharCode), fPublicKey.KeyData.Data, fPublicKey.KeyData.Length );
outBuffer = tempPtr;
result = eDSNoErr;
}
DbgLog( kLogDebug, "DSTCPEndpointProcessData - Send Public Key - generate key pair - %s",
(result == eDSNoErr ? "succeeded" : "failed") );
fKeyState = eKeyStateGenerateChallenge;
break;
case eKeyStateGenerateChallenge:
if ( cdsaDhKeyExchange(fcspHandle, &fPrivateKey, inBuffer, inBufferLen, &fDerivedKey, DERIVE_KEY_SIZE, DERIVE_KEY_ALG) == CSSM_OK )
{
fChallengeValue = arc4random();
uint32_t temp = htonl( fChallengeValue );
plainText.Data = (uint8_t *) &temp;
plainText.Length = sizeof(temp);
if ( cdsaEncrypt(fcspHandle, &fDerivedKey, &plainText, &cipherText) == CSSM_OK )
{
outBuffer = cipherText.Data;
outBufferLen = cipherText.Length;
result = eDSNoErr;
}
fChallengeValue++; }
DbgLog( kLogDebug, "DSTCPEndpointProcessData - Generate Challenge - challenge creation - %s",
(result == eDSNoErr ? "succeeded" : "failed") );
fKeyState = eKeyStateAcceptResponse;
break;
case eKeyStateAcceptResponse:
cipherText.Data = (uint8_t *) inBuffer;
cipherText.Length = inBufferLen;
plainText.Data = NULL;
plainText.Length = 0;
if ( cdsaDecrypt(fcspHandle, &fDerivedKey, &cipherText, &plainText) == CSSM_OK )
{
if ( plainText.Data != NULL && plainText.Length == sizeof(uint32_t) && fChallengeValue == ntohl(*((uint32_t*) plainText.Data)) )
{
fKeyState = eKeyStateValidKey;
result = eDSNoErr;
}
DSFree( plainText.Data );
}
DbgLog ( kLogDebug, "DSTCPEndpointProcessData - Accept Response - response was %s",
(result == eDSNoErr ? "correct" : "incorrect") );
break;
case eKeyStateAcceptClientKey:
if ( inBufferLen > sizeof(FourCharCode) )
{
char *tempPtr = (char *) inBuffer;
if ( DSTCPAuthTag == ntohl(*((FourCharCode *) tempPtr)) )
{
tempPtr += sizeof(FourCharCode);
inBufferLen -= sizeof(FourCharCode);
if ( cdsaDhGenerateKeyPair(fcspHandle, &fPublicKey, &fPrivateKey, DH_KEY_SIZE, &fParamBlock, NULL) == CSSM_OK )
{
if ( cdsaDhKeyExchange(fcspHandle, &fPrivateKey, tempPtr, inBufferLen, &fDerivedKey, DERIVE_KEY_SIZE,
DERIVE_KEY_ALG) == CSSM_OK )
{
outBufferLen = fPublicKey.KeyData.Length;
outBuffer = calloc( outBufferLen, sizeof(char) );
bcopy( fPublicKey.KeyData.Data, outBuffer, outBufferLen );
result = eDSNoErr;
}
}
}
}
DbgLog( kLogDebug, "DSTCPEndpointProcessData - Accept Client Key - %s", (result == eDSNoErr ? "succeed" : "failed") );
fKeyState = eKeyStateGenerateResponse;
break;
case eKeyStateGenerateResponse:
if ( inBufferLen != 0 )
{
cipherText.Data = (uint8_t *) inBuffer;
cipherText.Length = inBufferLen;
if ( cdsaDecrypt(fcspHandle, &fDerivedKey, &cipherText, &plainText) == CSSM_OK )
{
if ( plainText.Data != NULL && plainText.Length == 4 )
{
uint32_t temp = ntohl( *((uint32_t *) plainText.Data) ) + 1;
(*(uint32_t *) plainText.Data) = htonl( temp );
cipherText.Data = NULL;
cipherText.Length = 0;
if ( cdsaEncrypt(fcspHandle, &fDerivedKey, &plainText, &cipherText) == CSSM_OK )
{
outBuffer = cipherText.Data;
outBufferLen = cipherText.Length;
result = eDSNoErr;
}
DSFree ( plainText.Data );
}
}
}
DbgLog( kLogDebug, "DSTCPEndpointProcessData - Generate Response - %s", (result == eDSNoErr ? "succeed" : "failed") );
fKeyState = eKeyStateValidKey;
break;
case eKeyStateValidKey:
outBufferLen = 0;
if ( fDerivedKey.KeyData.Data != NULL )
{
if ( bEncrypt == true )
{
plainText.Data = (uint8_t *)inBuffer;
plainText.Length = inBufferLen;
if ( cdsaEncrypt(fcspHandle, &fDerivedKey, &plainText, &cipherText) == CSSM_OK )
{
outBuffer = cipherText.Data;
outBufferLen = cipherText.Length;
DbgLog( kLogDebug, "DSTCPEndpointProcessData - Encrypted data - length %d", outBufferLen );
result = eDSNoErr;
}
}
else
{
cipherText.Data = (uint8_t *) inBuffer;
cipherText.Length = inBufferLen;
if ( cdsaDecrypt(fcspHandle, &fDerivedKey, &cipherText, &plainText) == CSSM_OK )
{
outBuffer = plainText.Data;
outBufferLen = plainText.Length;
DbgLog( kLogDebug, "DSTCPEndpointProcessData - Decrypted data - length %d", outBufferLen );
result = eDSNoErr;
}
}
}
break;
}
return result;
}