#include "ssl.h"
#include "sslRecord.h"
#include "sslMemory.h"
#include "cryptType.h"
#include "sslContext.h"
#include "sslAlertMessage.h"
#include "sslDebug.h"
#include "sslUtils.h"
#include "sslDigests.h"
#include <string.h>
#include <assert.h>
#define SSL_ALLOW_UNNOTICED_DISCONNECT 1
OSStatus
SSLReadRecord(SSLRecord *rec, SSLContext *ctx)
{ OSStatus err;
size_t len, contentLen;
UInt8 *charPtr;
SSLBuffer readData, cipherFragment;
size_t head=5;
int skipit=0;
#if ENABLE_DTLS
if(ctx->isDTLS)
head+=8;
#endif
if (!ctx->partialReadBuffer.data || ctx->partialReadBuffer.length < head)
{ if (ctx->partialReadBuffer.data)
if ((err = SSLFreeBuffer(&ctx->partialReadBuffer, ctx)) != 0)
{ SSLFatalSessionAlert(SSL_AlertInternalError, ctx);
return err;
}
if ((err = SSLAllocBuffer(&ctx->partialReadBuffer,
DEFAULT_BUFFER_SIZE, ctx)) != 0)
{ SSLFatalSessionAlert(SSL_AlertInternalError, ctx);
return err;
}
}
if (ctx->negProtocolVersion == SSL_Version_Undetermined) {
if (ctx->amountRead < 1)
{ readData.length = 1 - ctx->amountRead;
readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
len = readData.length;
err = sslIoRead(readData, &len, ctx);
if(err != 0)
{ if (err == errSSLWouldBlock) {
ctx->amountRead += len;
return err;
}
else {
err = errSSLClosedAbort;
if((ctx->protocolSide == kSSLClientSide) &&
(ctx->amountRead == 0) &&
(len == 0)) {
switch(ctx->state) {
case SSL_HdskStateServerHello:
case SSL_HdskStateServerHelloUnknownVersion:
sslHdskStateDebug("Server dropped initial connection\n");
err = errSSLConnectionRefused;
break;
default:
break;
}
}
SSLFatalSessionAlert(SSL_AlertCloseNotify, ctx);
return err;
}
}
ctx->amountRead += len;
}
}
if (ctx->amountRead < head)
{ readData.length = head - ctx->amountRead;
readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
len = readData.length;
err = sslIoRead(readData, &len, ctx);
if(err != 0)
{
switch(err) {
case errSSLWouldBlock:
ctx->amountRead += len;
break;
#if SSL_ALLOW_UNNOTICED_DISCONNECT
case errSSLClosedGraceful:
if((ctx->amountRead == 0) &&
(len == 0) &&
(ctx->state == SSL_HdskStateClientReady)) {
SSLChangeHdskState(ctx, SSL_HdskStateNoNotifyClose);
err = errSSLClosedNoNotify;
break;
}
else {
err = errSSLClosedAbort;
}
#endif
default:
SSLFatalSessionAlert(SSL_AlertCloseNotify, ctx);
break;
}
return err;
}
ctx->amountRead += len;
}
assert(ctx->amountRead >= head);
charPtr = ctx->partialReadBuffer.data;
rec->contentType = *charPtr++;
if (rec->contentType < SSL_RecordTypeV3_Smallest ||
rec->contentType > SSL_RecordTypeV3_Largest)
return errSSLProtocol;
rec->protocolVersion = (SSLProtocolVersion)SSLDecodeInt(charPtr, 2);
charPtr += 2;
#if ENABLE_DTLS
if(rec->protocolVersion == DTLS_Version_1_0)
{
sslUint64 seqNum;
SSLDecodeUInt64(charPtr, 8, &seqNum);
charPtr += 8;
sslLogRecordIo("Read DTLS Record %08x_%08x (seq is: %08x_%08x)",
seqNum.high, seqNum.low,
ctx->readCipher.sequenceNum.high,ctx->readCipher.sequenceNum.low);
if((seqNum.high>>8)!=(ctx->readCipher.sequenceNum.high>>8)) {
skipit=1;
} else {
ctx->readCipher.sequenceNum.high=seqNum.high;
ctx->readCipher.sequenceNum.low=seqNum.low;
}
}
#endif
contentLen = SSLDecodeInt(charPtr, 2);
charPtr += 2;
if (contentLen > (16384 + 2048))
{ SSLFatalSessionAlert(SSL_AlertRecordOverflow, ctx);
return errSSLProtocol;
}
if (!skipit && contentLen < ctx->readCipher.macRef->hash->digestSize)
{
SSLFatalSessionAlert(SSL_AlertInternalError, ctx);
return errSSLClosedAbort;
}
if (ctx->partialReadBuffer.length < head + contentLen)
{ if ((err = SSLReallocBuffer(&ctx->partialReadBuffer, head + contentLen, ctx)) != 0)
{ SSLFatalSessionAlert(SSL_AlertInternalError, ctx);
return err;
}
}
if (ctx->amountRead < head + contentLen)
{ readData.length = head + contentLen - ctx->amountRead;
readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
len = readData.length;
err = sslIoRead(readData, &len, ctx);
if(err != 0)
{ if (err == errSSLWouldBlock)
ctx->amountRead += len;
else
SSLFatalSessionAlert(SSL_AlertCloseNotify, ctx);
return err;
}
ctx->amountRead += len;
}
assert(ctx->amountRead >= head + contentLen);
cipherFragment.data = ctx->partialReadBuffer.data + head;
cipherFragment.length = contentLen;
ctx->amountRead = 0;
if(skipit) {
DTLSRetransmit(ctx);
return errSSLWouldBlock;
}
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->decryptRecord(rec->contentType,
&cipherFragment, ctx)) != 0)
return err;
IncrementUInt64(&ctx->readCipher.sequenceNum);
if ((err = SSLAllocBuffer(&rec->contents, cipherFragment.length, ctx)) != 0)
{ SSLFatalSessionAlert(SSL_AlertInternalError, ctx);
return err;
}
memcpy(rec->contents.data, cipherFragment.data, cipherFragment.length);
return noErr;
}
OSStatus SSLVerifyMac(
UInt8 type,
SSLBuffer *data,
UInt8 *compareMAC,
SSLContext *ctx)
{
OSStatus err;
UInt8 macData[SSL_MAX_DIGEST_LEN];
SSLBuffer secret, mac;
secret.data = ctx->readCipher.macSecret;
secret.length = ctx->readCipher.macRef->hash->digestSize;
mac.data = macData;
mac.length = ctx->readCipher.macRef->hash->digestSize;
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->computeMac(type,
*data,
mac,
&ctx->readCipher,
ctx->readCipher.sequenceNum,
ctx)) != 0)
return err;
if ((memcmp(mac.data, compareMAC, mac.length)) != 0) {
sslErrorLog("ssl3VerifyMac: Mac verify failure\n");
return errSSLProtocol;
}
return noErr;
}