#include "sx.h"
#ifdef HAVE_SSL
#include "ssl.h"
#include <CoreServices/CoreServices.h>
#include <Security/SecKeychain.h>
#include <Security/SecKeychainItem.h>
static void _sx_ssl_starttls_notify_proceed(sx_t s, void *arg) {
_sx_debug(ZONE, "preparing for starttls");
_sx_reset(s);
sx_server_init(s, s->flags | SX_SSL_WRAPPER);
}
static int _sx_ssl_process(sx_t s, sx_plugin_t p, nad_t nad) {
int flags;
char *ns = NULL, *to = NULL, *from = NULL, *version = NULL;
sx_error_t sxe;
if(s->type == type_SERVER && !(s->flags & SX_SSL_STARTTLS_OFFER))
return 1;
if(NAD_ENS(nad, 0) < 0 || NAD_NURI_L(nad, NAD_ENS(nad, 0)) != strlen(uri_TLS) || strncmp(NAD_NURI(nad, NAD_ENS(nad, 0)), uri_TLS, strlen(uri_TLS)) != 0)
return 1;
if(s->type == type_SERVER) {
if(NAD_ENAME_L(nad, 0) == 8 || strncmp(NAD_ENAME(nad, 0), "starttls", 8) == 0) {
nad_free(nad);
if(s->ssf > 0) {
_sx_debug(ZONE, "starttls requested on already encrypted channel, dropping packet");
return 0;
}
_sx_debug(ZONE, "starttls requested, setting up");
jqueue_push(s->wbufq, _sx_buffer_new("<proceed xmlns='" uri_TLS "'/>", strlen(uri_TLS) + 19, _sx_ssl_starttls_notify_proceed, NULL), 0);
s->want_write = 1;
return 0;
}
}
else if(s->type == type_CLIENT) {
if(NAD_ENAME_L(nad, 0) == 7 || strncmp(NAD_ENAME(nad, 0), "proceed", 7) == 0) {
nad_free(nad);
flags = s->flags;
if(s->ns != NULL) ns = strdup(s->ns);
if(s->req_to != NULL) to = strdup(s->req_to);
if(s->req_from != NULL) from = strdup(s->req_from);
if(s->req_version != NULL) version = strdup(s->req_version);
_sx_reset(s);
_sx_debug(ZONE, "server ready for ssl, starting");
sx_client_init(s, flags | SX_SSL_WRAPPER, ns, to, from, version);
if(ns != NULL) free(ns);
if(to != NULL) free(to);
if(from != NULL) free(from);
if(version != NULL) free(version);
return 0;
}
if(NAD_ENAME_L(nad, 0) == 7 || strncmp(NAD_ENAME(nad, 0), "failure", 7) == 0) {
nad_free(nad);
if(s->plugin_data[p->index] != NULL)
free(s->plugin_data[p->index]);
_sx_debug(ZONE, "server can't handle ssl, business as usual");
_sx_gen_error(sxe, SX_ERR_STARTTLS_FAILURE, "STARTTLS failure", "Server was unable to prepare for the TLS handshake");
_sx_event(s, event_ERROR, (void *) &sxe);
return 0;
}
}
_sx_debug(ZONE, "unknown starttls namespace element '%.*s', dropping packet", NAD_ENAME_L(nad, 0), NAD_ENAME(nad, 0));
nad_free(nad);
return 0;
}
static void _sx_ssl_features(sx_t s, sx_plugin_t p, nad_t nad) {
int ns;
if(s->state > state_STREAM || s->ssf > 0 || !(s->flags & SX_SSL_STARTTLS_OFFER))
return;
_sx_debug(ZONE, "offering starttls");
ns = nad_add_namespace(nad, uri_TLS, NULL);
nad_append_elem(nad, ns, "starttls", 1);
if(s->flags & SX_SSL_STARTTLS_REQUIRE)
nad_append_elem(nad, ns, "required", 2);
}
static int _sx_ssl_handshake(sx_t s, _sx_ssl_conn_t sc) {
int ret, err;
char *errstring;
sx_error_t sxe;
while(!SSL_is_init_finished(sc->ssl)) {
_sx_debug(ZONE, "secure channel not established, handshake in progress");
if(sc->last_state == SX_SSL_STATE_WANT_READ && BIO_pending(sc->rbio) == 0)
return 0;
if(s->type == type_CLIENT)
ret = SSL_connect(sc->ssl);
else
ret = SSL_accept(sc->ssl);
if(ret == 1) {
_sx_debug(ZONE, "secure channel established");
sc->last_state = SX_SSL_STATE_NONE;
s->ssf = SSL_get_cipher_bits(sc->ssl, NULL);
_sx_debug(ZONE, "using cipher %s (%d bits)", SSL_get_cipher_name(sc->ssl), s->ssf);
return 1;
}
else if(ret <= 0) {
err = SSL_get_error(sc->ssl, ret);
if(err == SSL_ERROR_WANT_READ)
sc->last_state = SX_SSL_STATE_WANT_READ;
else if(err == SSL_ERROR_WANT_WRITE)
sc->last_state = SX_SSL_STATE_WANT_WRITE;
else {
sc->last_state = SX_SSL_STATE_ERROR;
errstring = ERR_error_string(ERR_get_error(), NULL);
_sx_debug(ZONE, "openssl error: %s", errstring);
_sx_gen_error(sxe, SX_ERR_SSL, "SSL handshake error", errstring);
_sx_event(s, event_ERROR, (void *) &sxe);
_sx_error(s, stream_err_INTERNAL_SERVER_ERROR, errstring);
_sx_close(s);
return -1;
}
}
}
return 1;
}
static int _sx_ssl_wio(sx_t s, sx_plugin_t p, sx_buf_t buf) {
_sx_ssl_conn_t sc = (_sx_ssl_conn_t) s->plugin_data[p->index];
int est, ret, err;
sx_buf_t wbuf;
char *errstring;
sx_error_t sxe;
if(sc->last_state == SX_SSL_STATE_ERROR)
return -2;
_sx_debug(ZONE, "in _sx_ssl_wio");
if(buf->len > 0) {
_sx_debug(ZONE, "queueing buffer for write");
jqueue_push(sc->wq, _sx_buffer_new(buf->data, buf->len, buf->notify, buf->notify_arg), 0);
_sx_buffer_clear(buf);
buf->notify = NULL;
buf->notify_arg = NULL;
}
est = _sx_ssl_handshake(s, sc);
if(est < 0)
return -2;
wbuf = NULL;
if(est > 0 && jqueue_size(sc->wq) > 0) {
_sx_debug(ZONE, "preparing queued buffer for write");
wbuf = jqueue_pull(sc->wq);
ret = SSL_write(sc->ssl, wbuf->data, wbuf->len);
if(ret <= 0) {
_sx_debug(ZONE, "write failed, requeuing buffer");
jqueue_push(sc->wq, wbuf, (sc->wq->front != NULL) ? sc->wq->front->priority + 1 : 0);
err = SSL_get_error(sc->ssl, ret);
if(err == SSL_ERROR_ZERO_RETURN) {
_sx_close(s);
}
if(err == SSL_ERROR_WANT_READ) {
_sx_debug(ZONE, "renegotiation started");
sc->last_state = SX_SSL_STATE_WANT_READ;
}
else {
sc->last_state = SX_SSL_STATE_ERROR;
errstring = ERR_error_string(ERR_get_error(), NULL);
_sx_debug(ZONE, "openssl error: %s", errstring);
_sx_gen_error(sxe, SX_ERR_SSL, "SSL handshake error", errstring);
_sx_event(s, event_ERROR, (void *) &sxe);
_sx_error(s, stream_err_INTERNAL_SERVER_ERROR, errstring);
_sx_close(s);
return -2;
}
}
}
if(BIO_pending(sc->wbio) > 0) {
int bytes_pending = BIO_pending(sc->wbio);
assert(buf->len == 0);
_sx_buffer_alloc_margin(buf, 0, bytes_pending);
BIO_read(sc->wbio, buf->data, bytes_pending);
buf->len += bytes_pending;
if(wbuf != NULL) {
buf->notify = wbuf->notify;
buf->notify_arg = wbuf->notify_arg;
_sx_buffer_free(wbuf);
}
_sx_debug(ZONE, "prepared %d ssl bytes for write", buf->len);
}
if(sc->last_state == SX_SSL_STATE_WANT_READ || sc->last_state == SX_SSL_STATE_NONE)
s->want_read = 1;
return 1;
}
static int _sx_ssl_rio(sx_t s, sx_plugin_t p, sx_buf_t buf) {
_sx_ssl_conn_t sc = (_sx_ssl_conn_t) s->plugin_data[p->index];
int est, ret, err, pending;
char *errstring;
sx_error_t sxe;
if(sc->last_state == SX_SSL_STATE_ERROR)
return -1;
_sx_debug(ZONE, "in _sx_ssl_rio");
if(buf->len > 0) {
_sx_debug(ZONE, "loading %d bytes into ssl read buffer", buf->len);
BIO_write(sc->rbio, buf->data, buf->len);
_sx_buffer_clear(buf);
}
est = _sx_ssl_handshake(s, sc);
if(est < 0)
return -1;
if(est > 0) {
pending = SSL_pending(sc->ssl);
if(pending == 0)
pending = BIO_pending(sc->rbio);
while((pending = SSL_pending(sc->ssl)) > 0 || (pending = BIO_pending(sc->rbio)) > 0) {
_sx_buffer_alloc_margin(buf, 0, pending);
ret = SSL_read(sc->ssl, &(buf->data[buf->len]), pending);
if (ret == 0)
{
if (SSL_get_shutdown(sc->ssl) == SSL_RECEIVED_SHUTDOWN)
{
_sx_close(s);
break;
}
err = SSL_get_error(sc->ssl, ret);
_sx_buffer_clear(buf);
if(err == SSL_ERROR_ZERO_RETURN) {
_sx_close(s);
}
return -1;
}
else if(ret < 0) {
err = SSL_get_error(sc->ssl, ret);
if(err == SSL_ERROR_WANT_READ) {
sc->last_state = SX_SSL_STATE_WANT_READ;
break;
}
_sx_buffer_clear(buf);
sc->last_state = SX_SSL_STATE_ERROR;
errstring = ERR_error_string(ERR_get_error(), NULL);
_sx_debug(ZONE, "openssl error: %s", errstring);
_sx_gen_error(sxe, SX_ERR_SSL, "SSL handshake error", errstring);
_sx_event(s, event_ERROR, (void *) &sxe);
_sx_error(s, stream_err_INTERNAL_SERVER_ERROR, errstring);
_sx_close(s);
return -1;
}
buf->len += ret;
}
}
if(BIO_pending(sc->wbio) > 0 || (est > 0 && jqueue_size(sc->wq) > 0))
s->want_write = 1;
if(sc->last_state == SX_SSL_STATE_WANT_READ || sc->last_state == SX_SSL_STATE_NONE)
s->want_read = 1;
if(buf->len == 0)
return 0;
return 1;
}
static void _sx_ssl_client(sx_t s, sx_plugin_t p) {
_sx_ssl_conn_t sc;
char *pemfile;
int ret;
if(!(s->flags & SX_SSL_WRAPPER) || s->ssf > 0)
return;
_sx_debug(ZONE, "preparing for ssl connect for %d", s->tag);
sc = (_sx_ssl_conn_t) malloc(sizeof(struct _sx_ssl_conn_st));
memset(sc, 0, sizeof(struct _sx_ssl_conn_st));
sc->rbio = BIO_new(BIO_s_mem());
sc->wbio = BIO_new(BIO_s_mem());
sc->ssl = SSL_new((SSL_CTX *) p->private);
SSL_set_bio(sc->ssl, sc->rbio, sc->wbio);
SSL_set_connect_state(sc->ssl);
pemfile = s->plugin_data[p->index]; s->plugin_data[p->index] = NULL;
if(pemfile != NULL) {
ret = SSL_use_certificate_file(sc->ssl, pemfile, SSL_FILETYPE_PEM);
if(ret != 1) {
_sx_debug(ZONE, "couldn't load alternate certificate from %s", pemfile);
SSL_free(sc->ssl);
free(sc);
free(pemfile);
return;
}
ret = SSL_use_PrivateKey_file(sc->ssl, pemfile, SSL_FILETYPE_PEM);
if(ret != 1) {
_sx_debug(ZONE, "couldn't load alternate private key from %s", pemfile);
SSL_free(sc->ssl);
free(sc);
free(pemfile);
return;
}
ret = SSL_check_private_key(sc->ssl);
if(ret != 1) {
_sx_debug(ZONE, "private key does not match certificate public key");
SSL_free(sc->ssl);
free(sc);
free(pemfile);
return;
}
_sx_debug(ZONE, "loaded alternate pemfile %s", pemfile);
free(pemfile);
}
sc->wq = jqueue_new();
s->plugin_data[p->index] = (void *) sc;
_sx_chain_io_plugin(s, p);
}
static void _sx_ssl_server(sx_t s, sx_plugin_t p) {
_sx_ssl_conn_t sc;
if(!(s->flags & SX_SSL_WRAPPER) || s->ssf > 0)
return;
_sx_debug(ZONE, "preparing for ssl accept for %d", s->tag);
sc = (_sx_ssl_conn_t) malloc(sizeof(struct _sx_ssl_conn_st));
memset(sc, 0, sizeof(struct _sx_ssl_conn_st));
sc->rbio = BIO_new(BIO_s_mem());
sc->wbio = BIO_new(BIO_s_mem());
sc->ssl = SSL_new((SSL_CTX *) p->private);
SSL_set_bio(sc->ssl, sc->rbio, sc->wbio);
SSL_set_accept_state(sc->ssl);
sc->wq = jqueue_new();
s->plugin_data[p->index] = (void *) sc;
_sx_chain_io_plugin(s, p);
}
static void _sx_ssl_free(sx_t s, sx_plugin_t p) {
_sx_ssl_conn_t sc = (_sx_ssl_conn_t) s->plugin_data[p->index];
sx_buf_t buf;
if(sc == NULL)
return;
log_debug(ZONE, "cleaning up conn state");
if(s->type == type_NONE) {
free(sc);
return;
}
if(sc->external_id != NULL) free(sc->external_id);
if(sc->ssl) SSL_free(sc->ssl);
while((buf = jqueue_pull(sc->wq)) != NULL)
_sx_buffer_free(buf);
jqueue_free(sc->wq);
free(sc);
}
static void _sx_ssl_unload(sx_plugin_t p) {
SSL_CTX_free((SSL_CTX *) p->private);
}
int sx_ssl_init(sx_env_t env, sx_plugin_t p, va_list args) {
char *pemfile, *cachain;
SSL_CTX *ctx;
int ret;
_sx_debug(ZONE, "initialising ssl plugin");
pemfile = va_arg(args, char *);
if(pemfile == NULL)
return 1;
if(p->private != NULL)
return 1;
cachain = va_arg(args, char *);
SSL_library_init();
SSL_load_error_strings();
ctx = SSL_CTX_new(SSLv23_method());
if(ctx == NULL) {
_sx_debug(ZONE, "ssl context creation failed");
return 1;
}
ret = SSL_CTX_use_certificate_file(ctx, pemfile, SSL_FILETYPE_PEM);
if(ret != 1) {
_sx_debug(ZONE, "couldn't load certificate from %s", pemfile);
SSL_CTX_free(ctx);
return 1;
}
#ifdef __APPLE__
char *label = NULL;
char *tmp = strrchr(pemfile, '/');
_sx_debug(ZONE, "Adding Apple-custom SSL password callback");
{
if(tmp != NULL)
label = strdup(++tmp);
else
label = strdup(pemfile);
tmp = strrchr(label, '.');
if(tmp != NULL)
*tmp = '\0';
if(strlen(label))
{
SSL_CTX_set_default_passwd_cb_userdata(ctx, (void *)label);
SSL_CTX_set_default_passwd_cb(ctx, &apple_password_callback);
_sx_debug(ZONE, "Apple-custom SSL password callback enabled for %s", label);
}
else
_sx_debug(ZONE, "Could not set custom callback for %s", pemfile);
}
#endif
ret = SSL_CTX_use_PrivateKey_file(ctx, pemfile, SSL_FILETYPE_PEM);
if(ret != 1) {
_sx_debug(ZONE, "couldn't load private key from %s", pemfile);
SSL_CTX_free(ctx);
return 1;
}
if (cachain != NULL) {
ret = SSL_CTX_use_certificate_chain_file(ctx, cachain);
if(ret != 1) {
_sx_debug(ZONE, "WARNING: couldn't load CA chain");
}
}
ret = SSL_CTX_check_private_key(ctx);
if(ret != 1) {
_sx_debug(ZONE, "private key does not match certificate public key");
SSL_CTX_free(ctx);
return 1;
}
_sx_debug(ZONE, "ssl context initialised; certificate and key loaded from %s", pemfile);
p->magic = SX_SSL_MAGIC;
p->private = (void *) ctx;
p->unload = _sx_ssl_unload;
p->client = _sx_ssl_client;
p->server = _sx_ssl_server;
p->rio = _sx_ssl_rio;
p->wio = _sx_ssl_wio;
p->features = _sx_ssl_features;
p->process = _sx_ssl_process;
p->free = _sx_ssl_free;
return 0;
}
int sx_ssl_client_starttls(sx_plugin_t p, sx_t s, char *pemfile) {
assert((int) p);
assert((int) s);
if(s->type != type_CLIENT || s->state != state_STREAM) {
_sx_debug(ZONE, "wrong conn type or state for client starttls");
return 1;
}
if(s->ssf > 0) {
_sx_debug(ZONE, "encrypted channel already established");
return 1;
}
_sx_debug(ZONE, "initiating starttls sequence");
if(pemfile != NULL)
s->plugin_data[p->index] = (void *) strdup(pemfile);
jqueue_push(s->wbufq, _sx_buffer_new("<starttls xmlns='" uri_TLS "'/>", strlen(uri_TLS) + 20, NULL, NULL), 0);
s->want_write = 1;
_sx_event(s, event_WANT_WRITE, NULL);
return 0;
}
#ifdef __APPLE__
int apple_password_callback(char *inBuf, int inSize, int in_rwflag, void *inUserData)
{
OSStatus status = noErr;
void *pwdBuf = NULL;
UInt32 pwdLen = 0;
char *service = "certificateManager";
const char *label = (const char *)inUserData;
size_t len = strlen(label);
if(inBuf == NULL || inUserData == NULL || len >= FILENAME_MAX || len <= 0)
{
_sx_debug(ZONE, "Invalid arguments in callback");
return 0;
}
status = SecKeychainSetPreferenceDomain(kSecPreferencesDomainSystem);
if(status != noErr)
{
_sx_debug(ZONE, "SecKeychainSetPreferenceDomain returned status: %d", status);
return 0;
}
status = SecKeychainFindGenericPassword(NULL, strlen(service), service, len, label, &pwdLen, &pwdBuf, NULL);
if(status == noErr && pwdBuf != NULL)
{
if(pwdLen > inSize)
{
_sx_debug(ZONE, "Invalid buffer size callback (size:%d, len:%d)", inSize, pwdLen);
pwdLen = 0;
}
if(pwdLen > 0)
memcpy(inBuf, pwdBuf, pwdLen);
inBuf[pwdLen] = 0;
SecKeychainItemFreeContent(NULL, pwdBuf);
return pwdLen;
}
if(status == errSecNotAvailable)
_sx_debug(ZONE, "SecKeychainFindGenericPassword: No keychain is available");
else if(status == errSecItemNotFound)
_sx_debug(ZONE, "SecKeychainFindGenericPassword: Requested key not in system keychain");
else if(status != noErr)
_sx_debug(ZONE, "SecKeychainFindGenericPassword returned status %d", status);
return 0 ;
}
#endif
#endif