#include "tls_handshake_priv.h"
#include "sslSession.h"
#include "sslMemory.h"
#include "sslUtils.h"
#include "sslDebug.h"
#include "sslCipherSpecs.h"
#include <assert.h>
#include <string.h>
#include <stddef.h>
typedef struct
{
size_t sessionIDLen;
uint8_t sessionID[32];
tls_protocol_version negProtocolVersion;
tls_protocol_version reqProtocolVersion;
uint16_t cipherSuite;
uint16_t padding;
uint8_t masterSecret[48];
size_t ticketLen;
size_t certCount;
uint8_t data[1];
} ResumableSession;
#define SESSION_TICKET_ID "SESSION-TICKET"
int
SSLAddSessionData(const tls_handshake_t ctx)
{ int err;
size_t sessionDataLen;
tls_buffer sessionData;
ResumableSession *session;
size_t certCount;
SSLCertificate *cert;
uint8_t *certDest;
if (!ctx->sessionID.data && !ctx->sessionTicket.data)
return errSSLSessionNotFound;
sessionDataLen = offsetof(ResumableSession, data) + ctx->sessionTicket.length;
cert = ctx->peerCert;
certCount = 0;
while (cert)
{ ++certCount;
sessionDataLen += 4 + cert->derCert.length;
cert = cert->next;
}
if ((err = SSLAllocBuffer(&sessionData, sessionDataLen)))
return err;
session = (ResumableSession*)sessionData.data;
if(ctx->sessionID.data==NULL) {
session->sessionIDLen = strlen(SESSION_TICKET_ID);
memcpy(session->sessionID, SESSION_TICKET_ID, session->sessionIDLen);
} else {
session->sessionIDLen = ctx->sessionID.length;
memcpy(session->sessionID, ctx->sessionID.data, session->sessionIDLen);
}
session->negProtocolVersion = ctx->negProtocolVersion;
session->reqProtocolVersion = ctx->clientReqProtocol;
session->cipherSuite = ctx->selectedCipher;
memcpy(session->masterSecret, ctx->masterSecret, 48);
session->certCount = certCount;
session->padding = 0;
session->ticketLen = ctx->sessionTicket.length;
memcpy(session->data, ctx->sessionTicket.data,ctx->sessionTicket.length);
certDest = session->data + session->ticketLen;
cert = ctx->peerCert;
while (cert)
{ certDest = SSLEncodeInt(certDest, cert->derCert.length, 4);
memcpy(certDest, cert->derCert.data, cert->derCert.length);
certDest += cert->derCert.length;
cert = cert->next;
}
if(ctx->isServer)
err = ctx->callbacks->save_session_data(ctx->callback_ctx, ctx->sessionID, sessionData);
else
err = ctx->callbacks->save_session_data(ctx->callback_ctx, ctx->peerID, sessionData);
SSLFreeBuffer(&sessionData);
return err;
}
int
SSLDeleteSessionData(const tls_handshake_t ctx)
{ int err;
if (ctx->sessionID.data == 0)
return errSSLSessionNotFound;
err = ctx->callbacks->delete_session_data(ctx->callback_ctx, ctx->sessionID);
return err;
}
int
SSLRetrieveSessionTicket(
const tls_buffer sessionData,
tls_buffer *ticket)
{
ResumableSession *session;
session = (ResumableSession*) sessionData.data;
ticket->data = session->data;
ticket->length = session->ticketLen;
return errSSLSuccess;
}
int
SSLRetrieveSessionID(
const tls_buffer sessionData,
tls_buffer *identifier)
{
ResumableSession *session;
session = (ResumableSession*) sessionData.data;
identifier->data = session->sessionID;
identifier->length = session->sessionIDLen;
return errSSLSuccess;
}
int SSLServerValidateSessionData(const tls_buffer sessionData, tls_handshake_t ctx)
{
ResumableSession *session = (ResumableSession *)sessionData.data;
require(session->sessionIDLen == ctx->proposedSessionID.length, out);
require(memcmp(session->sessionID, ctx->proposedSessionID.data, ctx->proposedSessionID.length) == 0, out);
require(session->negProtocolVersion == ctx->negProtocolVersion, out);
require(cipherSuiteInSet(session->cipherSuite, ctx->enabledCipherSuites, ctx->numEnabledCipherSuites), out);
require(cipherSuiteInSet(session->cipherSuite, ctx->requestedCipherSuites, ctx->numRequestedCipherSuites), out);
ctx->selectedCipher = session->cipherSuite;
InitCipherSpecParams(ctx);
return 0;
out:
return errSSLSessionNotFound;
}
int SSLClientValidateSessionDataBefore(const tls_buffer sessionData, tls_handshake_t ctx)
{
ResumableSession *session = (ResumableSession *)sessionData.data;
require(ctx->maxProtocolVersion <= session->reqProtocolVersion, out);
require(session->negProtocolVersion <= ctx->maxProtocolVersion, out);
require(session->negProtocolVersion >= ctx->minProtocolVersion, out);
require(cipherSuiteInSet(session->cipherSuite, ctx->enabledCipherSuites, ctx->numEnabledCipherSuites), out);
return 0;
out:
return errSSLSessionNotFound;
}
int SSLClientValidateSessionDataAfter(const tls_buffer sessionData, tls_handshake_t ctx)
{
ResumableSession *session = (ResumableSession *)sessionData.data;
require(session->negProtocolVersion == ctx->negProtocolVersion, out);
require(session->cipherSuite == ctx->selectedCipher, out);
return 0;
out:
return errSSLProtocol;
}
int
SSLInstallSessionFromData(const tls_buffer sessionData, tls_handshake_t ctx)
{ int err;
ResumableSession *session;
uint8_t *storedCertProgress;
SSLCertificate *cert;
SSLCertificate *lastCert = NULL;
size_t certCount;
size_t certLen;
session = (ResumableSession*)sessionData.data;
assert(ctx->selectedCipher == session->cipherSuite);
assert(ctx->negProtocolVersion == session->negProtocolVersion);
memcpy(ctx->masterSecret, session->masterSecret, 48);
storedCertProgress = session->data + session->ticketLen;
certCount = session->certCount;
while (certCount--)
{
cert = (SSLCertificate *)sslMalloc(sizeof(SSLCertificate));
if(cert == NULL) {
return errSSLAllocate;
}
cert->next = 0;
certLen = SSLDecodeInt(storedCertProgress, 4);
storedCertProgress += 4;
if ((err = SSLAllocBuffer(&cert->derCert, certLen)))
{
sslFree(cert);
return err;
}
memcpy(cert->derCert.data, storedCertProgress, certLen);
storedCertProgress += certLen;
if (lastCert == 0) {
SSLFreeCertificates(ctx->peerCert);
ctx->peerCert = cert;
}
else
lastCert->next = cert;
lastCert = cert;
}
return errSSLSuccess;
}