#include "tls_handshake_priv.h"
#include "sslSession.h"
#include "sslMemory.h"
#include "sslUtils.h"
#include "sslDebug.h"
#include "sslCipherSpecs.h"
#include "sslAlertMessage.h"
#include <assert.h>
#include <string.h>
#include <stddef.h>
typedef struct
{
size_t sessionIDLen;
uint8_t sessionID[32];
tls_protocol_version negProtocolVersion;
tls_protocol_version reqProtocolVersion;
uint16_t cipherSuite;
uint16_t padding;
uint8_t masterSecret[48];
size_t ticketLen;
size_t ocspResponseLen;
size_t sctListLen;
size_t certListLen;
bool sessExtMSSet;
uint8_t data[0];
} ResumableSession;
#define SESSION_TICKET_ID "SESSION-TICKET"
static bool
SSLSessionDataCheck(const tls_buffer sessionData)
{
ResumableSession *session;
if(sessionData.length < sizeof(ResumableSession))
return false;
session = (ResumableSession *)sessionData.data;
return (sessionData.length == (sizeof(ResumableSession) + session->ticketLen + session->ocspResponseLen + session->sctListLen + session->certListLen));
}
int
SSLAddSessionData(const tls_handshake_t ctx)
{ int err;
size_t sessionDataLen;
tls_buffer sessionData;
ResumableSession *session;
size_t sctListLen;
size_t certListLen;
uint8_t *p;
if (!ctx->sessionID.data && !ctx->sessionTicket.data)
return errSSLSessionNotFound;
sctListLen = SSLEncodedBufferListSize(ctx->sct_list, 2);
certListLen = SSLEncodedBufferListSize((tls_buffer_list_t *)ctx->peerCert, 3);
sessionDataLen = sizeof(ResumableSession) + ctx->sessionTicket.length
+ ctx->ocsp_response.length + sctListLen + certListLen;
if ((err = SSLAllocBuffer(&sessionData, sessionDataLen)))
return err;
session = (ResumableSession*)sessionData.data;
if(ctx->sessionID.data==NULL) {
session->sessionIDLen = strlen(SESSION_TICKET_ID);
memcpy(session->sessionID, SESSION_TICKET_ID, session->sessionIDLen);
} else {
session->sessionIDLen = ctx->sessionID.length;
memcpy(session->sessionID, ctx->sessionID.data, session->sessionIDLen);
}
session->negProtocolVersion = ctx->negProtocolVersion;
session->reqProtocolVersion = ctx->clientReqProtocol;
session->cipherSuite = ctx->selectedCipher;
memcpy(session->masterSecret, ctx->masterSecret, 48);
session->ticketLen = ctx->sessionTicket.length;
session->ocspResponseLen = ctx->ocsp_response.length;
session->sctListLen = sctListLen;
session->certListLen = certListLen;
session->padding = 0;
p = session->data;
memcpy(p, ctx->sessionTicket.data, ctx->sessionTicket.length); p += ctx->sessionTicket.length;
memcpy(p, ctx->ocsp_response.data, ctx->ocsp_response.length); p += ctx->ocsp_response.length;
p = SSLEncodeBufferList(ctx->sct_list, 2, p);
p = SSLEncodeBufferList((tls_buffer_list_t *)ctx->peerCert, 3, p);
if(!SSLSessionDataCheck(sessionData)) {
SSLFreeBuffer(&sessionData);
return errSSLInternal;
}
if (ctx->extMSEnabled && ctx->extMSReceived)
session->sessExtMSSet = true;
else
session->sessExtMSSet = false;
if(ctx->isServer)
err = ctx->callbacks->save_session_data(ctx->callback_ctx, ctx->sessionID, sessionData);
else
err = ctx->callbacks->save_session_data(ctx->callback_ctx, ctx->peerID, sessionData);
SSLFreeBuffer(&sessionData);
return err;
}
int
SSLDeleteSessionData(const tls_handshake_t ctx)
{ int err;
if (ctx->sessionID.data == 0)
return errSSLSessionNotFound;
err = ctx->callbacks->delete_session_data(ctx->callback_ctx, ctx->sessionID);
return err;
}
int
SSLRetrieveSessionTicket(
const tls_buffer sessionData,
tls_buffer *ticket)
{
ResumableSession *session;
if(!SSLSessionDataCheck(sessionData))
return errSSLInternal;
session = (ResumableSession*) sessionData.data;
ticket->data = session->data;
ticket->length = session->ticketLen;
return errSSLSuccess;
}
int
SSLRetrieveSessionID(
const tls_buffer sessionData,
tls_buffer *identifier)
{
ResumableSession *session;
if(!SSLSessionDataCheck(sessionData))
return errSSLInternal;
session = (ResumableSession*) sessionData.data;
identifier->data = session->sessionID;
identifier->length = session->sessionIDLen;
return errSSLSuccess;
}
int SSLServerValidateSessionData(const tls_buffer sessionData, tls_handshake_t ctx)
{
ResumableSession *session = (ResumableSession *)sessionData.data;
if(!SSLSessionDataCheck(sessionData))
return errSSLInternal;
require(session->sessionIDLen == ctx->proposedSessionID.length, out);
require(memcmp(session->sessionID, ctx->proposedSessionID.data, ctx->proposedSessionID.length) == 0, out);
require(session->negProtocolVersion == ctx->negProtocolVersion, out);
require(cipherSuiteInSet(session->cipherSuite, ctx->enabledCipherSuites, ctx->numEnabledCipherSuites), out);
require(cipherSuiteInSet(session->cipherSuite, ctx->requestedCipherSuites, ctx->numRequestedCipherSuites), out);
if (session->sessExtMSSet) {
if (!ctx->extMSReceived) {
SSLFatalSessionAlert(SSL_AlertHandshakeFail, ctx);
return errSSLFatalAlert;
}
} else {
if (ctx->extMSReceived) {
goto out;
}
}
ctx->selectedCipher = session->cipherSuite;
InitCipherSpecParams(ctx);
return 0;
out:
return errSSLSessionNotFound;
}
int SSLClientValidateSessionDataBefore(const tls_buffer sessionData, tls_handshake_t ctx)
{
ResumableSession *session = (ResumableSession *)sessionData.data;
if(!SSLSessionDataCheck(sessionData))
return errSSLInternal;
require(ctx->maxProtocolVersion <= session->reqProtocolVersion, out);
require(session->negProtocolVersion <= ctx->maxProtocolVersion, out);
require(session->negProtocolVersion >= ctx->minProtocolVersion, out);
require(cipherSuiteInSet(session->cipherSuite, ctx->enabledCipherSuites, ctx->numEnabledCipherSuites), out);
return 0;
out:
return errSSLSessionNotFound;
}
int SSLClientValidateSessionDataAfter(const tls_buffer sessionData, tls_handshake_t ctx)
{
ResumableSession *session = (ResumableSession *)sessionData.data;
if(!SSLSessionDataCheck(sessionData))
return errSSLInternal;
require(session->negProtocolVersion == ctx->negProtocolVersion, out);
require(session->cipherSuite == ctx->selectedCipher, out);
require(session->sessExtMSSet == ctx->extMSReceived, out);
return 0;
out:
return errSSLProtocol;
}
int
SSLInstallSessionFromData(const tls_buffer sessionData, tls_handshake_t ctx)
{ int err;
ResumableSession *session;
uint8_t *p;
if(!SSLSessionDataCheck(sessionData))
return errSSLInternal;
session = (ResumableSession*)sessionData.data;
assert(ctx->selectedCipher == session->cipherSuite);
assert(ctx->negProtocolVersion == session->negProtocolVersion);
memcpy(ctx->masterSecret, session->masterSecret, 48);
p = session->data + session->ticketLen;
SSLFreeBuffer(&ctx->ocsp_response);
ctx->ocsp_response_received = false;
if(session->ocspResponseLen) {
ctx->ocsp_response_received = true;
SSLCopyBufferFromData(p, session->ocspResponseLen, &ctx->ocsp_response);
}
p += session->ocspResponseLen;
tls_free_buffer_list(ctx->sct_list);
ctx->sct_list = NULL;
if(session->sctListLen) {
if((err=SSLDecodeBufferList(p, session->sctListLen, 2, &ctx->sct_list))) {
return err;
}
}
p += session->sctListLen;
SSLFreeCertificates(ctx->peerCert);
ctx->peerCert = NULL;
if(session->certListLen) {
if((err=SSLDecodeBufferList(p, session->certListLen, 3, (tls_buffer_list_t **)&ctx->peerCert))) {
return err;
}
}
return errSSLSuccess;
}