handshake_server_test.c [plain text]
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <AssertMacros.h>
#include <CommonCrypto/CommonRandomSPI.h>
#define CCRNGSTATE() ccDRBGGetRngState()
int handshake_server_test(void);
#include <tls_handshake.h>
#include <tls_record.h>
#include <tls_stream_parser.h>
#include <tls_ciphersuites.h>
#include <Security/SecCertificate.h>
#include <Security/SecKeyPriv.h>
#include <Security/SecIdentity.h>
#include "appleSession.h"
#include "secCrypto.h"
#define test_printf(x,...)
#define DEBUG_ONLY __attribute__((unused))
typedef struct {
int port;
bool dtls;
int protocol_min;
int protocol_max;
const uint16_t *ciphersuites;
int num_ciphersuites;
bool allow_resumption;
SSLCertificate certs;
const char *request;
} tls_test_case;
typedef struct {
int sock;
tls_record_t rec;
tls_handshake_t hdsk;
tls_stream_parser_t parser;
tls_test_case *test;
int err;
int read_ready_received;
int write_ready_received;
int certificate_requested;
dispatch_semaphore_t connection_done;
} myFilterCtx_t;
static
int mySSLAlloc(tls_buffer *buf, size_t len)
{
buf->data=malloc(len);
if(!buf->data)
return errSecAllocate;
buf->length=len;
return errSecSuccess;
}
static void mySSLFree(tls_buffer *buf)
{
if(buf->data)
free(buf->data);
buf->data=NULL;
buf->length=0;
}
static int tls_handshake_write_callback(tls_handshake_ctx_t ctx, const tls_buffer data, uint8_t content_type)
{
int err;
myFilterCtx_t *myCtx = (myFilterCtx_t *)ctx;
tls_buffer encrypted = {0, }, out;
test_printf("%s: %p (rec.len=%zd)\n", __FUNCTION__, myCtx, data.length);
err=mySSLAlloc(&encrypted, tls_record_encrypted_size(myCtx->rec, content_type, data.length));
require_noerr(err, fail);
err=tls_record_encrypt(myCtx->rec, data, content_type, &encrypted);
require_noerr(err, fail);
test_printf("%s: %p Writing %zd encrypted bytes\n", __FUNCTION__, myCtx, encrypted.length);
out = encrypted;
while(out.length) {
ssize_t nwr;
nwr = send(myCtx->sock, out.data, out.length, 0);
if(nwr<0) {
printf("Error writing %zd bytes to socket : %d\n", out.length, (int)nwr);
err = (int)nwr;
goto fail;
}
out.data += nwr;
out.length -= nwr;
}
fail:
mySSLFree(&encrypted);
return err;
}
__unused static uint8_t alpn_http_1_1[] = {0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31};
__unused static tls_buffer alpnData = {
.data = alpn_http_1_1,
.length = sizeof(alpn_http_1_1),
};
__unused static uint8_t npn_http_1_1[] = {0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31};
__unused static tls_buffer npnHttpData = {
.data = npn_http_1_1,
.length = sizeof(npn_http_1_1),
};
static int
tls_handshake_message_callback(tls_handshake_ctx_t ctx, tls_handshake_message_t event)
{
myFilterCtx_t *myCtx = (myFilterCtx_t *)ctx;
int err = 0;
printf("%s: %p event = %d\n", __FUNCTION__, myCtx, event);
switch(event) {
case tls_handshake_message_certificate:
require_noerr((err = tls_handshake_set_peer_trust(myCtx->hdsk, tls_handshake_trust_ok)), errOut);
const SSLCertificate *cert = tls_handshake_get_peer_certificates(myCtx->hdsk);
require_noerr((err = tls_set_peer_pubkey(myCtx->hdsk, cert)), errOut);
break;
case tls_handshake_message_client_hello:
#if 0
{
const tls_buffer *npnData;
npnData = tls_handshake_get_peer_npn_data(myCtx->hdsk);
if(npnData) {
printf("NPN Data = %p, %zd\n", npnData->data, npnData->length);
require_noerr((err=tls_handshake_set_npn_data(myCtx->hdsk, npnHttpData)), errOut);
}
npnData = tls_handshake_get_peer_alpn_data(myCtx->hdsk);
if(npnData) {
printf("ALPN Data = %p, %zd\n", npnData->data, npnData->length);
}
}
#endif
break;
default:
break;
}
errOut:
return err;
}
int tls_set_peer_pubkey(tls_handshake_t hdsk, const SSLCertificate *certchain);
static void
tls_handshake_ready_callback(tls_handshake_ctx_t ctx, bool write, bool ready)
{
myFilterCtx_t *myCtx = (myFilterCtx_t *)ctx;
test_printf("%s: %s ready=%d\n", __FUNCTION__, write?"write":"read", ready);
if(ready) {
if(write) {
myCtx->write_ready_received++;
} else {
myCtx->read_ready_received++;
}
}
}
static int
tls_handshake_set_retransmit_timer_callback(tls_handshake_ctx_t ctx, int attempt)
{
myFilterCtx_t DEBUG_ONLY *myCtx = (myFilterCtx_t *)ctx;
test_printf("%s: %p attempt=%d\n", __FUNCTION__, myCtx, attempt);
return errSecUnimplemented;
}
static
int mySSLRecordInitPendingCiphersFunc(tls_handshake_ctx_t ref,
uint16_t selectedCipher,
bool server,
tls_buffer key)
{
test_printf("%s: %s, cipher=%04x, server=%d\n", __FUNCTION__, ref, selectedCipher, server);
myFilterCtx_t *c = (myFilterCtx_t *)ref;
return tls_record_init_pending_ciphers(c->rec, selectedCipher, server, key);
}
static
int mySSLRecordAdvanceWriteCipherFunc(tls_handshake_ctx_t ref)
{
test_printf("%s: %s\n", __FUNCTION__, ref);
myFilterCtx_t *c = (myFilterCtx_t *)ref;
return tls_record_advance_write_cipher(c->rec);
}
static
int mySSLRecordRollbackWriteCipherFunc(tls_handshake_ctx_t ref)
{
test_printf("%s: %s\n", __FUNCTION__, ref);
myFilterCtx_t *c = (myFilterCtx_t *)ref;
return tls_record_rollback_write_cipher(c->rec);
}
static
int mySSLRecordAdvanceReadCipherFunc(tls_handshake_ctx_t ref)
{
test_printf("%s: %s\n", __FUNCTION__, ref);
myFilterCtx_t *c = (myFilterCtx_t *)ref;
return tls_record_advance_read_cipher(c->rec);
}
static
int mySSLRecordSetProtocolVersionFunc(tls_handshake_ctx_t ref,
tls_protocol_version protocolVersion)
{
test_printf("%s: %s, pv=%04x\n", __FUNCTION__, ref, protocolVersion);
myFilterCtx_t *c = (myFilterCtx_t *)ref;
return tls_record_set_protocol_version(c->rec, protocolVersion);
}
static int
tls_handshake_save_session_data_callback(tls_handshake_ctx_t ctx, tls_buffer sessionKey, tls_buffer sessionData)
{
myFilterCtx_t DEBUG_ONLY *myCtx = (myFilterCtx_t *)ctx;
test_printf("%s:%p\n", __FUNCTION__, myCtx);
test_printf("key = %s data=[%p,%zd]\n", sessionKey.data, sessionData.data, sessionData.length);
return sslAddSession(sessionKey, sessionData, 0);
}
static int
tls_handshake_load_session_data_callback(tls_handshake_ctx_t ctx, tls_buffer sessionKey, tls_buffer *sessionData)
{
myFilterCtx_t DEBUG_ONLY *myCtx = (myFilterCtx_t *)ctx;
test_printf("%s:%p\n", __FUNCTION__, myCtx);
int err = sslGetSession(sessionKey, sessionData);
test_printf("key = %s data=[%p,%zd], err=%d\n", sessionKey.data, sessionData->data, sessionData->length, err);
return err;
}
static int
tls_handshake_delete_session_data_callback(tls_handshake_ctx_t ctx, tls_buffer sessionKey)
{
myFilterCtx_t DEBUG_ONLY *myCtx = (myFilterCtx_t *)ctx;
test_printf("%s:%p\n", __FUNCTION__, myCtx);
return sslDeleteSession(sessionKey);
}
static int
tls_handshake_delete_all_sessions_callback(tls_handshake_ctx_t ctx)
{
myFilterCtx_t DEBUG_ONLY *myCtx = (myFilterCtx_t *)ctx;
test_printf("%s:%p\n", __FUNCTION__, myCtx);
return sslCleanupSession();
}
static
tls_handshake_callbacks_t tls_handshake_callbacks = {
.write = tls_handshake_write_callback,
.message = tls_handshake_message_callback,
.ready = tls_handshake_ready_callback,
.set_retransmit_timer = tls_handshake_set_retransmit_timer_callback,
.init_pending_cipher = mySSLRecordInitPendingCiphersFunc,
.advance_write_cipher = mySSLRecordAdvanceWriteCipherFunc,
.rollback_write_cipher = mySSLRecordRollbackWriteCipherFunc,
.advance_read_cipher = mySSLRecordAdvanceReadCipherFunc,
.set_protocol_version = mySSLRecordSetProtocolVersionFunc,
.load_session_data = tls_handshake_load_session_data_callback,
.save_session_data = tls_handshake_save_session_data_callback,
.delete_session_data = tls_handshake_delete_session_data_callback,
.delete_all_sessions = tls_handshake_delete_all_sessions_callback,
};
#include <errno.h>
static int SocketListen(int port)
{
struct sockaddr_in sa;
int sock;
if ((sock=socket(AF_INET, SOCK_STREAM, 0))==-1) {
perror("socket");
return -errno;
}
memset((char *) &sa, 0, sizeof(sa));
sa.sin_family = AF_INET;
sa.sin_port = htons(port);
sa.sin_addr.s_addr = htonl(INADDR_ANY);
if(bind (sock, (struct sockaddr *)&sa, sizeof(sa))==-1)
{
perror("bind");
return -errno;
}
if(listen(sock, 5)==-1)
{
perror("listen");
return -errno;
}
return sock;
}
static
int tls_stream_parser_process(tls_stream_parser_ctx_t ctx, tls_buffer record)
{
int err = errSecAllocate;
myFilterCtx_t *c = (myFilterCtx_t *)ctx;
tls_buffer out;
uint8_t content_type;
printf("%s: %p, len = %zu\n", __FUNCTION__, ctx, record.length);
size_t dlen = tls_record_decrypted_size(c->rec, record.length);
mySSLAlloc(&out, dlen+1); require(out.data, fail);
require_noerr((err=tls_record_decrypt(c->rec, record, &out, &content_type)), fail);
if(content_type!=tls_record_type_AppData) {
printf("%s: %p, processing protocol message of type %d, len=%zu\n", __FUNCTION__, ctx, content_type, out.length);
require_noerr_quiet((err = tls_handshake_process(c->hdsk, out, content_type)), fail);
} else {
if(c->read_ready_received<0)
printf("Received data before read_ready\n");
printf("%s: %p, received data record, len = %zu\n", __FUNCTION__, ctx, out.length);
out.data[out.length]=0;
printf("DATA: %s\n", out.data);
}
fail:
mySSLFree(&out);
c->err = err; return err;
}
static
int init_connection(myFilterCtx_t **pc, int fd, tls_test_case *test)
{
int err = errSecAllocate;
myFilterCtx_t *c;
require((c = malloc(sizeof(myFilterCtx_t))), fail);
memset(c, 0, sizeof(myFilterCtx_t));
c->sock = fd;
struct ccrng_state *rng = CCRNGSTATE();
if (!rng) {
abort();
}
require((c->rec = tls_record_create(test->dtls, rng)), fail);
require((c->hdsk = tls_handshake_create(test->dtls, true)), fail);
require((c->parser = tls_stream_parser_create(c, tls_stream_parser_process)), fail);
require_noerr((err=tls_handshake_set_callbacks(c->hdsk,
&tls_handshake_callbacks,
c)),
fail);
if(test->num_ciphersuites)
require_noerr((err=tls_handshake_set_ciphersuites(c->hdsk, test->ciphersuites, test->num_ciphersuites)), fail);
if(test->protocol_min)
require_noerr((err=tls_handshake_set_min_protocol_version(c->hdsk, test->protocol_min)), fail);
if(test->protocol_max)
require_noerr((err=tls_handshake_set_max_protocol_version(c->hdsk, test->protocol_max)), fail);
require_noerr((err=tls_handshake_set_resumption(c->hdsk,test->allow_resumption)), fail);
require_noerr((err=tls_handshake_set_identity(c->hdsk, &server_cert, server_key)), fail);
*pc = c;
return 0;
fail:
return err;
}
static
void clean_connection(myFilterCtx_t *c)
{
if(c->hdsk) tls_handshake_destroy(c->hdsk);
if(c->rec) tls_record_destroy(c->rec);
if(c->parser) tls_stream_parser_destroy(c->parser);
}
static int test_server(tls_test_case *test)
{
dispatch_queue_t read_queue = NULL;
dispatch_source_t socket_source = NULL;
int server_sock = SocketListen(test->port);
if(server_sock<0)
return server_sock;
printf("Listening to socket: %d\n", server_sock);
require((read_queue = dispatch_queue_create("server read queue", DISPATCH_QUEUE_SERIAL)), fail);
require((socket_source = dispatch_source_create(DISPATCH_SOURCE_TYPE_READ, (uintptr_t) server_sock, 0, read_queue)), fail);
dispatch_source_set_cancel_handler(socket_source, ^{
close(server_sock);
});
dispatch_source_set_event_handler(socket_source, ^{
int err;
struct sockaddr my_sock;
socklen_t my_socklen;
int fd;
dispatch_source_t fd_source;
myFilterCtx_t *conn;
fd = accept(server_sock, &my_sock, &my_socklen);
require((fd>=0), connect_fail);
printf("A to socket: %d\n", server_sock);
fd_source = dispatch_source_create(DISPATCH_SOURCE_TYPE_READ, (uintptr_t)fd, 0, read_queue);
require_noerr((err=init_connection(&conn, fd, test)), connect_fail);
printf("Created connection, fd=%d, conn=%p\n", fd, conn);
dispatch_source_set_cancel_handler(fd_source, ^{
printf("Cancelling connection, fd=%d, conn=%p\n", fd, conn);
close(fd);
clean_connection(conn);
});
dispatch_source_set_event_handler(fd_source, ^{
ssize_t nr;
int err = -1;
tls_buffer readbuffer = {0,};
unsigned long data = dispatch_source_get_data(fd_source);
test_printf("[%d] source event data = %lu\n", fd, data);
if(data==0) {
test_printf("EOF? Socket closed ?\n");
err = -1;
goto done;
}
require_noerr(mySSLAlloc(&readbuffer, data),done);
nr = recv(fd, readbuffer.data, readbuffer.length, 0);
require(nr>0, done);
readbuffer.length = nr;
printf("recvd %zd bytes, parse it\n", nr);
require_noerr_quiet((err=tls_stream_parser_parse(conn->parser, readbuffer)), done);
done:
test_printf("done, err=%d\n", err);
mySSLFree(&readbuffer);
if(err) {
printf("Error while parsing incoming data, fd=%d, err = %d\n", fd, err);
printf("Cancelling connection\n");
dispatch_source_cancel(fd_source);
}
});
dispatch_resume(fd_source);
return;
connect_fail:
printf("new connection failed\n");
return;
});
dispatch_resume(socket_source);
printf("Main server thread is now spinning...\n");
while(1);
fail:
if(read_queue)
dispatch_release(read_queue);
if(socket_source) {
dispatch_source_cancel(socket_source);
dispatch_release(socket_source);
}
return 0;
}
uint16_t server_ciphers[] = {
SSL_RSA_WITH_NULL_MD5,
SSL_RSA_WITH_NULL_SHA,
TLS_RSA_WITH_NULL_SHA256,
SSL_RSA_WITH_RC4_128_MD5,
SSL_RSA_WITH_RC4_128_SHA,
SSL_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_256_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA256,
SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
SSL_DH_anon_WITH_RC4_128_MD5,
SSL_DH_anon_WITH_3DES_EDE_CBC_SHA,
TLS_DH_anon_WITH_AES_128_CBC_SHA,
TLS_DH_anon_WITH_AES_128_CBC_SHA256,
TLS_DH_anon_WITH_AES_256_CBC_SHA,
TLS_DH_anon_WITH_AES_256_CBC_SHA256,
TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
};
int num_server_ciphers = sizeof(server_ciphers)/sizeof(server_ciphers[0]);
int handshake_server_test(void);
int handshake_server_test(void)
{
int err;
require_noerr((err=init_server_keys()), fail);
tls_test_case test0;
memset(&test0, 0, sizeof(test0));
test0.port=10443;
test0.num_ciphersuites = num_server_ciphers;
test0.ciphersuites = server_ciphers;
printf("***** Testing case: test0 -- %d\n", test0.port);
err = test_server(&test0);
printf("***** Tested case: test0 -- %d -- err =%d\n", test0.port, err);
fail:
clean_server_keys();
return err;
}