#ifndef _SSL_H_
#include "ssl.h"
#endif
#ifndef _SSLCTX_H_
#include "sslctx.h"
#endif
#ifndef _SSLSESS_H_
#include "sslsess.h"
#endif
#ifndef _SSLALLOC_H_
#include "sslalloc.h"
#endif
#ifndef _SSLUTIL_H_
#include "sslutil.h"
#endif
#ifndef _SSL_DEBUG_H_
#include "sslDebug.h"
#endif
#ifndef _CIPHER_SPECS_H_
#include "cipherSpecs.h"
#endif
#include "appleSession.h"
#include <assert.h>
#include <string.h>
#include <stddef.h>
typedef struct
{ int sessionIDLen;
UInt8 sessionID[32];
SSLProtocolVersion protocolVersion;
UInt16 cipherSuite;
UInt16 padding;
UInt8 masterSecret[48];
int certCount;
UInt8 certs[1];
} ResumableSession;
SSLErr
SSLAddSessionData(const SSLContext *ctx)
{ SSLErr err;
uint32 sessionIDLen;
SSLBuffer sessionID;
ResumableSession *session;
int certCount;
SSLCertificate *cert;
uint8 *certDest;
if (ctx->peerID.data == 0)
return SSLSessionNotFoundErr;
sessionIDLen = offsetof(ResumableSession, certs);
cert = ctx->peerCert;
certCount = 0;
while (cert)
{ ++certCount;
sessionIDLen += 4 + cert->derCert.length;
cert = cert->next;
}
if ((err = SSLAllocBuffer(&sessionID, sessionIDLen, &ctx->sysCtx)) != 0)
return err;
session = (ResumableSession*)sessionID.data;
session->sessionIDLen = ctx->sessionID.length;
memcpy(session->sessionID, ctx->sessionID.data, session->sessionIDLen);
session->protocolVersion = ctx->negProtocolVersion;
session->cipherSuite = ctx->selectedCipher;
memcpy(session->masterSecret, ctx->masterSecret, 48);
session->certCount = certCount;
session->padding = 0;
certDest = session->certs;
cert = ctx->peerCert;
while (cert)
{ certDest = SSLEncodeInt(certDest, cert->derCert.length, 4);
memcpy(certDest, cert->derCert.data, cert->derCert.length);
certDest += cert->derCert.length;
cert = cert->next;
}
err = sslAddSession(ctx->peerID, sessionID);
SSLFreeBuffer(&sessionID, &ctx->sysCtx);
return err;
}
SSLErr
SSLGetSessionData(SSLBuffer *sessionData, const SSLContext *ctx)
{ SSLErr err;
if (ctx->peerID.data == 0)
return ERR(SSLSessionNotFoundErr);
sessionData->data = 0;
err = sslGetSession(ctx->peerID, sessionData);
if (sessionData->data == 0)
return ERR(SSLSessionNotFoundErr);
return err;
}
SSLErr
SSLDeleteSessionData(const SSLContext *ctx)
{ SSLErr err;
if (ctx->peerID.data == 0)
return SSLSessionNotFoundErr;
err = sslDeleteSession(ctx->peerID);
return err;
}
SSLErr
SSLRetrieveSessionID(
const SSLBuffer sessionData,
SSLBuffer *identifier,
const SSLContext *ctx)
{ SSLErr err;
ResumableSession *session;
session = (ResumableSession*) sessionData.data;
if ((err = SSLAllocBuffer(identifier, session->sessionIDLen, &ctx->sysCtx)) != 0)
return err;
memcpy(identifier->data, session->sessionID, session->sessionIDLen);
return SSLNoErr;
}
SSLErr
SSLRetrieveSessionProtocolVersion(
const SSLBuffer sessionData,
SSLProtocolVersion *version,
const SSLContext *ctx)
{ ResumableSession *session;
session = (ResumableSession*) sessionData.data;
*version = session->protocolVersion;
return SSLNoErr;
}
#define ALLOW_CIPHERSPEC_CHANGE 1
SSLErr
SSLInstallSessionFromData(const SSLBuffer sessionData, SSLContext *ctx)
{ SSLErr err;
ResumableSession *session;
uint8 *storedCertProgress;
SSLCertificate *cert, *lastCert;
int certCount;
uint32 certLen;
session = (ResumableSession*)sessionData.data;
CASSERT(ctx->negProtocolVersion == session->protocolVersion);
if(ctx->negProtocolVersion == SSL_Version_2_0) {
if(ctx->protocolSide == SSL_ClientSide) {
assert(ctx->selectedCipher == 0);
ctx->selectedCipher = session->cipherSuite;
}
else {
if(ctx->selectedCipher != session->cipherSuite) {
errorLog2("+++SSL2: CipherSpec change from %d to %d on session "
"resume\n",
session->cipherSuite, ctx->selectedCipher);
return SSLProtocolErr;
}
}
}
else {
assert(ctx->selectedCipher != 0);
if(ctx->selectedCipher != session->cipherSuite) {
#if ALLOW_CIPHERSPEC_CHANGE
dprintf2("+++WARNING: CipherSpec change from %d to %d on session resume\n",
session->cipherSuite, ctx->selectedCipher);
#else
errorLog2("+++SSL: CipherSpec change from %d to %d on session resume\n",
session->cipherSuite, ctx->selectedCipher);
return SSLProtocolErr;
#endif
}
}
if ((err = FindCipherSpec(ctx)) != 0) {
return err;
}
memcpy(ctx->masterSecret, session->masterSecret, 48);
lastCert = 0;
storedCertProgress = session->certs;
certCount = session->certCount;
while (certCount--)
{
cert = (SSLCertificate *)sslMalloc(sizeof(SSLCertificate));
if(cert == NULL) {
return SSLMemoryErr;
}
cert->next = 0;
certLen = SSLDecodeInt(storedCertProgress, 4);
storedCertProgress += 4;
if ((err = SSLAllocBuffer(&cert->derCert, certLen, &ctx->sysCtx)) != 0)
{
sslFree(cert);
return err;
}
memcpy(cert->derCert.data, storedCertProgress, certLen);
storedCertProgress += certLen;
if (lastCert == 0)
ctx->peerCert = cert;
else
lastCert->next = cert;
lastCert = cert;
}
return SSLNoErr;
}