#ifndef _SSLTRSPT_H_
#include "ssltrspt.h"
#endif
#ifndef _SSLALLOC_H_
#include "sslalloc.h"
#endif
#ifndef _SSLCTX_H_
#include "sslctx.h"
#endif
#ifndef _SSLCTX_H_
#include "sslrec.h"
#endif
#ifndef _SSLALERT_H_
#include "sslalert.h"
#endif
#ifndef _SSLSESS_H_
#include "sslsess.h"
#endif
#ifndef _SSL2_H_
#include "ssl2.h"
#endif
#ifndef _APPLE_GLUE_H_
#include "appleGlue.h"
#endif
#ifndef _SSL_DEBUG_H_
#include "sslDebug.h"
#endif
#ifndef _CIPHER_SPECS_H_
#include "cipherSpecs.h"
#endif
#include <CoreServices/../Frameworks/CarbonCore.framework/Headers/MacErrors.h>
#include <assert.h>
#include <string.h>
#define SSL_IO_TRACE 0
#if SSL_IO_TRACE
static void sslIoTrace(
const char *op,
UInt32 req,
UInt32 moved,
OSStatus stat)
{
printf("===%s: req %4d moved %4d status %d\n",
op, req, moved, stat);
}
#else
#define sslIoTrace(op, req, moved, stat)
#endif
static SSLErr SSLProcessProtocolMessage(SSLRecord rec, SSLContext *ctx);
static SSLErr SSLHandshakeProceed(SSLContext *ctx);
static SSLErr SSLInitConnection(SSLContext *ctx);
static SSLErr SSLServiceWriteQueue(SSLContext *ctx);
OSStatus
SSLWrite(
SSLContext *ctx,
const void * data,
UInt32 dataLength,
UInt32 *bytesWritten)
{
SSLErr err;
SSLRecord rec;
UInt32 dataLen, processed;
if((ctx == NULL) || (bytesWritten == NULL)) {
return paramErr;
}
dataLen = dataLength;
processed = 0;
*bytesWritten = 0;
switch(ctx->state) {
case SSLGracefulClose:
err = SSLConnectionClosedGraceful;
goto abort;
case SSLErrorClose:
err = SSLConnectionClosedError;
goto abort;
default:
sslIoTrace("SSLWrite", dataLength, 0, badReqErr);
return badReqErr;
case HandshakeServerReady:
case HandshakeClientReady:
break;
}
err = SSLNoErr;
while (ctx->writeCipher.ready == 0)
{ if ((err = SSLHandshakeProceed(ctx)) != 0)
goto exit;
}
if ((err = SSLServiceWriteQueue(ctx)) != 0)
goto abort;
processed = 0;
while (dataLen > 0)
{ rec.contentType = SSL_application_data;
rec.protocolVersion = ctx->negProtocolVersion;
rec.contents.data = ((UInt8*)data) + processed;
if (dataLen < MAX_RECORD_LENGTH)
rec.contents.length = dataLen;
else
rec.contents.length = MAX_RECORD_LENGTH;
assert(ctx->sslTslCalls != NULL);
if (ERR(err = ctx->sslTslCalls->writeRecord(rec, ctx)) != 0)
goto exit;
processed += rec.contents.length;
dataLen -= rec.contents.length;
}
*bytesWritten = processed;
if (ERR(err = SSLServiceWriteQueue(ctx)) != 0)
goto exit;
err = SSLNoErr;
exit:
if (err != 0 && err != SSLWouldBlockErr && err != SSLConnectionClosedGraceful) {
dprintf1("SSLWrite: going to state errorCLose due to err %d\n",
err);
SSLChangeHdskState(ctx, SSLErrorClose);
}
abort:
sslIoTrace("SSLWrite", dataLength, *bytesWritten, sslErrToOsStatus(err));
return sslErrToOsStatus(err);
}
OSStatus
SSLRead (
SSLContext *ctx,
void * data,
UInt32 dataLength,
UInt32 *processed)
{
SSLErr err;
UInt8 *progress;
UInt32 bufSize, remaining, count;
SSLRecord rec;
if((ctx == NULL) || (processed == NULL)) {
return paramErr;
}
bufSize = dataLength;
*processed = 0;
switch(ctx->state) {
case SSLGracefulClose:
err = SSLConnectionClosedGraceful;
goto abort;
case SSLErrorClose:
err = SSLConnectionClosedError;
goto abort;
case SSLNoNotifyClose:
err = SSLConnectionClosedNoNotify;
goto abort;
default:
break;
}
err = SSLNoErr;
while (ctx->readCipher.ready == 0)
{ if (ERR(err = SSLHandshakeProceed(ctx)) != 0)
goto exit;
}
if (ERR(err = SSLServiceWriteQueue(ctx)) != 0)
{ if (err != SSLWouldBlockErr)
goto exit;
err = SSLNoErr;
}
remaining = bufSize;
progress = (UInt8*)data;
if (ctx->receivedDataBuffer.data)
{ count = ctx->receivedDataBuffer.length - ctx->receivedDataPos;
if (count > bufSize)
count = bufSize;
memcpy(data, ctx->receivedDataBuffer.data + ctx->receivedDataPos, count);
remaining -= count;
progress += count;
*processed += count;
ctx->receivedDataPos += count;
}
CASSERT(ctx->receivedDataPos <= ctx->receivedDataBuffer.length);
CASSERT(*processed + remaining == bufSize);
CASSERT(progress == ((UInt8*)data) + *processed);
if (ctx->receivedDataBuffer.data != 0 &&
ctx->receivedDataPos >= ctx->receivedDataBuffer.length)
{ SSLFreeBuffer(&ctx->receivedDataBuffer, &ctx->sysCtx);
ctx->receivedDataBuffer.data = 0;
ctx->receivedDataPos = 0;
}
if (remaining > 0 && ctx->state != SSLGracefulClose)
{ CASSERT(ctx->receivedDataBuffer.data == 0);
if (ERR(err = SSLReadRecord(&rec, ctx)) != 0)
goto exit;
if (rec.contentType == SSL_application_data ||
rec.contentType == SSL_version_2_0_record)
{ if (rec.contents.length <= remaining)
{ memcpy(progress, rec.contents.data, rec.contents.length);
remaining -= rec.contents.length;
progress += rec.contents.length;
*processed += rec.contents.length;
{
SSLBuffer *b = &rec.contents;
if (ERR(err = SSLFreeBuffer(b, &ctx->sysCtx)) != 0) {
goto exit;
}
}
}
else
{ memcpy(progress, rec.contents.data, remaining);
progress += remaining;
*processed += remaining;
ctx->receivedDataBuffer = rec.contents;
ctx->receivedDataPos = remaining;
remaining = 0;
}
}
else
{ if (ERR(err = SSLProcessProtocolMessage(rec, ctx)) != 0)
goto exit;
if (ERR(err = SSLFreeBuffer(&rec.contents, &ctx->sysCtx)) != 0)
goto exit;
}
}
err = SSLNoErr;
exit:
switch(err) {
case SSLNoErr:
case SSLWouldBlockErr:
case SSLConnectionClosedGraceful:
case SSLConnectionClosedNoNotify:
break;
default:
dprintf1("SSLRead: going to state errorClose due to err %d\n",
err);
SSLChangeHdskState(ctx, SSLErrorClose);
break;
}
abort:
sslIoTrace("SSLRead ", dataLength, *processed, sslErrToOsStatus(err));
return sslErrToOsStatus(err);
}
#if SSL_DEBUG
#include "appleCdsa.h"
#endif
OSStatus
SSLHandshake(SSLContext *ctx)
{
SSLErr err;
if(ctx == NULL) {
return paramErr;
}
if (ctx->state == SSLGracefulClose)
return sslErrToOsStatus(SSLConnectionClosedGraceful);
if (ctx->state == SSLErrorClose)
return sslErrToOsStatus(SSLConnectionClosedError);
if(ctx->protocolSide == SSL_ServerSide) {
if((ctx->localCert == NULL) ||
(ctx->signingPrivKey == NULL) ||
(ctx->signingPubKey == NULL) ||
(ctx->signingKeyCsp == 0)) {
errorLog0("SSLHandshake: insufficient init\n");
return badReqErr;
}
}
if(ctx->validCipherSpecs == NULL) {
err = sslBuildCipherSpecArray(ctx);
if(err) {
return err;
}
}
err = SSLNoErr;
while (ctx->readCipher.ready == 0 || ctx->writeCipher.ready == 0)
{ if (ERR(err = SSLHandshakeProceed(ctx)) != 0)
return sslErrToOsStatus(err);
}
if ((err = SSLServiceWriteQueue(ctx)) != 0) {
return sslErrToOsStatus(err);
}
return noErr;
}
static SSLErr
SSLHandshakeProceed(SSLContext *ctx)
{ SSLErr err;
SSLRecord rec;
if (ctx->state == SSLUninitialized)
if (ERR(err = SSLInitConnection(ctx)) != 0)
return err;
if (ERR(err = SSLServiceWriteQueue(ctx)) != 0)
return err;
CASSERT(ctx->readCipher.ready == 0);
if (ERR(err = SSLReadRecord(&rec, ctx)) != 0)
return err;
if (ERR(err = SSLProcessProtocolMessage(rec, ctx)) != 0)
{ SSLFreeBuffer(&rec.contents, &ctx->sysCtx);
return err;
}
if (ERR(err = SSLFreeBuffer(&rec.contents, &ctx->sysCtx)) != 0)
return err;
return SSLNoErr;
}
static SSLErr
SSLInitConnection(SSLContext *ctx)
{ SSLErr err;
if (ctx->protocolSide == SSL_ClientSide) {
SSLChangeHdskState(ctx, HandshakeClientUninit);
}
else
{ CASSERT(ctx->protocolSide == SSL_ServerSide);
SSLChangeHdskState(ctx, HandshakeServerUninit);
}
if (ctx->peerID.data != 0)
{ ERR(SSLGetSessionData(&ctx->resumableSession, ctx));
}
if (ctx->resumableSession.data != 0) {
SSLProtocolVersion savedVersion;
if (ERR(err = SSLRetrieveSessionProtocolVersion(ctx->resumableSession,
&savedVersion, ctx)) != 0) {
return err;
}
if(savedVersion > ctx->maxProtocolVersion) {
SSLLogResumSess("===Resumable session protocol mismatch\n");
SSLFreeBuffer(&ctx->resumableSession, &ctx->sysCtx);
}
else {
SSLLogResumSess("===attempting to resume session\n");
if(ctx->protocolSide == SSL_ClientSide) {
ctx->negProtocolVersion = savedVersion;
}
}
}
if (ctx->state == HandshakeClientUninit && ctx->writeCipher.ready == 0)
{ switch (ctx->negProtocolVersion)
{ case SSL_Version_Undetermined:
case SSL_Version_3_0_With_2_0_Hello:
case SSL_Version_2_0:
if (ERR(err = SSL2AdvanceHandshake(ssl2_mt_kickstart_handshake, ctx)) != 0)
return err;
break;
case SSL_Version_3_0_Only:
case SSL_Version_3_0:
case TLS_Version_1_0_Only:
case TLS_Version_1_0:
if (ERR(err = SSLAdvanceHandshake(SSL_hello_request, ctx)) != 0)
return err;
break;
default:
sslPanic("Bad protocol version");
break;
}
}
return SSLNoErr;
}
static SSLErr
SSLServiceWriteQueue(SSLContext *ctx)
{ SSLErr err = SSLNoErr, werr = SSLNoErr;
UInt32 written = 0;
SSLBuffer buf, recBuf;
WaitingRecord *rec;
while (!werr && ((rec = ctx->recordWriteQueue) != 0))
{ buf.data = rec->data.data + rec->sent;
buf.length = rec->data.length - rec->sent;
werr = sslIoWrite(buf, &written, ctx);
rec->sent += written;
if (rec->sent >= rec->data.length)
{ CASSERT(rec->sent == rec->data.length);
CASSERT(err == 0);
err = SSLFreeBuffer(&rec->data, &ctx->sysCtx);
CASSERT(err == 0);
recBuf.data = (UInt8*)rec;
recBuf.length = sizeof(WaitingRecord);
ctx->recordWriteQueue = rec->next;
err = SSLFreeBuffer(&recBuf, &ctx->sysCtx);
CASSERT(err == 0);
}
if (ERR(err))
return err;
CASSERT(ctx->recordWriteQueue == 0 || ctx->recordWriteQueue->sent == 0);
}
return werr;
}
#if LOG_RX_PROTOCOL
static void sslLogRxProto(const char *msgType)
{
printf("---received protoMsg %s\n", msgType);
}
#else
#define sslLogRxProto(msgType)
#endif
static SSLErr
SSLProcessProtocolMessage(SSLRecord rec, SSLContext *ctx)
{ SSLErr err;
switch (rec.contentType)
{ case SSL_handshake:
sslLogRxProto("SSL_handshake");
ERR(err = SSLProcessHandshakeRecord(rec, ctx));
break;
case SSL_alert:
sslLogRxProto("SSL_alert");
ERR(err = SSLProcessAlert(rec, ctx));
break;
case SSL_change_cipher_spec:
sslLogRxProto("SSL_change_cipher_spec");
ERR(err = SSLProcessChangeCipherSpec(rec, ctx));
break;
case SSL_version_2_0_record:
sslLogRxProto("SSL_version_2_0_record");
ERR(err = SSL2ProcessMessage(rec, ctx));
break;
default:
sslLogRxProto("Bad msg");
return ERR(SSLProtocolErr);
}
return err;
}
OSStatus
SSLClose(SSLContext *ctx)
{
SSLErr err = SSLNoErr;
if(ctx == NULL) {
return paramErr;
}
if (ctx->negProtocolVersion >= SSL_Version_3_0)
ERR(err = SSLSendAlert(alert_warning, alert_close_notify, ctx));
if (err == 0)
ERR(err = SSLServiceWriteQueue(ctx));
SSLChangeHdskState(ctx, SSLGracefulClose);
if (err == SSLIOErr)
err = SSLNoErr;
return sslErrToOsStatus(err);
}
OSStatus
SSLGetBufferedReadSize(SSLContextRef ctx,
size_t *bufSize)
{
if(ctx == NULL) {
return paramErr;
}
if(ctx->receivedDataBuffer.data == NULL) {
*bufSize = 0;
}
else {
CASSERT(ctx->receivedDataBuffer.length >= ctx->receivedDataPos);
*bufSize = ctx->receivedDataBuffer.length - ctx->receivedDataPos;
}
return noErr;
}