#include <Security/SecureTransportPriv.h>
#include <string.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <net/kext_net.h>
#include "tlssocket.h"
#include "tlsnke.h"
#include <AssertMacros.h>
#include <errno.h>
static
int TLSSocket_Read(SSLRecordContextRef ref,
SSLRecord *rec)
{
int socket = (int)ref;
int rc;
ssize_t sz;
struct sockaddr_in client_addr;
int avail;
socklen_t avail_size;
struct cmsghdr *cmsg;
tls_record_hdr_t hdr;
struct msghdr msg;
struct iovec iov;
int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
uint8_t cbuf[cbuf_len];
char b;
rc = (int)recv(socket, &b, 1, MSG_PEEK);
if(rc==-1)
{
if(errno==EAGAIN)
return errSSLRecordWouldBlock;
else {
perror("recv");
return errno;
}
}
avail_size = sizeof(avail);
rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
check_noerr(rc);
check(avail_size==sizeof(avail));
if(rc || (avail_size !=sizeof(avail)))
return errSSLRecordInternal;
if(avail==0)
return errSSLRecordWouldBlock;
rec->contents.data = malloc(avail);
rec->contents.length = avail;
iov.iov_base = rec->contents.data;
iov.iov_len = rec->contents.length;
msg.msg_name = &client_addr;
msg.msg_namelen = sizeof(client_addr);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = cbuf;
msg.msg_controllen = cbuf_len;
sz = recvmsg(socket, &msg, 0);
check(sz==avail);
rec->contents.length = sz;
cmsg = CMSG_FIRSTHDR(&msg);
check(cmsg);
if(!cmsg)
return 0;
check(cmsg->cmsg_type == SCM_TLS_HEADER);
check(cmsg->cmsg_level == SOL_SOCKET);
check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
check(hdr);
rec->contentType = hdr->content_type;
rec->protocolVersion = hdr->protocol_version;
if(rec->contentType==SSL_RecordTypeChangeCipher) {
printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
}
return 0;
}
static
int TLSSocket_Free(SSLRecordContextRef ref,
SSLRecord rec)
{
free(rec.contents.data);
return 0;
}
static
int TLSSocket_Write(SSLRecordContextRef ref,
SSLRecord rec)
{
int socket = (int)ref;
ssize_t sz;
struct msghdr msg;
struct iovec iov;
tls_record_hdr_t hdr;
struct cmsghdr *cmsg;
int cbuf_len=CMSG_SPACE(sizeof(*hdr));
uint8_t cbuf[cbuf_len];
if(rec.contentType==SSL_RecordTypeChangeCipher) {
printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
}
iov.iov_base = rec.contents.data;
iov.iov_len = rec.contents.length;
msg.msg_name = NULL;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = cbuf;
msg.msg_controllen = cbuf_len;
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_TLS_HEADER;
cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
hdr->content_type = rec.contentType;
hdr->protocol_version = rec.protocolVersion;
sz = sendmsg(socket, &msg, 0);
if(sz<0)
perror("sendmsg");
check(sz==rec.contents.length);
if(sz<0)
return (int)sz;
else
return 0;
}
static
int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref,
uint16_t selectedCipher,
bool server,
SSLBuffer key)
{
int socket = (int)ref;
int rc;
char *buf;
buf = malloc(key.length+3);
buf[0] = selectedCipher >> 8;
buf[1] = selectedCipher & 0xff;
buf[2] = server;
memcpy(buf+3, key.data, key.length);
printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
printf("%s: rc=%d\n", __FUNCTION__, rc);
free(buf);
return rc;
}
static
int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
{
int socket = (int)ref;
int rc;
rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
printf("%s: rc=%d\n", __FUNCTION__, rc);
return rc;
}
static
int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
{
int socket = (int)ref;
int rc;
rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
printf("%s: rc=%d\n", __FUNCTION__, rc);
return rc;
}
static
int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref)
{
int socket = (int)ref;
int rc;
rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
printf("%s: rc=%d\n", __FUNCTION__, rc);
return rc;
}
static
int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref,
SSLProtocolVersion protocolVersion)
{
int socket = (int)ref;
int rc;
rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
printf("%s: rc=%d\n", __FUNCTION__, rc);
return rc;
}
static
int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref)
{
int socket = (int)ref;
int rc;
rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
return rc;
}
const struct SSLRecordFuncs TLSSocket_Funcs = {
.read = TLSSocket_Read,
.write = TLSSocket_Write,
.initPendingCiphers = TLSSocket_InitPendingCiphers,
.advanceWriteCipher = TLSSocket_AdvanceWriteCipher,
.rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
.advanceReadCipher = TLSSocket_AdvanceReadCipher,
.setProtocolVersion = TLSSocket_SetProtocolVersion,
.free = TLSSocket_Free,
.serviceWriteQueue = TLSSocket_ServiceWriteQueue,
};
int TLSSocket_Attach(int socket)
{
struct so_nke so_tlsnke;
int rc;
int handle;
socklen_t len;
memset(&so_tlsnke, 0, sizeof(so_tlsnke));
so_tlsnke.nke_handle = TLS_HANDLE_IP4;
rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
if(rc)
return rc;
len = sizeof(handle);
rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
if(rc)
return rc;
assert(len==sizeof(handle));
return handle;
}