#include <Security/Security.h>
#include "ssl-utils.h"
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <string.h>
#include <fcntl.h>
#include <time.h>
#include "tlssocket.h"
#define SERVER "10.0.2.1"
#define PORT 23232
#define BUFLEN 128
#define COUNT 10
#if 0
static void dumppacket(const unsigned char *data, unsigned long len)
{
unsigned long i;
for(i=0;i<len;i++)
{
if((i&0xf)==0) printf("%04lx :",i);
printf(" %02x", data[i]);
if((i&0xf)==0xf) printf("\n");
}
printf("\n");
}
#endif
static time_t lastTime = (time_t)0;
#define TIME_INTERVAL 3
static void sslOutputDot()
{
time_t thisTime = time(0);
if((thisTime - lastTime) >= TIME_INTERVAL) {
printf("."); fflush(stdout);
lastTime = thisTime;
}
}
static void printSslErrStr(
const char *op,
OSStatus err)
{
printf("*** %s: %ld\n", op, (long)err);
}
#define MTU 2048
int dtls_client(const char *hostname, int bypass);
int dtls_client(const char *hostname, int bypass)
{
int fd;
int tlsfd;
struct sockaddr_in sa;
printf("Running dtls_client test with hostname=%s, bypass=%d\n", hostname, bypass);
if ((fd=socket(AF_INET, SOCK_DGRAM, 0))==-1) {
perror("socket");
exit(-1);
}
memset((char *) &sa, 0, sizeof(sa));
sa.sin_family = AF_INET;
sa.sin_port = htons(PORT);
if (inet_aton(hostname, &sa.sin_addr)==0) {
fprintf(stderr, "inet_aton() failed\n");
exit(1);
}
if(connect(fd, (struct sockaddr *)&sa, sizeof(sa))==-1)
{
perror("connect");
return errno;
}
fcntl(fd, F_SETFL, O_NONBLOCK);
SSLRecordContextRef c=(intptr_t)fd;
OSStatus ortn;
SSLContextRef ctx = NULL;
SSLClientCertificateState certState;
SSLCipherSuite negCipher;
SSLProtocol negVersion;
ctx = SSLCreateContextWithRecordFuncs(kCFAllocatorDefault, kSSLClientSide, kSSLDatagramType, &TLSSocket_Funcs);
if(!ctx) {
printSslErrStr("SSLCreateContextWithRecordFuncs", -1);
return -1;
}
printf("Attaching filter\n");
ortn = TLSSocket_Attach(fd);
if(ortn) {
printSslErrStr("TLSSocket_Attach", ortn);
return ortn;
}
if(bypass) {
tlsfd = open("/dev/tlsnke", O_RDWR);
if(tlsfd<0) {
perror("opening tlsnke dev");
exit(-1);
}
}
ortn = SSLSetRecordContext(ctx, c);
if(ortn) {
printSslErrStr("SSLSetRecordContext", ortn);
return ortn;
}
ortn = SSLSetMaxDatagramRecordSize(ctx, 600);
if(ortn) {
printSslErrStr("SSLSetMaxDatagramRecordSize", ortn);
return ortn;
}
ortn = SSLSetEnableCertVerify(ctx, false);
if(ortn) {
printSslErrStr("SSLSetEnableCertVerify", ortn);
return ortn;
}
ortn = SSLSetCertificate(ctx, server_chain());
if(ortn) {
printSslErrStr("SSLSetCertificate", ortn);
return ortn;
}
printf("Handshake...\n");
do {
ortn = SSLHandshake(ctx);
if(ortn == errSSLWouldBlock) {
sslOutputDot();
}
} while (ortn == errSSLWouldBlock);
SSLGetClientCertificateState(ctx, &certState);
SSLGetNegotiatedCipher(ctx, &negCipher);
SSLGetNegotiatedProtocolVersion(ctx, &negVersion);
int count;
size_t len;
ssize_t sreadLen, swriteLen;
size_t readLen, writeLen;
char buffer[BUFLEN];
count = 0;
while(count<COUNT) {
int timeout = 10000;
snprintf(buffer, BUFLEN, "Message %d", count);
len = strlen(buffer);
if(bypass) {
swriteLen=write(tlsfd, buffer, len);
if(swriteLen<0) {
perror("write to tlsfd");
break;
}
writeLen=swriteLen;
} else {
ortn=SSLWrite(ctx, buffer, len, &writeLen);
if(ortn) {
printSslErrStr("SSLWrite", ortn);
break;
}
}
printf("Wrote %lu bytes\n", writeLen);
count++;
if(bypass) {
do {
sreadLen=read(tlsfd, buffer, BUFLEN);
} while((sreadLen==-1) && (errno==EAGAIN) && (timeout--));
if((sreadLen==-1) && (errno==EAGAIN)) {
printf("Read timeout...\n");
continue;
}
if(sreadLen<0) {
perror("read from tlsfd");
break;
}
readLen=sreadLen;
}
else {
do {
ortn=SSLRead(ctx, buffer, BUFLEN, &readLen);
} while((ortn==errSSLWouldBlock) && (timeout--));
if(ortn==errSSLWouldBlock) {
printf("SSLRead timeout...\n");
continue;
}
if(ortn) {
printSslErrStr("SSLRead", ortn);
break;
}
}
buffer[readLen]=0;
printf("Received %lu bytes: %s\n", readLen, buffer);
}
SSLClose(ctx);
SSLDisposeContext(ctx);
return ortn;
}