#include <stdbool.h>
#include <pthread.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>
#include <CoreFoundation/CoreFoundation.h>
#include <AssertMacros.h>
#include <Security/SecureTransportPriv.h>
#include <Security/SecureTransport.h>
#include <Security/SecPolicy.h>
#include <Security/SecTrust.h>
#include <Security/SecIdentity.h>
#include <Security/SecIdentityPriv.h>
#include <Security/SecCertificatePriv.h>
#include <Security/SecKeyPriv.h>
#include <Security/SecItem.h>
#include <Security/SecRandom.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <errno.h>
#include <stdlib.h>
#include <mach/mach_time.h>
#if TARGET_OS_IPHONE
#include <Security/SecRSAKey.h>
#endif
#include "ssl_regressions.h"
#include "ssl-utils.h"
#include <tls_stream_parser.h>
typedef struct {
SSLContextRef st;
bool is_server;
int comm;
CFArrayRef certs;
int write_counter;
tls_stream_parser_t parser;
size_t write_size;
} ssl_test_handle;
#pragma mark -
#pragma mark SecureTransport support
#if 0
static void hexdump(const char *s, const uint8_t *bytes, size_t len) {
size_t ix;
printf("socket %s(%p, %lu)\n", s, bytes, len);
for (ix = 0; ix < len; ++ix) {
if (!(ix % 16))
printf("\n");
printf("%02X ", bytes[ix]);
}
printf("\n");
}
#else
#define hexdump(string, bytes, len)
#endif
static OSStatus SocketWrite(SSLConnectionRef h, const void *data, size_t *length)
{
ssl_test_handle *handle =(ssl_test_handle *)h;
int conn = handle->comm;
size_t len = *length;
uint8_t *ptr = (uint8_t *)data;
if(handle->is_server) {
tls_buffer buffer;
buffer.data = ptr;
buffer.length = len;
tls_stream_parser_parse(handle->parser, buffer);
}
do {
ssize_t ret;
do {
hexdump("write", ptr, len);
ret = write((int)conn, ptr, len);
} while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
if (ret > 0) {
len -= ret;
ptr += ret;
}
else
return -36;
} while (len > 0);
*length = *length - len;
return errSecSuccess;
}
static OSStatus SocketRead(SSLConnectionRef h, void *data, size_t *length)
{
const ssl_test_handle *handle=h;
int conn = handle->comm;
size_t len = *length;
uint8_t *ptr = (uint8_t *)data;
do {
ssize_t ret;
do {
ret = read((int)conn, ptr, len);
} while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
if (ret > 0) {
len -= ret;
ptr += ret;
}
else
return -36;
} while (len > 0);
if(len!=0)
printf("Something went wrong here... len=%d\n", (int)len);
*length = *length - len;
return errSecSuccess;
}
static int process(tls_stream_parser_ctx_t ctx, tls_buffer record)
{
ssl_test_handle *handle = (ssl_test_handle *)ctx;
if(record.data[0]==tls_record_type_AppData) {
handle->write_counter++;
}
return 0;
}
static void *securetransport_ssl_thread(void *arg)
{
OSStatus ortn;
ssl_test_handle * ssl = (ssl_test_handle *)arg;
SSLContextRef ctx = ssl->st;
bool got_server_auth = false;
do {
ortn = SSLHandshake(ctx);
if (ortn == errSSLServerAuthCompleted)
{
require_string(!got_server_auth, out, "second server auth");
got_server_auth = true;
}
} while (ortn == errSSLWouldBlock
|| ortn == errSSLServerAuthCompleted);
require_noerr_action_quiet(ortn, out,
fprintf(stderr, "Fell out of SSLHandshake with error: %d\n", (int)ortn));
unsigned char ibuf[90000], obuf[45000];
if (ssl->is_server) {
size_t len;
require_action(errSecSuccess==SecRandomCopyBytes(kSecRandomDefault, ssl->write_size, obuf), out, ortn = -1);
require_noerr(ortn = SSLWrite(ctx, obuf, ssl->write_size, &len), out);
require_action(len == ssl->write_size, out, ortn = -1);
require_noerr(ortn = SSLWrite(ctx, obuf, ssl->write_size, &len), out);
require_action(len == ssl->write_size, out, ortn = -1);
} else {
size_t len = ssl->write_size*2;
size_t olen;
unsigned char *p = ibuf;
while (len) {
require_noerr(ortn = SSLRead(ctx, p, len, &olen), out);
len -= olen;
p += olen;
}
}
out:
SSLClose(ctx);
CFRelease(ctx);
close(ssl->comm);
pthread_exit((void *)(intptr_t)ortn);
return NULL;
}
static void
ssl_test_handle_destroy(ssl_test_handle *handle)
{
if(handle) {
if(handle->parser) tls_stream_parser_destroy(handle->parser);
free(handle);
}
}
static ssl_test_handle *
ssl_test_handle_create(bool server, int comm, CFArrayRef certs)
{
ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
SSLContextRef ctx = SSLCreateContext(kCFAllocatorDefault, server?kSSLServerSide:kSSLClientSide, kSSLStreamType);
require(handle, out);
require(ctx, out);
require_noerr(SSLSetIOFuncs(ctx,
(SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)handle), out);
if (server)
require_noerr(SSLSetCertificate(ctx, certs), out);
require_noerr(SSLSetSessionOption(ctx,
kSSLSessionOptionBreakOnServerAuth, true), out);
require_noerr(SSLSetEnableCertVerify(ctx, false), out);
handle->is_server = server;
handle->comm = comm;
handle->certs = certs;
handle->st = ctx;
handle->write_counter = 0;
handle->parser = tls_stream_parser_create(handle, process);
return handle;
out:
if (handle) free(handle);
if (ctx) CFRelease(ctx);
return NULL;
}
static SSLCipherSuite ciphers[] = {
TLS_RSA_WITH_AES_128_CBC_SHA,
};
static int nciphers = sizeof(ciphers)/sizeof(ciphers[0]);
static SSLProtocol versions[] = {
kSSLProtocol3,
kTLSProtocol1,
kTLSProtocol11,
kTLSProtocol12,
};
static int nversions = sizeof(versions)/sizeof(versions[0]);
static size_t wsizes[][3] = {
{ 1, 2, 2 },
{ 2, 2, 3 },
{ 3, 2, 3 },
{ 4, 2, 3 },
{ 16384, 2, 3 },
{ 16385, 4, 4 },
{ 16386, 4, 5 },
{ 16387, 4, 5 },
{ 16388, 4, 5 },
{ 32768, 4, 5 },
{ 32769, 6, 6 },
{ 32770, 6, 7 },
{ 32771, 6, 7 },
{ 32772, 6, 7 },
};
static int nwsizes = sizeof(wsizes)/sizeof(wsizes[0]);
static void
tests(void)
{
pthread_t client_thread, server_thread;
CFArrayRef server_certs = server_chain();
ok(server_certs, "got server certs");
int i,j,k,s;
for(i=0; i<nciphers; i++)
for(j=0; j<nversions; j++)
for(k=0; k<nwsizes; k++)
for(s=0; s<3; s++)
{
int sp[2];
if (socketpair(AF_UNIX, SOCK_STREAM, 0, sp)) exit(errno);
fcntl(sp[0], F_SETNOSIGPIPE, 1);
fcntl(sp[1], F_SETNOSIGPIPE, 1);
ssl_test_handle *server, *client;
server = ssl_test_handle_create(true , sp[0], server_certs);
client = ssl_test_handle_create(false, sp[1], NULL);
server->write_size = wsizes[k][0];
client->write_size = wsizes[k][0];
require(client, out);
require(server, out);
require_noerr(SSLSetProtocolVersionMin(server->st, kSSLProtocol3), out); require_noerr(SSLSetProtocolVersionMax(client->st, versions[j]), out);
require_noerr(SSLSetEnabledCiphers(client->st, &ciphers[i], 1), out);
if(s) {
require_noerr(SSLSetSessionOption(server->st, kSSLSessionOptionSendOneByteRecord, (s==1)?true:false), out);
}
pthread_create(&client_thread, NULL, securetransport_ssl_thread, client);
pthread_create(&server_thread, NULL, securetransport_ssl_thread, server);
intptr_t server_err, client_err;
pthread_join(client_thread, (void*)&client_err);
pthread_join(server_thread, (void*)&server_err);
ok(!server_err, "Server error = %ld", server_err);
ok(!client_err, "Client error = %ld", client_err);
bool expected_split = (i==0) && (s!=2) && (versions[j]<=kTLSProtocol1);
int expected_count = (int)(expected_split ? wsizes[k][2]: wsizes[k][1]);
is(server->write_counter, expected_count, "wrong number of data records");
out:
ssl_test_handle_destroy(client);
ssl_test_handle_destroy(server);
}
CFReleaseNull(server_certs);
}
int ssl_48_split(int argc, char *const *argv)
{
plan_tests(1 + nciphers*nversions*nwsizes*3 * 3);
tests();
return 0;
}