#include "ssl2.h"
#include "sslRecord.h"
#include "sslMemory.h"
#include "sslContext.h"
#include "sslAlertMessage.h"
#include "sslDebug.h"
#include "sslUtils.h"
#include "sslDigests.h"
#include <string.h>
static OSStatus SSL2DecryptRecord(
SSLBuffer &payload,
SSLContext *ctx);
static OSStatus SSL2VerifyMAC(
SSLBuffer &content,
UInt8 *compareMAC,
SSLContext *ctx);
static OSStatus SSL2CalculateMAC(
SSLBuffer &secret,
SSLBuffer &content,
UInt32 seqNo,
const HashReference &hash,
SSLBuffer &mac,
SSLContext *ctx);
OSStatus
SSL2ReadRecord(SSLRecord &rec, SSLContext *ctx)
{ OSStatus err;
UInt32 len, contentLen;
int padding, headerSize;
UInt8 *charPtr;
SSLBuffer readData, cipherFragment;
switch (ctx->negProtocolVersion)
{ case SSL_Version_Undetermined:
case SSL_Version_2_0:
break;
case SSL_Version_3_0:
case TLS_Version_1_0:
SSLFatalSessionAlert(SSL_AlertUnexpectedMsg, ctx);
return errSSLProtocol;
default:
sslErrorLog("bad protocolVersion in ctx->protocolVersion");
return errSSLInternal;
}
if (!ctx->partialReadBuffer.data || ctx->partialReadBuffer.length < 3)
{ if (ctx->partialReadBuffer.data)
if ((err = SSLFreeBuffer(ctx->partialReadBuffer, ctx)) != 0)
{ SSL2SendError(SSL2_ErrNoCipher, ctx);
return err;
}
if ((err = SSLAllocBuffer(ctx->partialReadBuffer, DEFAULT_BUFFER_SIZE, ctx)) != 0)
{ SSL2SendError(SSL2_ErrNoCipher, ctx);
return err;
}
}
if (ctx->amountRead < 3)
{ readData.length = 3 - ctx->amountRead;
readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
len = readData.length;
err = sslIoRead(readData, &len, ctx);
if(err != 0)
{ if (err == errSSLWouldBlock)
ctx->amountRead += len;
if (err == ioErr && ctx->amountRead == 0)
err = errSSLClosedGraceful;
return err;
}
ctx->amountRead += len;
}
rec.contentType = SSL_RecordTypeV2_0;
rec.protocolVersion = SSL_Version_2_0;
charPtr = ctx->partialReadBuffer.data;
if (((*charPtr) & 0x80) != 0)
{ headerSize = 2;
contentLen = ((charPtr[0] & 0x7F) << 8) | charPtr[1];
padding = 0;
}
else if (((*charPtr) & 0x40) != 0)
{ return errSSLProtocol;
}
else
{ headerSize = 3;
contentLen = ((charPtr[0] & 0x3F) << 8) | charPtr[1];
padding = charPtr[2];
}
if((contentLen == 0) || (contentLen > 0xffff)) {
return errSSLProtocol;
}
charPtr += headerSize;
if (ctx->partialReadBuffer.length < headerSize + contentLen)
{ if ((err = SSLReallocBuffer(ctx->partialReadBuffer, 5 + contentLen, ctx)) != 0)
return err;
}
if (ctx->amountRead < headerSize + contentLen)
{ readData.length = headerSize + contentLen - ctx->amountRead;
readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
len = readData.length;
err = sslIoRead(readData, &len, ctx);
if(err != 0)
{ if (err == errSSLWouldBlock)
ctx->amountRead += len;
return err;
}
ctx->amountRead += len;
}
cipherFragment.data = ctx->partialReadBuffer.data + headerSize;
cipherFragment.length = contentLen;
if ((err = SSL2DecryptRecord(cipherFragment, ctx)) != 0)
return err;
cipherFragment.length -= padding;
IncrementUInt64(&ctx->readCipher.sequenceNum);
if ((err = SSLAllocBuffer(rec.contents, cipherFragment.length, ctx)) != 0)
return err;
memcpy(rec.contents.data, cipherFragment.data, cipherFragment.length);
ctx->amountRead = 0;
return noErr;
}
OSStatus
SSL2WriteRecord(SSLRecord &rec, SSLContext *ctx)
{ OSStatus err;
int padding = 0, i, headerSize;
WaitingRecord *out, *queue;
SSLBuffer buf, content, payload, secret, mac;
UInt8 *charPtr;
UInt16 payloadSize, blockSize;
assert(rec.contents.length < 16384);
out = 0;
if ((err = SSLAllocBuffer(buf, sizeof(WaitingRecord), ctx)) != 0)
return err;
out = (WaitingRecord*)buf.data;
out->next = 0;
out->sent = 0;
payloadSize = (UInt16)
(rec.contents.length + ctx->writeCipher.macRef->hash->digestSize);
blockSize = ctx->writeCipher.symCipher->blockSize;
if (blockSize > 0)
{
padding = blockSize - (payloadSize % blockSize);
if (padding == blockSize)
padding = 0;
payloadSize += padding;
headerSize = 3;
}
else
{ padding = 0;
headerSize = 2;
}
out->data.data = 0;
if ((err = SSLAllocBuffer(out->data, headerSize + payloadSize, ctx)) != 0)
goto fail;
charPtr = out->data.data;
if (headerSize == 2)
charPtr = SSLEncodeInt(charPtr, payloadSize | 0x8000, 2);
else
{ charPtr = SSLEncodeInt(charPtr, payloadSize, 2);
*charPtr++ = padding;
}
payload.data = charPtr;
payload.length = payloadSize;
mac.data = charPtr;
mac.length = ctx->writeCipher.macRef->hash->digestSize;
charPtr += mac.length;
content.data = charPtr;
content.length = rec.contents.length + padding;
memcpy(charPtr, rec.contents.data, rec.contents.length);
charPtr += rec.contents.length;
i = padding;
while (i--)
*charPtr++ = padding;
assert(charPtr == out->data.data + out->data.length);
secret.data = ctx->writeCipher.macSecret;
secret.length = ctx->writeCipher.symCipher->keySize;
if (mac.length > 0)
if ((err = SSL2CalculateMAC(secret, content,
ctx->writeCipher.sequenceNum.low,
*ctx->writeCipher.macRef->hash, mac, ctx)) != 0)
goto fail;
if ((err = ctx->writeCipher.symCipher->encrypt(payload,
payload,
&ctx->writeCipher,
ctx)) != 0)
goto fail;
if (ctx->recordWriteQueue == 0)
ctx->recordWriteQueue = out;
else
{ queue = ctx->recordWriteQueue;
while (queue->next != 0)
queue = queue->next;
queue->next = out;
}
IncrementUInt64(&ctx->writeCipher.sequenceNum);
return noErr;
fail:
SSLFreeBuffer(out->data, 0);
buf.data = (UInt8*)out;
buf.length = sizeof(WaitingRecord);
SSLFreeBuffer(buf, ctx);
return err;
}
static OSStatus
SSL2DecryptRecord(SSLBuffer &payload, SSLContext *ctx)
{ OSStatus err;
SSLBuffer content;
if (ctx->readCipher.symCipher->blockSize > 0)
if (payload.length % ctx->readCipher.symCipher->blockSize != 0)
return errSSLProtocol;
if ((err = ctx->readCipher.symCipher->decrypt(payload,
payload,
&ctx->readCipher,
ctx)) != 0)
return err;
if (ctx->readCipher.macRef->hash->digestSize > 0)
{ content.data = payload.data + ctx->readCipher.macRef->hash->digestSize;
content.length = payload.length - ctx->readCipher.macRef->hash->digestSize;
if ((err = SSL2VerifyMAC(content, payload.data, ctx)) != 0)
return err;
payload = content;
}
return noErr;
}
#define IGNORE_MAC_FAILURE 0
static OSStatus
SSL2VerifyMAC(SSLBuffer &content, UInt8 *compareMAC, SSLContext *ctx)
{ OSStatus err;
UInt8 calculatedMAC[SSL_MAX_DIGEST_LEN];
SSLBuffer secret, mac;
secret.data = ctx->readCipher.macSecret;
secret.length = ctx->readCipher.symCipher->keySize;
mac.data = calculatedMAC;
mac.length = ctx->readCipher.macRef->hash->digestSize;
if ((err = SSL2CalculateMAC(secret, content, ctx->readCipher.sequenceNum.low,
*ctx->readCipher.macRef->hash, mac, ctx)) != 0)
return err;
if (memcmp(mac.data, compareMAC, mac.length) != 0) {
#if IGNORE_MAC_FAILURE
sslErrorLog("SSL2VerifyMAC: Mac verify failure\n");
return noErr;
#else
sslErrorLog("SSL2VerifyMAC: Mac verify failure\n");
return errSSLProtocol;
#endif
}
return noErr;
}
#define LOG_MAC_DATA 0
#if LOG_MAC_DATA
static void logMacData(
char *field,
SSLBuffer *data)
{
int i;
printf("%s: ", field);
for(i=0; i<data->length; i++) {
printf("%02X", data->data[i]);
if((i % 4) == 3) {
printf(" ");
}
}
printf("\n");
}
#else
#define logMacData(f, d)
#endif
static OSStatus
SSL2CalculateMAC(
SSLBuffer &secret,
SSLBuffer &content,
UInt32 seqNo,
const HashReference &hash,
SSLBuffer &mac,
SSLContext *ctx)
{ OSStatus err;
UInt8 sequenceNum[4];
SSLBuffer seqData, hashContext;
SSLEncodeInt(sequenceNum, seqNo, 4);
seqData.data = sequenceNum;
seqData.length = 4;
hashContext.data = 0;
if ((err = ReadyHash(hash, hashContext, ctx)) != 0)
return err;
if ((err = hash.update(hashContext, secret)) != 0)
goto fail;
if ((err = hash.update(hashContext, content)) != 0)
goto fail;
if ((err = hash.update(hashContext, seqData)) != 0)
goto fail;
if ((err = hash.final(hashContext, mac)) != 0)
goto fail;
logMacData("secret ", &secret);
logMacData("seqData", &seqData);
logMacData("mac ", &mac);
err = noErr;
fail:
SSLFreeBuffer(hashContext, ctx);
return err;
}
OSStatus
SSL2SendError(SSL2ErrorCode error, SSLContext *ctx)
{ OSStatus err;
SSLRecord rec;
UInt8 errorData[3];
rec.contentType = SSL_RecordTypeV2_0;
rec.protocolVersion = SSL_Version_2_0;
rec.contents.data = errorData;
rec.contents.length = 3;
errorData[0] = SSL2_MsgError;
SSLEncodeInt(errorData + 1, error, 2);
err = SSL2WriteRecord(rec, ctx);
return err;
}