#include <stdio.h>
#include "tlsCallbacks.h"
#include "sslContext.h"
#include "sslCrypto.h"
#include "sslDebug.h"
#include "sslMemory.h"
#include <Security/SecCertificate.h>
#include <Security/SecCertificatePriv.h>
#include "utilities/SecCFRelease.h"
#include <tls_helpers.h>
#include <tls_cache.h>
static
int tls_handshake_write_callback(tls_handshake_ctx_t ctx, const SSLBuffer data, uint8_t content_type)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p (rec.len=%zd, ct=%d, d[0]=%d)\n", myCtx, data.length, content_type, data.data[0]);
SSLRecord rec;
rec.contents=data;
rec.contentType=content_type;
return myCtx->recFuncs->write(myCtx->recCtx,rec);
}
static int
tls_handshake_message_callback(tls_handshake_ctx_t ctx, tls_handshake_message_t event)
{
SSLContext *myCtx = (SSLContext *)ctx;
const tls_buffer *npn_data;
const tls_buffer *alpn_data;
int err = 0;
sslDebugLog("%p, message = %d\n", ctx, event);
switch(event) {
case tls_handshake_message_certificate_request:
assert(myCtx->protocolSide == kSSLClientSide);
myCtx->clientCertState = kSSLClientCertRequested;
myCtx->clientAuthTypes = tls_handshake_get_peer_acceptable_client_auth_type(myCtx->hdsk, &myCtx->numAuthTypes);
if (myCtx->breakOnCertRequest && (myCtx->localCertArray==NULL)) {
myCtx->signalCertRequest = true;
err = errSSLClientCertRequested;
}
break;
case tls_handshake_message_client_hello:
myCtx->peerSigAlgs = tls_handshake_get_peer_signature_algorithms(myCtx->hdsk, &myCtx->numPeerSigAlgs);
if (myCtx->breakOnClientHello) {
err = errSSLClientHelloReceived;
}
break;
case tls_handshake_message_server_hello:
myCtx->serverHelloReceived = true;
alpn_data = tls_handshake_get_peer_alpn_data(myCtx->hdsk);
if (alpn_data && myCtx->alpnFunc != NULL) {
myCtx->alpnFunc(myCtx, myCtx->alpnFuncInfo, alpn_data->data, alpn_data->length);
} else {
npn_data = tls_handshake_get_peer_npn_data(myCtx->hdsk);
if(npn_data) {
myCtx->npnFunc(myCtx, myCtx->npnFuncInfo, npn_data->data, npn_data->length);
}
}
myCtx->peerSigAlgs = tls_handshake_get_peer_signature_algorithms(myCtx->hdsk, &myCtx->numPeerSigAlgs);
break;
case tls_handshake_message_certificate:
err = tls_helper_set_peer_pubkey(myCtx->hdsk);
if(!err && (myCtx->protocolSide == kSSLServerSide)) {
err = tls_verify_peer_cert(myCtx);
}
break;
case tls_handshake_message_server_hello_done:
err = tls_verify_peer_cert(myCtx);
break;
case tls_handshake_message_NPN_encrypted_extension:
npn_data = tls_handshake_get_peer_npn_data(myCtx->hdsk);
if(npn_data)
myCtx->npnFunc(myCtx, myCtx->npnFuncInfo, npn_data->data, npn_data->length);
break;
case tls_handshake_message_certificate_status:
break;
default:
break;
}
return err;
}
static void
tls_handshake_ready_callback(tls_handshake_ctx_t ctx, bool write, bool ready)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p %s ready=%d\n", myCtx, write?"write":"read", ready);
if(write) {
myCtx->writeCipher_ready=ready?1:0;
} else {
myCtx->readCipher_ready=ready?1:0;
if(ready) {
SSLChangeHdskState(myCtx, SSL_HdskStateReady);
} else {
SSLChangeHdskState(myCtx, SSL_HdskStatePending);
}
}
}
static int
tls_handshake_set_retransmit_timer_callback(tls_handshake_ctx_t ctx, int attempt)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p attempt=%d\n", ctx, attempt);
if(attempt) {
myCtx->timeout_deadline = CFAbsoluteTimeGetCurrent()+((1<<(attempt-1))*myCtx->timeout_duration);
} else {
myCtx->timeout_deadline = 0; }
return 0;
}
static int
tls_handshake_init_pending_cipher_callback(tls_handshake_ctx_t ctx,
uint16_t selectedCipher,
bool server,
SSLBuffer key)
{
sslDebugLog("%p, cipher=%04x, server=%d\n", ctx, selectedCipher, server);
SSLContext *myCtx = (SSLContext *)ctx;
return myCtx->recFuncs->initPendingCiphers(myCtx->recCtx, selectedCipher, server, key);
}
static int
tls_handshake_advance_write_callback(tls_handshake_ctx_t ctx)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p\n", myCtx);
bool split = (myCtx->oneByteRecordEnable && (myCtx->negProtocolVersion<=TLS_Version_1_0));
myCtx->recFuncs->setOption(myCtx->recCtx, kSSLRecordOptionSendOneByteRecord, split);
return myCtx->recFuncs->advanceWriteCipher(myCtx->recCtx);
}
static
int tls_handshake_rollback_write_callback(tls_handshake_ctx_t ctx)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p\n", myCtx);
return myCtx->recFuncs->rollbackWriteCipher(myCtx->recCtx);
}
static
int tls_handshake_advance_read_cipher_callback(tls_handshake_ctx_t ctx)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p\n", myCtx);
return myCtx->recFuncs->advanceReadCipher(myCtx->recCtx);
}
static
int tls_handshake_set_protocol_version_callback(tls_handshake_ctx_t ctx,
tls_protocol_version protocolVersion)
{
SSLContext *myCtx = (SSLContext *)ctx;
myCtx->negProtocolVersion = protocolVersion;
return myCtx->recFuncs->setProtocolVersion(myCtx->recCtx, protocolVersion);
}
static int
_buildConfigurationSpecificSessionCacheKey(SSLContext *myCtx, SSLBuffer *sessionKey, SSLBuffer *outputBuffer)
{
SSLBuffer configurationBuffer;
configurationBuffer.data = NULL;
configurationBuffer.length = 0;
int err = SSLGetSessionConfigurationIdentifier(myCtx, &configurationBuffer);
if (err != errSecSuccess) {
return err;
}
outputBuffer->length = configurationBuffer.length + sessionKey->length;
outputBuffer->data = (uint8_t *) malloc(outputBuffer->length);
if (outputBuffer->data == NULL) {
free(configurationBuffer.data);
return errSecAllocate;
}
memcpy(outputBuffer->data, configurationBuffer.data, configurationBuffer.length);
memcpy(outputBuffer->data + configurationBuffer.length, sessionKey->data, sessionKey->length);
free(configurationBuffer.data);
return errSecSuccess;
}
static int
tls_handshake_save_session_data_callback(tls_handshake_ctx_t ctx, SSLBuffer sessionKey, SSLBuffer sessionData)
{
int err = errSSLSessionNotFound;
SSLContext *myCtx = (SSLContext *)ctx;
if (myCtx->cache == NULL) {
return errSSLSessionNotFound;
}
SSLBuffer configurationSpecificKey;
configurationSpecificKey.data = NULL;
configurationSpecificKey.length = 0;
err = _buildConfigurationSpecificSessionCacheKey(myCtx, &sessionKey, &configurationSpecificKey);
if (err != errSecSuccess) {
return err;
}
sslDebugLog("%s: %p, key len=%zd, k[0]=%02x, data len=%zd\n", __FUNCTION__, myCtx, configurationSpecificKey.length, configurationSpecificKey.data[0], sessionData.length);
err = tls_cache_save_session_data(myCtx->cache, &configurationSpecificKey, &sessionData, myCtx->sessionCacheTimeout);
free(configurationSpecificKey.data);
return err;
}
static int
tls_handshake_load_session_data_callback(tls_handshake_ctx_t ctx, SSLBuffer sessionKey, SSLBuffer *sessionData)
{
SSLContext *myCtx = (SSLContext *)ctx;
int err = errSSLSessionNotFound;
SSLFreeBuffer(&myCtx->resumableSession);
if (myCtx->cache == NULL) {
return errSSLSessionNotFound;
}
SSLBuffer configurationSpecificKey;
configurationSpecificKey.data = NULL;
configurationSpecificKey.length = 0;
err = _buildConfigurationSpecificSessionCacheKey(myCtx, &sessionKey, &configurationSpecificKey);
if (err != errSecSuccess) {
return err;
}
err = tls_cache_load_session_data(myCtx->cache, &configurationSpecificKey, &myCtx->resumableSession);
sslDebugLog("%p, key len=%zd, data len=%zd, err=%d\n", ctx, configurationSpecificKey.length, sessionData->length, err);
*sessionData = myCtx->resumableSession;
free(configurationSpecificKey.data);
return err;
}
static int
tls_handshake_delete_session_data_callback(tls_handshake_ctx_t ctx, SSLBuffer sessionKey)
{
int err = errSSLSessionNotFound;
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p, key len=%zd k[0]=%02x\n", ctx, sessionKey.length, sessionKey.data[0]);
if(myCtx->cache) {
err = tls_cache_delete_session_data(myCtx->cache, &sessionKey);
}
return err;
}
static int
tls_handshake_delete_all_sessions_callback(tls_handshake_ctx_t ctx)
{
SSLContext *myCtx = (SSLContext *)ctx;
sslDebugLog("%p\n", ctx);
if(myCtx->cache) {
tls_cache_empty(myCtx->cache);
}
return 0;
}
tls_handshake_callbacks_t tls_handshake_callbacks = {
.write = tls_handshake_write_callback,
.message = tls_handshake_message_callback,
.ready = tls_handshake_ready_callback,
.set_retransmit_timer = tls_handshake_set_retransmit_timer_callback,
.save_session_data = tls_handshake_save_session_data_callback,
.load_session_data = tls_handshake_load_session_data_callback,
.delete_session_data = tls_handshake_delete_session_data_callback,
.delete_all_sessions = tls_handshake_delete_all_sessions_callback,
.init_pending_cipher = tls_handshake_init_pending_cipher_callback,
.advance_write_cipher = tls_handshake_advance_write_callback,
.rollback_write_cipher = tls_handshake_rollback_write_callback,
.advance_read_cipher = tls_handshake_advance_read_cipher_callback,
.set_protocol_version = tls_handshake_set_protocol_version_callback,
};