#include "s2s.h"
static int _in_sx_callback(sx_t s, sx_event_t e, void *data, void *arg);
static void _in_result(conn_t in, nad_t nad);
static void _in_verify(conn_t in, nad_t nad);
static void _in_packet(conn_t in, nad_t nad);
int in_mio_callback(mio_t m, mio_action_t a, int fd, void *data, void *arg) {
conn_t in = (conn_t) arg;
s2s_t s2s = (s2s_t) arg;
struct sockaddr_storage sa;
int namelen = sizeof(sa), port, nbytes;
char ipport[INET6_ADDRSTRLEN + 17];
switch(a) {
case action_READ:
log_debug(ZONE, "read action on fd %d", fd);
ioctl(fd, FIONREAD, &nbytes);
if(nbytes == 0) {
sx_kill(in->s);
return 0;
}
return sx_can_read(in->s);
case action_WRITE:
log_debug(ZONE, "write action on fd %d", fd);
return sx_can_write(in->s);
case action_CLOSE:
log_debug(ZONE, "close action on fd %d", fd);
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] disconnect", fd, in->ip, in->port);
jqueue_push(in->s2s->dead, (void *) in->s, 0);
if (in->online)
xhash_zap(in->s2s->in, in->key);
else {
snprintf(ipport, INET6_ADDRSTRLEN + 16, "%s/%d", in->ip, in->port);
xhash_zap(in->s2s->in_accept, ipport);
}
xhash_free(in->states);
xhash_free(in->states_time);
xhash_free(in->routes);
if(in->key != NULL) free(in->key);
free(in);
break;
case action_ACCEPT:
s2s = (s2s_t) arg;
log_debug(ZONE, "accept action on fd %d", fd);
getpeername(fd, (struct sockaddr *) &sa, &namelen);
port = j_inet_getport(&sa);
log_write(s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] incoming connection", fd, (char *) data, port);
in = (conn_t) malloc(sizeof(struct conn_st));
memset(in, 0, sizeof(struct conn_st));
in->s2s = s2s;
strcpy(in->ip, (char *) data);
in->port = port;
in->states = xhash_new(101);
in->states_time = xhash_new(101);
in->fd = fd;
in->init_time = time(NULL);
in->s = sx_new(s2s->sx_env, in->fd, _in_sx_callback, (void *) in);
mio_app(m, in->fd, in_mio_callback, (void *) in);
snprintf(ipport, INET6_ADDRSTRLEN + 16, "%s/%d", in->ip, in->port);
xhash_put(s2s->in_accept, pstrdup(xhash_pool(s2s->in_accept),ipport), (void *) in);
#ifdef HAVE_SSL
sx_server_init(in->s, S2S_DB_HEADER | ((s2s->local_pemfile != NULL) ?
SX_SSL_STARTTLS_OFFER : 0) );
#else
sx_server_init(in->s, S2S_DB_HEADER);
#endif
break;
}
return 0;
}
static int _in_sx_callback(sx_t s, sx_event_t e, void *data, void *arg) {
conn_t in = (conn_t) arg;
sx_buf_t buf = (sx_buf_t) data;
int len;
sx_error_t *sxe;
nad_t nad;
char ipport[INET6_ADDRSTRLEN + 17];
jid_t from;
int attr;
switch(e) {
case event_WANT_READ:
log_debug(ZONE, "want read");
mio_read(in->s2s->mio, in->fd);
break;
case event_WANT_WRITE:
log_debug(ZONE, "want write");
mio_write(in->s2s->mio, in->fd);
break;
case event_READ:
log_debug(ZONE, "reading from %d", in->fd);
len = recv(in->fd, buf->data, buf->len, 0);
if(len < 0) {
if(errno == EWOULDBLOCK || errno == EINTR || errno == EAGAIN) {
buf->len = 0;
return 0;
}
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] read error: %s (%d)", in->fd, in->ip, in->port, strerror(errno), errno);
sx_kill(s);
return -1;
}
else if(len == 0) {
sx_kill(s);
return -1;
}
log_debug(ZONE, "read %d bytes", len);
buf->len = len;
return len;
case event_WRITE:
log_debug(ZONE, "writing to %d", in->fd);
len = send(in->fd, buf->data, buf->len, 0);
if(len >= 0) {
log_debug(ZONE, "%d bytes written", len);
return len;
}
if(errno == EWOULDBLOCK || errno == EINTR || errno == EAGAIN)
return 0;
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] write error: %s (%d)", in->fd, in->ip, in->port, strerror(errno), errno);
sx_kill(s);
return -1;
case event_ERROR:
sxe = (sx_error_t *) data;
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] error: %s (%s)", in->fd, in->ip, in->port, sxe->generic, sxe->specific);
break;
case event_STREAM:
case event_OPEN:
log_debug(ZONE, "STREAM or OPEN event from %s port %d (id %s)", in->ip, in->port, s->id);
if ((!in->online)||(strcmp(in->key,s->id)!=0)) {
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] incoming stream online (id %s)", in->fd, in->ip, in->port, s->id);
in->online = 1;
if (in->key != NULL) {
log_debug(ZONE,"adding new SSL stream id %s for stream id %s", s->id, in->key);
xhash_zap(in->s2s->in, in->key);
}
free(in->key);
in->key = strdup(s->id);
xhash_put(in->s2s->in, in->key, (void *) in);
snprintf(ipport, INET6_ADDRSTRLEN + 16, "%s/%d", in->ip, in->port);
xhash_zap(in->s2s->in_accept, ipport);
}
break;
case event_PACKET:
nad = (nad_t) data;
in->last_packet = time(NULL);
if(NAD_NURI_L(nad, NAD_ENS(nad, 0)) == strlen(uri_DIALBACK) &&
strncmp(uri_DIALBACK, NAD_NURI(nad, NAD_ENS(nad, 0)), strlen(uri_DIALBACK)) == 0 &&
(in->s2s->require_tls == 0 || s->ssf > 0)) {
if(NAD_ENAME_L(nad, 0) == 6) {
if(strncmp("result", NAD_ENAME(nad, 0), 6) == 0) {
_in_result(in, nad);
return 0;
}
if(strncmp("verify", NAD_ENAME(nad, 0), 6) == 0) {
_in_verify(in, nad);
return 0;
}
}
log_debug(ZONE, "unknown dialback packet, dropping it");
nad_free(nad);
return 0;
}
if(!(
NAD_ENS(nad, 0) >= 0 &&
((NAD_NURI_L(nad, NAD_ENS(nad, 0)) == strlen(uri_CLIENT) && strncmp(uri_CLIENT, NAD_NURI(nad, NAD_ENS(nad, 0)), strlen(uri_CLIENT)) == 0) ||
(NAD_NURI_L(nad, NAD_ENS(nad, 0)) == strlen(uri_SERVER) && strncmp(uri_SERVER, NAD_NURI(nad, NAD_ENS(nad, 0)), strlen(uri_SERVER)) == 0)) && (
(NAD_ENAME_L(nad, 0) == 7 && strncmp("message", NAD_ENAME(nad, 0), 7) == 0) ||
(NAD_ENAME_L(nad, 0) == 8 && strncmp("presence", NAD_ENAME(nad, 0), 8) == 0) ||
(NAD_ENAME_L(nad, 0) == 2 && strncmp("iq", NAD_ENAME(nad, 0), 2) == 0)
) &&
nad_find_attr(nad, 0, -1, "to", NULL) >= 0 && nad_find_attr(nad, 0, -1, "from", NULL) >= 0
)) {
log_debug(ZONE, "they sent us a non-jabber looking packet, dropping it");
nad_free(nad);
return 0;
}
attr = nad_find_attr(nad, 0, -1, "from", NULL);
if(attr < 0 || (from = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid from on incoming packet, attr is %d", attr);
nad_free(nad);
return;
}
if (in->s2s->enable_whitelist > 0 && (_s2s_domain_in_whitelist(in->s2s, from->domain) == 0)) {
log_write(in->s2s->log, LOG_NOTICE, "received a packet not from a whitelisted domain, dropping it");
jid_free(from);
nad_free(nad);
return;
}
jid_free(from);
_in_packet(in, nad);
return 0;
case event_CLOSED:
mio_close(in->s2s->mio, in->fd);
break;
}
return 0;
}
static void _in_result(conn_t in, nad_t nad) {
int attr, ns;
jid_t from, to;
char *rkey;
nad_t verify;
pkt_t pkt;
time_t now;
attr = nad_find_attr(nad, 0, -1, "from", NULL);
if(attr < 0 || (from = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid from on db result packet");
nad_free(nad);
return;
}
attr = nad_find_attr(nad, 0, -1, "to", NULL);
if(attr < 0 || (to = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid to on db result packet");
jid_free(from);
nad_free(nad);
return;
}
rkey = s2s_route_key(NULL, to->domain, from->domain);
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] received dialback auth request for route '%s'", in->fd, in->ip, in->port, rkey);
if((conn_state_t) xhash_get(in->states, rkey) == conn_VALID) {
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] route '%s' is already valid: sending valid", in->fd, in->ip, in->port, rkey);
stanza_tofrom(nad, 0);
nad_set_attr(nad, 0, -1, "type", "valid", 5);
nad->elems[0].icdata = nad->elems[0].itail = -1;
nad->elems[0].lcdata = nad->elems[0].ltail = 0;
sx_nad_write(in->s, nad);
free(rkey);
jid_free(from);
jid_free(to);
return;
}
if(NAD_CDATA_L(nad, 0) <= 0) {
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] no dialback key given with db result packet", in->fd, in->ip, in->port, rkey);
free(rkey);
nad_free(nad);
jid_free(from);
jid_free(to);
return;
}
log_debug(ZONE, "requesting verification for route %s", rkey);
xhash_put(in->states, pstrdup(xhash_pool(in->states), rkey), (void *) conn_INPROGRESS);
now = time(NULL);
xhash_put(in->states_time, pstrdup(xhash_pool(in->states_time), rkey), (void *) now);
free(rkey);
verify = nad_new(in->s2s->router->nad_cache);
ns = nad_add_namespace(verify, uri_DIALBACK, "db");
nad_append_elem(verify, ns, "verify", 0);
nad_append_attr(verify, -1, "to", from->domain);
nad_append_attr(verify, -1, "from", to->domain);
nad_append_attr(verify, -1, "id", in->s->id);
nad_append_cdata(verify, NAD_CDATA(nad, 0), NAD_CDATA_L(nad, 0), 1);
pkt = (pkt_t) malloc(sizeof(struct pkt_st));
memset(pkt, 0, sizeof(struct pkt_st));
pkt->nad = verify;
pkt->to = from;
pkt->from = to;
pkt->db = 1;
out_packet(in->s2s, pkt);
nad_free(nad);
}
static void _in_verify(conn_t in, nad_t nad) {
int attr;
jid_t from, to;
char *id, *dbkey, *type;
attr = nad_find_attr(nad, 0, -1, "from", NULL);
if(attr < 0 || (from = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid from on db verify packet");
nad_free(nad);
return;
}
attr = nad_find_attr(nad, 0, -1, "to", NULL);
if(attr < 0 || (to = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid to on db verify packet");
jid_free(from);
nad_free(nad);
return;
}
attr = nad_find_attr(nad, 0, -1, "id", NULL);
if(attr < 0) {
log_debug(ZONE, "missing id on db verify packet");
jid_free(from);
jid_free(to);
nad_free(nad);
return;
}
if(NAD_CDATA_L(nad, 0) <= 0) {
log_debug(ZONE, "no cdata on db verify packet");
jid_free(from);
jid_free(to);
nad_free(nad);
return;
}
id = (char *) malloc(sizeof(char) * (NAD_AVAL_L(nad, attr) + 1));
snprintf(id, NAD_AVAL_L(nad, attr) + 1, "%.*s", NAD_AVAL_L(nad, attr), NAD_AVAL(nad, attr));
dbkey = s2s_db_key(NULL, in->s2s->local_secret, from->domain, id);
if(NAD_CDATA_L(nad, 0) == strlen(dbkey) && strncmp(dbkey, NAD_CDATA(nad, 0), NAD_CDATA_L(nad, 0)) == 0) {
log_debug(ZONE, "valid dialback key %s, verify succeeded", dbkey);
type = "valid";
} else {
log_debug(ZONE, "invalid dialback key %s, verify failed", dbkey);
type = "invalid";
}
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] checking dialback verification from %s: sending %s", in->fd, in->ip, in->port, from->domain, type);
log_debug(ZONE, "letting them know");
stanza_tofrom(nad, 0);
nad_set_attr(nad, 0, -1, "type", type, 0);
nad->elems[0].icdata = nad->elems[0].itail = -1;
nad->elems[0].lcdata = nad->elems[0].ltail = 0;
sx_nad_write(in->s, nad);
free(dbkey);
free(id);
jid_free(from);
jid_free(to);
return;
}
static void _in_packet(conn_t in, nad_t nad) {
int attr, ns, sns;
jid_t from, to;
char *rkey;
attr = nad_find_attr(nad, 0, -1, "from", NULL);
if(attr < 0 || (from = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid from on incoming packet");
nad_free(nad);
return;
}
attr = nad_find_attr(nad, 0, -1, "to", NULL);
if(attr < 0 || (to = jid_new(in->s2s->pc, NAD_AVAL(nad, attr), NAD_AVAL_L(nad, attr))) == NULL) {
log_debug(ZONE, "missing or invalid to on incoming packet");
jid_free(from);
nad_free(nad);
return;
}
rkey = s2s_route_key(NULL, to->domain, from->domain);
log_debug(ZONE, "received packet from %s for %s", in->key, rkey);
if((conn_state_t) xhash_get(in->states, rkey) != conn_VALID) {
log_write(in->s2s->log, LOG_NOTICE, "[%d] [%s, port=%d] dropping packet on unvalidated route: '%s'", in->fd, in->ip, in->port, rkey);
free(rkey);
nad_free(nad);
jid_free(from);
jid_free(to);
return;
}
free(rkey);
log_debug(ZONE, "incoming packet on valid route, preparing it for the router");
ns = nad_find_namespace(nad, 0, uri_SERVER, NULL);
if(ns >= 0) {
if(nad->elems[0].ns == ns)
nad->elems[0].ns = nad->nss[nad->elems[0].ns].next;
else {
for(sns = nad->elems[0].ns; sns >= 0 && nad->nss[sns].next != ns; sns = nad->nss[sns].next);
nad->nss[sns].next = nad->nss[nad->nss[sns].next].next;
}
}
ns = nad_find_namespace(nad, 0, uri_CLIENT, NULL);
if(ns < 0) {
ns = nad_add_namespace(nad, uri_CLIENT, NULL);
nad->scope = -1;
nad->nss[ns].next = nad->elems[0].ns;
nad->elems[0].ns = ns;
}
nad->elems[0].my_ns = ns;
ns = nad_add_namespace(nad, uri_COMPONENT, "comp");
nad_wrap_elem(nad, 0, ns, "route");
nad_set_attr(nad, 0, -1, "to", to->domain, 0);
nad_set_attr(nad, 0, -1, "from", in->s2s->id, 0);
log_debug(ZONE, "sending packet to %s", to->domain);
sx_nad_write(in->s2s->router, nad);
jid_free(from);
jid_free(to);
}