#include "ssl.h"
#include "sslMemory.h"
#include "sslContext.h"
#include "sslRecord.h"
#include "sslAlertMessage.h"
#include "sslSession.h"
#include "sslDebug.h"
#include "sslCipherSpecs.h"
#include "sslUtils.h"
#include <assert.h>
#include <string.h>
#include <utilities/SecIOFormat.h>
#ifndef NDEBUG
static inline void sslIoTrace(
const char *op,
size_t req,
size_t moved,
OSStatus stat)
{
sslLogRecordIo("===%s: req %4lu moved %4lu status %d",
op, req, moved, (int)stat);
}
#else
#define sslIoTrace(op, req, moved, stat)
#endif
extern int kSplitDefaultValue;
static OSStatus SSLProcessProtocolMessage(SSLRecord *rec, SSLContext *ctx);
static OSStatus SSLHandshakeProceed(SSLContext *ctx);
static OSStatus SSLInitConnection(SSLContext *ctx);
static Boolean isFalseStartAllowed(SSLContext *ctx)
{
SSL_CipherAlgorithm c=sslCipherSuiteGetSymmetricCipherAlgorithm(ctx->selectedCipher);
KeyExchangeMethod kem=sslCipherSuiteGetKeyExchangeMethod(ctx->selectedCipher);
return
(
(c==SSL_CipherAlgorithmAES_128_CBC) ||
(c==SSL_CipherAlgorithmAES_128_GCM) ||
(c==SSL_CipherAlgorithmAES_256_CBC) ||
(c==SSL_CipherAlgorithmAES_256_GCM) ||
(c==SSL_CipherAlgorithmRC4_128)
) && (
(kem==SSL_ECDHE_ECDSA) ||
(kem==SSL_ECDHE_RSA) ||
(kem==SSL_DHE_RSA) ||
(kem==SSL_DHE_DSS)
) && (
(ctx->negAuthType==SSLClientAuthNone) ||
(ctx->negAuthType==SSLClientAuth_DSSSign) ||
(ctx->negAuthType==SSLClientAuth_RSASign) ||
(ctx->negAuthType==SSLClientAuth_ECDSASign)
);
}
OSStatus
SSLWrite(
SSLContext *ctx,
const void * data,
size_t dataLength,
size_t *bytesWritten)
{
OSStatus err;
SSLRecord rec;
size_t dataLen, processed;
Boolean split;
sslLogRecordIo("SSLWrite top");
if((ctx == NULL) || (bytesWritten == NULL)) {
return errSecParam;
}
dataLen = dataLength;
processed = 0;
*bytesWritten = 0;
switch(ctx->state) {
case SSL_HdskStateGracefulClose:
err = errSSLClosedGraceful;
goto abort;
case SSL_HdskStateErrorClose:
err = errSSLClosedAbort;
goto abort;
case SSL_HdskStateServerReady:
case SSL_HdskStateClientReady:
break;
default:
if(ctx->state < SSL_HdskStateServerHello) {
sslIoTrace("SSLWrite", dataLength, 0, errSecBadReq);
return errSecBadReq;
}
break;
}
err = errSecSuccess;
while (!(
(ctx->state==SSL_HdskStateServerReady) ||
(ctx->state==SSL_HdskStateClientReady) ||
(ctx->writeCipher_ready && ctx->falseStartEnabled && isFalseStartAllowed(ctx))
))
{ if ((err = SSLHandshakeProceed(ctx)) != 0)
goto exit;
}
if ((err = SSLServiceWriteQueue(ctx)) != 0)
goto abort;
split = (ctx->oneByteRecordEnable && ctx->negProtocolVersion <= TLS_Version_1_0);
if (split) {
SSL_CipherAlgorithm cipherAlg = sslCipherSuiteGetSymmetricCipherAlgorithm(ctx->selectedCipher);
split = (cipherAlg > SSL_CipherAlgorithmRC4_128 &&
cipherAlg < SSL_CipherAlgorithmAES_128_GCM);
if (split) {
split = (kSplitDefaultValue == 2 && !ctx->wroteAppData) ? false : true;
}
}
processed = 0;
while (dataLen > 0)
{ rec.contentType = SSL_RecordTypeAppData;
rec.protocolVersion = ctx->negProtocolVersion;
rec.contents.data = ((uint8_t *)data) + processed;
if (processed == 0 && split)
rec.contents.length = 1;
else if (dataLen < MAX_RECORD_LENGTH)
rec.contents.length = dataLen;
else
rec.contents.length = MAX_RECORD_LENGTH;
if ((err = SSLWriteRecord(rec, ctx)) != 0)
goto exit;
processed += rec.contents.length;
dataLen -= rec.contents.length;
ctx->wroteAppData = 1;
}
*bytesWritten = processed;
if ((err = SSLServiceWriteQueue(ctx)) == 0) {
err = errSecSuccess;
}
exit:
switch(err) {
case errSecSuccess:
case errSSLWouldBlock:
case errSSLUnexpectedRecord:
case errSSLServerAuthCompleted:
case errSSLClientCertRequested:
case errSSLClosedGraceful:
break;
default:
sslErrorLog("SSLWrite: going to state errorClose due to err %d\n",
(int)err);
SSLChangeHdskState(ctx, SSL_HdskStateErrorClose);
break;
}
abort:
sslIoTrace("SSLWrite", dataLength, *bytesWritten, err);
return err;
}
OSStatus
SSLRead (
SSLContext *ctx,
void * data,
size_t dataLength,
size_t *processed)
{
OSStatus err;
uint8_t *charPtr;
size_t bufSize, remaining, count;
SSLRecord rec;
sslLogRecordIo("SSLRead top");
if((ctx == NULL) || (data == NULL) || (processed == NULL)) {
return errSecParam;
}
bufSize = dataLength;
*processed = 0;
readRetry:
switch(ctx->state) {
case SSL_HdskStateGracefulClose:
err = errSSLClosedGraceful;
goto abort;
case SSL_HdskStateErrorClose:
err = errSSLClosedAbort;
goto abort;
case SSL_HdskStateNoNotifyClose:
err = errSSLClosedNoNotify;
goto abort;
default:
break;
}
err = errSecSuccess;
while (ctx->readCipher_ready == 0) {
if ((err = SSLHandshakeProceed(ctx)) != 0) {
goto exit;
}
}
if ((err = SSLServiceWriteQueue(ctx)) != 0) {
if (err != errSSLWouldBlock) {
goto exit;
}
err = errSecSuccess;
}
remaining = bufSize;
charPtr = (uint8_t *)data;
if (ctx->receivedDataBuffer.data)
{ count = ctx->receivedDataBuffer.length - ctx->receivedDataPos;
if (count > bufSize)
count = bufSize;
memcpy(data, ctx->receivedDataBuffer.data + ctx->receivedDataPos, count);
remaining -= count;
charPtr += count;
*processed += count;
ctx->receivedDataPos += count;
}
assert(ctx->receivedDataPos <= ctx->receivedDataBuffer.length);
assert(*processed + remaining == bufSize);
assert(charPtr == ((uint8_t *)data) + *processed);
if (ctx->receivedDataBuffer.data != 0 &&
ctx->receivedDataPos >= ctx->receivedDataBuffer.length)
{ SSLFreeBuffer(&ctx->receivedDataBuffer);
ctx->receivedDataBuffer.data = 0;
ctx->receivedDataPos = 0;
}
if (remaining > 0 && ctx->state != SSL_HdskStateGracefulClose)
{ assert(ctx->receivedDataBuffer.data == 0);
if ((err = SSLReadRecord(&rec, ctx)) != 0) {
goto exit;
}
if (rec.contentType == SSL_RecordTypeAppData ||
rec.contentType == SSL_RecordTypeV2_0)
{ if (rec.contents.length <= remaining)
{ memcpy(charPtr, rec.contents.data, rec.contents.length);
remaining -= rec.contents.length;
charPtr += rec.contents.length;
*processed += rec.contents.length;
{
if ((err = SSLFreeRecord(rec, ctx))) {
goto exit;
}
}
}
else
{ memcpy(charPtr, rec.contents.data, remaining);
charPtr += remaining;
*processed += remaining;
ctx->receivedDataBuffer = rec.contents;
ctx->receivedDataPos = remaining;
remaining = 0;
}
}
else {
if ((err = SSLProcessProtocolMessage(&rec, ctx)) != 0) {
goto exit;
}
if ((err = SSLFreeRecord(rec, ctx))) {
goto exit;
}
}
}
err = errSecSuccess;
exit:
if(((err == errSecSuccess) && (*processed == 0) && dataLength) || (err == errSSLUnexpectedRecord)) {
sslLogNegotiateDebug("SSLRead recursion");
goto readRetry;
}
switch(err) {
case errSecSuccess:
case errSSLWouldBlock:
case errSSLUnexpectedRecord:
case errSSLServerAuthCompleted:
case errSSLClientCertRequested:
case errSSLClosedGraceful:
case errSSLClosedNoNotify:
break;
default:
sslErrorLog("SSLRead: going to state errorClose due to err %d\n",
(int)err);
SSLChangeHdskState(ctx, SSL_HdskStateErrorClose);
break;
}
abort:
sslIoTrace("SSLRead ", dataLength, *processed, err);
return err;
}
#if SSL_DEBUG
#include "sslCrypto.h"
#endif
OSStatus
SSLHandshake(SSLContext *ctx)
{
OSStatus err;
if(ctx == NULL) {
return errSecParam;
}
if (ctx->state == SSL_HdskStateGracefulClose)
return errSSLClosedGraceful;
if (ctx->state == SSL_HdskStateErrorClose)
return errSSLClosedAbort;
#if SSL_ECDSA_HACK
ctx->versionSsl2Enable = false;
#endif
if(ctx->validCipherSuites == NULL) {
err = sslBuildCipherSuiteArray(ctx);
if(err) {
return err;
}
}
err = errSecSuccess;
if(ctx->isDTLS) {
if (ctx->timeout_deadline<CFAbsoluteTimeGetCurrent()) {
DTLSRetransmit(ctx);
}
}
while (ctx->readCipher_ready == 0 || ctx->writeCipher_ready == 0)
{
err = SSLHandshakeProceed(ctx);
if((err != 0) && (err != errSSLUnexpectedRecord))
return err;
}
if ((err = SSLServiceWriteQueue(ctx)) != 0) {
return err;
}
return errSecSuccess;
}
static OSStatus
SSLHandshakeProceed(SSLContext *ctx)
{ OSStatus err;
SSLRecord rec;
if (ctx->signalServerAuth) {
ctx->signalServerAuth = false;
return errSSLServerAuthCompleted;
}
if (ctx->signalCertRequest) {
ctx->signalCertRequest = false;
return errSSLClientCertRequested;
}
if (ctx->signalClientAuth) {
ctx->signalClientAuth = false;
return errSSLClientAuthCompleted;
}
if (ctx->state == SSL_HdskStateUninit)
if ((err = SSLInitConnection(ctx)) != 0)
return err;
if ((ctx->protocolSide == kSSLClientSide) &&
(ctx->state == SSL_HdskStateClientCert))
if ((err = SSLAdvanceHandshake(SSL_HdskServerHelloDone, ctx)) != 0)
return err;
if ((err = SSLServiceWriteQueue(ctx)) != 0)
return err;
assert(ctx->readCipher_ready == 0);
if ((err = SSLReadRecord(&rec, ctx)) != 0)
return err;
if ((err = SSLProcessProtocolMessage(&rec, ctx)) != 0)
{
SSLFreeRecord(rec, ctx);
return err;
}
if ((err = SSLFreeRecord(rec, ctx)))
return err;
return errSecSuccess;
}
static OSStatus
SSLInitConnection(SSLContext *ctx)
{ OSStatus err = errSecSuccess;
if (ctx->protocolSide == kSSLClientSide) {
SSLChangeHdskState(ctx, SSL_HdskStateClientUninit);
}
else
{ assert(ctx->protocolSide == kSSLServerSide);
SSLChangeHdskState(ctx, SSL_HdskStateServerUninit);
}
if (ctx->peerID.data != 0)
{ SSLGetSessionData(&ctx->resumableSession, ctx);
}
Boolean cachedV3OrTls1 = ctx->isDTLS;
if (ctx->resumableSession.data != 0) {
SSLProtocolVersion savedVersion;
if ((err = SSLRetrieveSessionProtocolVersion(ctx->resumableSession,
&savedVersion, ctx)) != 0)
return err;
if (ctx->isDTLS
? (savedVersion <= ctx->minProtocolVersion &&
savedVersion >= ctx->maxProtocolVersion)
: (savedVersion >= ctx->minProtocolVersion &&
savedVersion <= ctx->maxProtocolVersion)) {
cachedV3OrTls1 = savedVersion != SSL_Version_2_0;
sslLogResumSessDebug("===attempting to resume session");
} else {
sslLogResumSessDebug("===Resumable session protocol mismatch");
SSLFreeBuffer(&ctx->resumableSession);
}
}
if (ctx->state == SSL_HdskStateClientUninit && ctx->writeCipher_ready == 0)
{
assert(ctx->negProtocolVersion == SSL_Version_Undetermined);
#if ENABLE_SSLV2
if(!cachedV3OrTls1) {
err = SSL2AdvanceHandshake(SSL2_MsgKickstart, ctx);
}
else
#endif
{
err = SSLAdvanceHandshake(SSL_HdskHelloRequest, ctx);
}
}
return err;
}
static OSStatus
SSLProcessProtocolMessage(SSLRecord *rec, SSLContext *ctx)
{ OSStatus err;
switch (rec->contentType)
{ case SSL_RecordTypeHandshake:
sslLogRxProtocolDebug("Handshake");
if(ctx->isDTLS)
err = DTLSProcessHandshakeRecord(*rec, ctx);
else
err = SSLProcessHandshakeRecord(*rec, ctx);
break;
case SSL_RecordTypeAlert:
sslLogRxProtocolDebug("Alert");
err = SSLProcessAlert(*rec, ctx);
break;
case SSL_RecordTypeChangeCipher:
sslLogRxProtocolDebug("ChangeCipher");
err = SSLProcessChangeCipherSpec(*rec, ctx);
break;
#if ENABLE_SSLV2
case SSL_RecordTypeV2_0:
sslLogRxProtocolDebug("RecordTypeV2_0");
err = SSL2ProcessMessage(rec, ctx);
break;
#endif
default:
sslLogRxProtocolDebug("Bad msg");
return errSSLProtocol;
}
return err;
}
OSStatus
SSLClose(SSLContext *ctx)
{
OSStatus err = errSecSuccess;
sslHdskStateDebug("SSLClose");
if(ctx == NULL) {
return errSecParam;
}
if (ctx->negProtocolVersion >= SSL_Version_3_0)
err = SSLSendAlert(SSL_AlertLevelWarning, SSL_AlertCloseNotify, ctx);
if (err == 0)
err = SSLServiceWriteQueue(ctx);
SSLChangeHdskState(ctx, SSL_HdskStateGracefulClose);
if (err == errSecIO)
err = errSecSuccess;
return err;
}
OSStatus
SSLGetBufferedReadSize(SSLContextRef ctx,
size_t *bufSize)
{
if(ctx == NULL) {
return errSecParam;
}
if(ctx->receivedDataBuffer.data == NULL) {
*bufSize = 0;
}
else {
assert(ctx->receivedDataBuffer.length >= ctx->receivedDataPos);
*bufSize = ctx->receivedDataBuffer.length - ctx->receivedDataPos;
}
return errSecSuccess;
}