#include "sslMemory.h"
#include "sslContext.h"
#include "sslRecord.h"
#include "sslAlertMessage.h"
#include "sslSession.h"
#include "ssl2.h"
#include "sslDebug.h"
#include "cipherSpecs.h"
#include "sslUtils.h"
#include <CoreServices/../Frameworks/CarbonCore.framework/Headers/MacErrors.h>
#include <assert.h>
#include <string.h>
#ifndef NDEBUG
static void inline sslIoTrace(
const char *op,
UInt32 req,
UInt32 moved,
OSStatus stat)
{
sslLogRecordIo("===%s: req %4lu moved %4lu status %ld\n",
op, req, moved, stat);
}
#else
#define sslIoTrace(op, req, moved, stat)
#endif
static OSStatus SSLProcessProtocolMessage(SSLRecord &rec, SSLContext *ctx);
static OSStatus SSLHandshakeProceed(SSLContext *ctx);
static OSStatus SSLInitConnection(SSLContext *ctx);
static OSStatus SSLServiceWriteQueue(SSLContext *ctx);
OSStatus
SSLWrite(
SSLContext *ctx,
const void * data,
UInt32 dataLength,
UInt32 *bytesWritten)
{
OSStatus err;
SSLRecord rec;
UInt32 dataLen, processed;
if((ctx == NULL) || (bytesWritten == NULL)) {
return paramErr;
}
dataLen = dataLength;
processed = 0;
*bytesWritten = 0;
switch(ctx->state) {
case SSL_HdskStateGracefulClose:
err = errSSLClosedGraceful;
goto abort;
case SSL_HdskStateErrorClose:
err = errSSLClosedAbort;
goto abort;
default:
sslIoTrace("SSLWrite", dataLength, 0, badReqErr);
return badReqErr;
case SSL_HdskStateServerReady:
case SSL_HdskStateClientReady:
break;
}
err = noErr;
while (ctx->writeCipher.ready == 0)
{ if ((err = SSLHandshakeProceed(ctx)) != 0)
goto exit;
}
if ((err = SSLServiceWriteQueue(ctx)) != 0)
goto abort;
processed = 0;
while (dataLen > 0)
{ rec.contentType = SSL_RecordTypeAppData;
rec.protocolVersion = ctx->negProtocolVersion;
rec.contents.data = ((UInt8*)data) + processed;
if (dataLen < MAX_RECORD_LENGTH)
rec.contents.length = dataLen;
else
rec.contents.length = MAX_RECORD_LENGTH;
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->writeRecord(rec, ctx)) != 0)
goto exit;
processed += rec.contents.length;
dataLen -= rec.contents.length;
}
*bytesWritten = processed;
if ((err = SSLServiceWriteQueue(ctx)) == 0) {
err = noErr;
}
exit:
if (err != 0 && err != errSSLWouldBlock && err != errSSLClosedGraceful) {
sslErrorLog("SSLWrite: going to state errorCLose due to err %d\n",
(int)err);
SSLChangeHdskState(ctx, SSL_HdskStateErrorClose);
}
abort:
sslIoTrace("SSLWrite", dataLength, *bytesWritten, err);
return err;
}
OSStatus
SSLRead (
SSLContext *ctx,
void * data,
UInt32 dataLength,
UInt32 *processed)
{
OSStatus err;
UInt8 *charPtr;
UInt32 bufSize, remaining, count;
SSLRecord rec;
if((ctx == NULL) || (processed == NULL)) {
return paramErr;
}
bufSize = dataLength;
*processed = 0;
switch(ctx->state) {
case SSL_HdskStateGracefulClose:
err = errSSLClosedGraceful;
goto abort;
case SSL_HdskStateErrorClose:
err = errSSLClosedAbort;
goto abort;
case SSL_HdskStateNoNotifyClose:
err = errSSLClosedNoNotify;
goto abort;
default:
break;
}
err = noErr;
while (ctx->readCipher.ready == 0) {
if ((err = SSLHandshakeProceed(ctx)) != 0) {
goto exit;
}
}
if ((err = SSLServiceWriteQueue(ctx)) != 0) {
if (err != errSSLWouldBlock) {
goto exit;
}
err = noErr;
}
remaining = bufSize;
charPtr = (UInt8*)data;
if (ctx->receivedDataBuffer.data)
{ count = ctx->receivedDataBuffer.length - ctx->receivedDataPos;
if (count > bufSize)
count = bufSize;
memcpy(data, ctx->receivedDataBuffer.data + ctx->receivedDataPos, count);
remaining -= count;
charPtr += count;
*processed += count;
ctx->receivedDataPos += count;
}
assert(ctx->receivedDataPos <= ctx->receivedDataBuffer.length);
assert(*processed + remaining == bufSize);
assert(charPtr == ((UInt8*)data) + *processed);
if (ctx->receivedDataBuffer.data != 0 &&
ctx->receivedDataPos >= ctx->receivedDataBuffer.length)
{ SSLFreeBuffer(ctx->receivedDataBuffer, ctx);
ctx->receivedDataBuffer.data = 0;
ctx->receivedDataPos = 0;
}
if (remaining > 0 && ctx->state != SSL_HdskStateGracefulClose)
{ assert(ctx->receivedDataBuffer.data == 0);
if ((err = SSLReadRecord(rec, ctx)) != 0) {
goto exit;
}
if (rec.contentType == SSL_RecordTypeAppData ||
rec.contentType == SSL_RecordTypeV2_0)
{ if (rec.contents.length <= remaining)
{ memcpy(charPtr, rec.contents.data, rec.contents.length);
remaining -= rec.contents.length;
charPtr += rec.contents.length;
*processed += rec.contents.length;
{
SSLBuffer *b = &rec.contents;
if ((err = SSLFreeBuffer(*b, ctx)) != 0) {
goto exit;
}
}
}
else
{ memcpy(charPtr, rec.contents.data, remaining);
charPtr += remaining;
*processed += remaining;
ctx->receivedDataBuffer = rec.contents;
ctx->receivedDataPos = remaining;
remaining = 0;
}
}
else {
if ((err = SSLProcessProtocolMessage(rec, ctx)) != 0) {
goto exit;
}
if ((err = SSLFreeBuffer(rec.contents, ctx)) != 0) {
goto exit;
}
}
}
err = noErr;
exit:
switch(err) {
case noErr:
case errSSLWouldBlock:
case errSSLClosedGraceful:
case errSSLClosedNoNotify:
break;
default:
sslErrorLog("SSLRead: going to state errorClose due to err %d\n",
(int)err);
SSLChangeHdskState(ctx, SSL_HdskStateErrorClose);
break;
}
abort:
sslIoTrace("SSLRead ", dataLength, *processed, err);
return err;
}
#if SSL_DEBUG
#include "appleCdsa.h"
#endif
OSStatus
SSLHandshake(SSLContext *ctx)
{
OSStatus err;
if(ctx == NULL) {
return paramErr;
}
if (ctx->state == SSL_HdskStateGracefulClose)
return errSSLClosedGraceful;
if (ctx->state == SSL_HdskStateErrorClose)
return errSSLClosedAbort;
if(ctx->validCipherSpecs == NULL) {
err = sslBuildCipherSpecArray(ctx);
if(err) {
return err;
}
}
err = noErr;
while (ctx->readCipher.ready == 0 || ctx->writeCipher.ready == 0)
{ if ((err = SSLHandshakeProceed(ctx)) != 0)
return err;
}
if ((err = SSLServiceWriteQueue(ctx)) != 0) {
return err;
}
return noErr;
}
static OSStatus
SSLHandshakeProceed(SSLContext *ctx)
{ OSStatus err;
SSLRecord rec;
if (ctx->state == SSL_HdskStateUninit)
if ((err = SSLInitConnection(ctx)) != 0)
return err;
if ((err = SSLServiceWriteQueue(ctx)) != 0)
return err;
assert(ctx->readCipher.ready == 0);
if ((err = SSLReadRecord(rec, ctx)) != 0)
return err;
if ((err = SSLProcessProtocolMessage(rec, ctx)) != 0)
{ SSLFreeBuffer(rec.contents, ctx);
return err;
}
if ((err = SSLFreeBuffer(rec.contents, ctx)) != 0)
return err;
return noErr;
}
static OSStatus
SSLInitConnection(SSLContext *ctx)
{ OSStatus err = noErr;
if (ctx->protocolSide == SSL_ClientSide) {
SSLChangeHdskState(ctx, SSL_HdskStateClientUninit);
}
else
{ assert(ctx->protocolSide == SSL_ServerSide);
SSLChangeHdskState(ctx, SSL_HdskStateServerUninit);
}
if (ctx->peerID.data != 0)
{ SSLGetSessionData(&ctx->resumableSession, ctx);
}
Boolean cachedV3OrTls1 = false;
if (ctx->resumableSession.data != 0) {
SSLProtocolVersion savedVersion;
Boolean enable;
if ((err = SSLRetrieveSessionProtocolVersion(ctx->resumableSession,
&savedVersion, ctx)) != 0) {
return err;
}
switch(savedVersion) {
case SSL_Version_2_0:
enable = ctx->versionSsl2Enable;
break;
case SSL_Version_3_0:
enable = ctx->versionSsl3Enable;
cachedV3OrTls1 = true; break;
case TLS_Version_1_0:
enable = ctx->versionTls1Enable;
cachedV3OrTls1 = true;
break;
default:
assert(0);
return errSSLInternal;
}
if(!enable) {
sslLogResumSessDebug("===Resumable session protocol mismatch");
SSLFreeBuffer(ctx->resumableSession, ctx);
cachedV3OrTls1 = false;
}
else {
sslLogResumSessDebug("===attempting to resume session");
}
}
if (ctx->state == SSL_HdskStateClientUninit && ctx->writeCipher.ready == 0)
{
assert(ctx->negProtocolVersion == SSL_Version_Undetermined);
if(ctx->versionSsl2Enable && !cachedV3OrTls1) {
err = SSL2AdvanceHandshake(SSL2_MsgKickstart, ctx);
}
else {
err = SSLAdvanceHandshake(SSL_HdskHelloRequest, ctx);
}
}
return err;
}
static OSStatus
SSLServiceWriteQueue(SSLContext *ctx)
{ OSStatus err = noErr, werr = noErr;
UInt32 written = 0;
SSLBuffer buf, recBuf;
WaitingRecord *rec;
while (!werr && ((rec = ctx->recordWriteQueue) != 0))
{ buf.data = rec->data.data + rec->sent;
buf.length = rec->data.length - rec->sent;
werr = sslIoWrite(buf, &written, ctx);
rec->sent += written;
if (rec->sent >= rec->data.length)
{ assert(rec->sent == rec->data.length);
assert(err == 0);
err = SSLFreeBuffer(rec->data, ctx);
assert(err == 0);
recBuf.data = (UInt8*)rec;
recBuf.length = sizeof(WaitingRecord);
ctx->recordWriteQueue = rec->next;
err = SSLFreeBuffer(recBuf, ctx);
assert(err == 0);
}
if (err)
return err;
}
return werr;
}
static OSStatus
SSLProcessProtocolMessage(SSLRecord &rec, SSLContext *ctx)
{ OSStatus err;
switch (rec.contentType)
{ case SSL_RecordTypeHandshake:
sslLogRxProtocolDebug("Handshake");
err = SSLProcessHandshakeRecord(rec, ctx);
break;
case SSL_RecordTypeAlert:
sslLogRxProtocolDebug("Alert");
err = SSLProcessAlert(rec, ctx);
break;
case SSL_RecordTypeChangeCipher:
sslLogRxProtocolDebug("ChangeCipher");
err = SSLProcessChangeCipherSpec(rec, ctx);
break;
case SSL_RecordTypeV2_0:
sslLogRxProtocolDebug("RecordTypeV2_0");
err = SSL2ProcessMessage(rec, ctx);
break;
default:
sslLogRxProtocolDebug("Bad msg");
return errSSLProtocol;
}
return err;
}
OSStatus
SSLClose(SSLContext *ctx)
{
OSStatus err = noErr;
if(ctx == NULL) {
return paramErr;
}
if (ctx->negProtocolVersion >= SSL_Version_3_0)
err = SSLSendAlert(SSL_AlertLevelWarning, SSL_AlertCloseNotify, ctx);
if (err == 0)
err = SSLServiceWriteQueue(ctx);
SSLChangeHdskState(ctx, SSL_HdskStateGracefulClose);
if (err == ioErr)
err = noErr;
return err;
}
OSStatus
SSLGetBufferedReadSize(SSLContextRef ctx,
size_t *bufSize)
{
if(ctx == NULL) {
return paramErr;
}
if(ctx->receivedDataBuffer.data == NULL) {
*bufSize = 0;
}
else {
assert(ctx->receivedDataBuffer.length >= ctx->receivedDataPos);
*bufSize = ctx->receivedDataBuffer.length - ctx->receivedDataPos;
}
return noErr;
}