sslHandshakeHello.cpp [plain text]
#include "sslContext.h"
#include "sslHandshake.h"
#include "sslMemory.h"
#include "sslSession.h"
#include "sslUtils.h"
#include "sslDebug.h"
#include "appleCdsa.h"
#include "sslDigests.h"
#include "cipherSpecs.h"
#include <string.h>
static OSStatus SSLEncodeRandom(unsigned char *p, SSLContext *ctx);
#define SSL_IE_NULL_RESUME_BUG 1
#if SSL_IE_NULL_RESUME_BUG
#define SSL_NULL_ID_LEN 32
#endif
OSStatus
SSLEncodeServerHello(SSLRecord &serverHello, SSLContext *ctx)
{ OSStatus err;
UInt8 *charPtr;
int sessionIDLen;
sessionIDLen = 0;
if (ctx->sessionID.data != 0)
sessionIDLen = (UInt8)ctx->sessionID.length;
#if SSL_IE_NULL_RESUME_BUG
if(sessionIDLen == 0) {
sessionIDLen = SSL_NULL_ID_LEN;
}
#endif
assert(ctx->negProtocolVersion != SSL_Version_Undetermined);
assert(ctx->negProtocolVersion != SSL_Version_2_0);
sslLogNegotiateDebug("===SSL3 server: sending version %d_%d",
ctx->negProtocolVersion >> 8, ctx->negProtocolVersion & 0xff);
sslLogNegotiateDebug("...sessionIDLen = %d", sessionIDLen);
serverHello.protocolVersion = ctx->negProtocolVersion;
serverHello.contentType = SSL_RecordTypeHandshake;
if ((err = SSLAllocBuffer(serverHello.contents, 42 + sessionIDLen, ctx)) != 0)
return err;
charPtr = serverHello.contents.data;
*charPtr++ = SSL_HdskServerHello;
charPtr = SSLEncodeInt(charPtr, 38 + sessionIDLen, 3);
charPtr = SSLEncodeInt(charPtr, serverHello.protocolVersion, 2);
if ((err = SSLEncodeRandom(charPtr, ctx)) != 0)
return err;
memcpy(ctx->serverRandom, charPtr, SSL_CLIENT_SRVR_RAND_SIZE);
charPtr += SSL_CLIENT_SRVR_RAND_SIZE;
*(charPtr++) = (UInt8)sessionIDLen;
#if SSL_IE_NULL_RESUME_BUG
if(ctx->sessionID.data != NULL) {
memcpy(charPtr, ctx->sessionID.data, sessionIDLen);
}
else {
SSLBuffer rb;
rb.data = charPtr;
rb.length = SSL_NULL_ID_LEN;
sslRand(ctx, &rb);
}
#else
if (sessionIDLen > 0)
memcpy(charPtr, ctx->sessionID.data, sessionIDLen);
#endif
charPtr += sessionIDLen;
charPtr = SSLEncodeInt(charPtr, ctx->selectedCipher, 2);
*(charPtr++) = 0;
sslLogNegotiateDebug("ssl3: server specifying cipherSuite 0x%lx",
(UInt32)ctx->selectedCipher);
assert(charPtr == serverHello.contents.data + serverHello.contents.length);
return noErr;
}
OSStatus
SSLProcessServerHello(SSLBuffer message, SSLContext *ctx)
{ OSStatus err;
SSLProtocolVersion protocolVersion, negVersion;
unsigned int sessionIDLen;
UInt8 *p;
assert(ctx->protocolSide == SSL_ClientSide);
if (message.length < 38 || message.length > 70) {
sslErrorLog("SSLProcessServerHello: msg len error\n");
return errSSLProtocol;
}
p = message.data;
protocolVersion = (SSLProtocolVersion)SSLDecodeInt(p, 2);
p += 2;
err = sslVerifyProtVersion(ctx, protocolVersion, &negVersion);
if(err) {
return err;
}
ctx->negProtocolVersion = negVersion;
switch(negVersion) {
case SSL_Version_3_0:
ctx->sslTslCalls = &Ssl3Callouts;
break;
case TLS_Version_1_0:
ctx->sslTslCalls = &Tls1Callouts;
break;
default:
return errSSLNegotiation;
}
sslLogNegotiateDebug("===SSL3 client: negVersion is %d_%d",
(negVersion >> 8) & 0xff, negVersion & 0xff);
memcpy(ctx->serverRandom, p, 32);
p += 32;
sessionIDLen = *p++;
if (message.length != 38 + sessionIDLen) {
sslErrorLog("SSLProcessServerHello: msg len error 2\n");
return errSSLProtocol;
}
if (sessionIDLen > 0 && ctx->peerID.data != 0)
{
err = SSLAllocBuffer(ctx->sessionID, sessionIDLen, ctx);
if (err == 0)
memcpy(ctx->sessionID.data, p, sessionIDLen);
}
p += sessionIDLen;
ctx->selectedCipher = (UInt16)SSLDecodeInt(p,2);
sslLogNegotiateDebug("===ssl3: server requests cipherKind %d",
(unsigned)ctx->selectedCipher);
p += 2;
if ((err = FindCipherSpec(ctx)) != 0) {
return err;
}
if (*p++ != 0)
return unimpErr;
assert(p == message.data + message.length);
return noErr;
}
OSStatus
SSLEncodeClientHello(SSLRecord &clientHello, SSLContext *ctx)
{
unsigned length, i;
OSStatus err;
unsigned char *p;
SSLBuffer sessionIdentifier;
UInt16 sessionIDLen;
assert(ctx->protocolSide == SSL_ClientSide);
sessionIDLen = 0;
if (ctx->resumableSession.data != 0)
{ if ((err = SSLRetrieveSessionID(ctx->resumableSession,
&sessionIdentifier, ctx)) != 0)
{ return err;
}
sessionIDLen = sessionIdentifier.length;
}
length = 39 + 2*(ctx->numValidCipherSpecs) + sessionIDLen;
err = sslGetMaxProtVersion(ctx, &clientHello.protocolVersion);
if(err) {
return err;
}
clientHello.contentType = SSL_RecordTypeHandshake;
if ((err = SSLAllocBuffer(clientHello.contents, length + 4, ctx)) != 0)
return err;
p = clientHello.contents.data;
*p++ = SSL_HdskClientHello;
p = SSLEncodeInt(p, length, 3);
p = SSLEncodeInt(p, clientHello.protocolVersion, 2);
sslLogNegotiateDebug("===SSL3 client: proclaiming max protocol "
"%d_%d capable ONLY",
clientHello.protocolVersion >> 8, clientHello.protocolVersion & 0xff);
if ((err = SSLEncodeRandom(p, ctx)) != 0)
{ SSLFreeBuffer(clientHello.contents, ctx);
return err;
}
memcpy(ctx->clientRandom, p, SSL_CLIENT_SRVR_RAND_SIZE);
p += 32;
*p++ = sessionIDLen;
if (sessionIDLen > 0)
{ memcpy(p, sessionIdentifier.data, sessionIDLen);
if ((err = SSLFreeBuffer(sessionIdentifier, ctx)) != 0)
return err;
}
p += sessionIDLen;
p = SSLEncodeInt(p, 2*(ctx->numValidCipherSpecs), 2);
for (i = 0; i<ctx->numValidCipherSpecs; ++i)
p = SSLEncodeInt(p, ctx->validCipherSpecs[i].cipherSpec, 2);
*p++ = 1;
*p++ = 0;
assert(p == clientHello.contents.data + clientHello.contents.length);
if ((err = SSLInitMessageHashes(ctx)) != 0)
return err;
return noErr;
}
OSStatus
SSLProcessClientHello(SSLBuffer message, SSLContext *ctx)
{ OSStatus err;
SSLProtocolVersion negVersion;
UInt16 cipherListLen, cipherCount, desiredSpec, cipherSpec;
UInt8 sessionIDLen, compressionCount;
UInt8 *charPtr;
unsigned i;
if (message.length < 41) {
sslErrorLog("SSLProcessClientHello: msg len error 1\n");
return errSSLProtocol;
}
charPtr = message.data;
ctx->clientReqProtocol = (SSLProtocolVersion)SSLDecodeInt(charPtr, 2);
charPtr += 2;
err = sslVerifyProtVersion(ctx, ctx->clientReqProtocol, &negVersion);
if(err) {
return err;
}
switch(negVersion) {
case SSL_Version_3_0:
ctx->sslTslCalls = &Ssl3Callouts;
break;
case TLS_Version_1_0:
ctx->sslTslCalls = &Tls1Callouts;
break;
default:
return errSSLNegotiation;
}
ctx->negProtocolVersion = negVersion;
sslLogNegotiateDebug("===SSL3 server: negVersion is %d_%d",
negVersion >> 8, negVersion & 0xff);
memcpy(ctx->clientRandom, charPtr, SSL_CLIENT_SRVR_RAND_SIZE);
charPtr += 32;
sessionIDLen = *(charPtr++);
if (message.length < (unsigned)(41 + sessionIDLen)) {
sslErrorLog("SSLProcessClientHello: msg len error 2\n");
return errSSLProtocol;
}
if (sessionIDLen > 0 && ctx->peerID.data != 0)
{
err = SSLAllocBuffer(ctx->sessionID, sessionIDLen, ctx);
if (err == 0)
memcpy(ctx->sessionID.data, charPtr, sessionIDLen);
}
charPtr += sessionIDLen;
cipherListLen = (UInt16)SSLDecodeInt(charPtr, 2);
charPtr += 2;
if ((cipherListLen & 1) ||
(cipherListLen < 2) ||
(message.length < (unsigned)(39 + sessionIDLen + cipherListLen))) {
sslErrorLog("SSLProcessClientHello: msg len error 3\n");
return errSSLProtocol;
}
cipherCount = cipherListLen/2;
cipherSpec = 0xFFFF;
while (cipherSpec == 0xFFFF && cipherCount--)
{ desiredSpec = (UInt16)SSLDecodeInt(charPtr, 2);
charPtr += 2;
for (i = 0; i <ctx->numValidCipherSpecs; i++)
{ if (ctx->validCipherSpecs[i].cipherSpec == desiredSpec)
{ cipherSpec = desiredSpec;
break;
}
}
}
if (cipherSpec == 0xFFFF)
return errSSLNegotiation;
charPtr += 2 * cipherCount;
ctx->selectedCipher = cipherSpec;
if ((err = FindCipherSpec(ctx)) != 0) {
return err;
}
sslLogNegotiateDebug("ssl3 server: selecting cipherKind 0x%x", (unsigned)ctx->selectedCipher);
compressionCount = *(charPtr++);
if ((compressionCount < 1) ||
(message.length <
(unsigned)(38 + sessionIDLen + cipherListLen + compressionCount))) {
sslErrorLog("SSLProcessClientHello: msg len error 4\n");
return errSSLProtocol;
}
if ((err = SSLInitMessageHashes(ctx)) != 0)
return err;
return noErr;
}
static OSStatus
SSLEncodeRandom(unsigned char *p, SSLContext *ctx)
{ SSLBuffer randomData;
OSStatus err;
UInt32 time;
if ((err = sslTime(&time)) != 0)
return err;
SSLEncodeInt(p, time, 4);
randomData.data = p+4;
randomData.length = 28;
if((err = sslRand(ctx, &randomData)) != 0)
return err;
return noErr;
}
OSStatus
SSLInitMessageHashes(SSLContext *ctx)
{ OSStatus err;
if ((err = CloseHash(SSLHashSHA1, ctx->shaState, ctx)) != 0)
return err;
if ((err = CloseHash(SSLHashMD5, ctx->md5State, ctx)) != 0)
return err;
if ((err = ReadyHash(SSLHashSHA1, ctx->shaState, ctx)) != 0)
return err;
if ((err = ReadyHash(SSLHashMD5, ctx->md5State, ctx)) != 0)
return err;
return noErr;
}