#include <sys_defs.h>
#include <stdlib.h>
#include <string.h>
#ifdef STRCASECMP_IN_STRINGS_H
#include <strings.h>
#endif
#include <msg.h>
#include <mymalloc.h>
#include <stringops.h>
#include <split_at.h>
#include <name_mask.h>
#include <mail_params.h>
#include <string_list.h>
#include <maps.h>
#include "lmtp.h"
#include "lmtp_sasl.h"
#ifdef USE_SASL_AUTH
static NAME_MASK lmtp_sasl_sec_mask[] = {
"noplaintext", SASL_SEC_NOPLAINTEXT,
"noactive", SASL_SEC_NOACTIVE,
"nodictionary", SASL_SEC_NODICTIONARY,
"noanonymous", SASL_SEC_NOANONYMOUS,
#if SASL_VERSION_MAJOR >= 2
"mutual_auth", SASL_SEC_MUTUAL_AUTH,
#endif
0,
};
#define STR(x) vstring_str(x)
#if SASL_VERSION_MAJOR < 2
#define SASL_LOG_WARN SASL_LOG_WARNING
#define SASL_LOG_NOTE SASL_LOG_INFO
#define SASL_CLIENT_NEW(srv, fqdn, lport, rport, prompt, secflags, pconn) \
sasl_client_new(srv, fqdn, prompt, secflags, pconn)
#define SASL_CLIENT_START(conn, mechlst, secret, prompt, clout, cllen, mech) \
sasl_client_start(conn, mechlst, secret, prompt, clout, cllen, mech)
#define SASL_DECODE64(in, inlen, out, outmaxlen, outlen) \
sasl_decode64(in, inlen, out, outlen)
#endif
#if SASL_VERSION_MAJOR >= 2
#define SASL_CLIENT_NEW(srv, fqdn, lport, rport, prompt, secflags, pconn) \
sasl_client_new(srv, fqdn, lport, rport, prompt, secflags, pconn)
#define SASL_CLIENT_START(conn, mechlst, secret, prompt, clout, cllen, mech) \
sasl_client_start(conn, mechlst, prompt, clout, cllen, mech)
#define SASL_DECODE64(in, inlen, out, outmaxlen, outlen) \
sasl_decode64(in, inlen, out, outmaxlen, outlen)
#endif
static MAPS *lmtp_sasl_passwd_map;
static int lmtp_sasl_log(void *unused_context, int priority,
const char *message)
{
switch (priority) {
case SASL_LOG_ERR:
case SASL_LOG_WARN:
msg_warn("SASL authentication problem: %s", message);
break;
case SASL_LOG_NOTE:
if (msg_verbose)
msg_info("SASL authentication info: %s", message);
break;
#if SASL_VERSION_MAJOR >= 2
case SASL_LOG_FAIL:
msg_warn("SASL authentication failure: %s", message);
#endif
}
return (SASL_OK);
}
static int lmtp_sasl_get_user(void *context, int unused_id, const char **result,
unsigned *len)
{
char *myname = "lmtp_sasl_get_user";
LMTP_STATE *state = (LMTP_STATE *) context;
if (msg_verbose)
msg_info("%s: %s", myname, state->sasl_username);
if (state->sasl_passwd == 0)
msg_panic("%s: no username looked up", myname);
*result = state->sasl_username;
if (len)
*len = strlen(state->sasl_username);
return (SASL_OK);
}
static int lmtp_sasl_get_passwd(sasl_conn_t *conn, void *context,
int id, sasl_secret_t **psecret)
{
char *myname = "lmtp_sasl_get_passwd";
LMTP_STATE *state = (LMTP_STATE *) context;
int len;
if (msg_verbose)
msg_info("%s: %s", myname, state->sasl_passwd);
if (!conn || !psecret || id != SASL_CB_PASS)
return (SASL_BADPARAM);
if (state->sasl_passwd == 0)
msg_panic("%s: no password looked up", myname);
len = strlen(state->sasl_passwd);
if ((*psecret = (sasl_secret_t *) malloc(sizeof(sasl_secret_t) + len)) == 0)
return (SASL_NOMEM);
(*psecret)->len = len;
memcpy((*psecret)->data, state->sasl_passwd, len + 1);
return (SASL_OK);
}
int lmtp_sasl_passwd_lookup(LMTP_STATE *state)
{
char *myname = "lmtp_sasl_passwd_lookup";
const char *value;
char *passwd;
if (lmtp_sasl_passwd_map == 0)
msg_panic("%s: passwd map not initialized", myname);
if ((value = maps_find(lmtp_sasl_passwd_map, state->session->host, 0)) != 0
|| (value = maps_find(lmtp_sasl_passwd_map, state->request->nexthop, 0)) != 0) {
state->sasl_username = mystrdup(value);
passwd = split_at(state->sasl_username, ':');
state->sasl_passwd = mystrdup(passwd ? passwd : "");
if (msg_verbose)
msg_info("%s: host `%s' user `%s' pass `%s'",
myname, state->session->host,
state->sasl_username, state->sasl_passwd);
return (1);
} else {
if (msg_verbose)
msg_info("%s: host `%s' no auth info found",
myname, state->session->host);
return (0);
}
}
void lmtp_sasl_initialize(void)
{
static sasl_callback_t callbacks[] = {
{SASL_CB_LOG, &lmtp_sasl_log, 0},
{SASL_CB_LIST_END, 0, 0}
};
if (lmtp_sasl_passwd_map)
msg_panic("lmtp_sasl_initialize: repeated call");
if (*var_lmtp_sasl_passwd == 0)
msg_fatal("specify a password table via the `%s' configuration parameter",
VAR_LMTP_SASL_PASSWD);
lmtp_sasl_passwd_map = maps_create("lmtp_sasl_passwd",
var_lmtp_sasl_passwd, DICT_FLAG_LOCK);
if (sasl_client_init(callbacks) != SASL_OK)
msg_fatal("SASL library initialization");
}
void lmtp_sasl_connect(LMTP_STATE *state)
{
state->sasl_mechanism_list = 0;
state->sasl_username = 0;
state->sasl_passwd = 0;
state->sasl_conn = 0;
state->sasl_encoded = 0;
state->sasl_decoded = 0;
state->sasl_callbacks = 0;
}
void lmtp_sasl_start(LMTP_STATE *state, const char *sasl_opts_name,
const char *sasl_opts_val)
{
static sasl_callback_t callbacks[] = {
{SASL_CB_USER, &lmtp_sasl_get_user, 0},
{SASL_CB_AUTHNAME, &lmtp_sasl_get_user, 0},
{SASL_CB_PASS, &lmtp_sasl_get_passwd, 0},
{SASL_CB_LIST_END, 0, 0}
};
sasl_callback_t *cp;
sasl_security_properties_t sec_props;
if (msg_verbose)
msg_info("starting new SASL client");
#define NULL_SECFLAGS 0
state->sasl_callbacks = (sasl_callback_t *) mymalloc(sizeof(callbacks));
memcpy((char *) state->sasl_callbacks, callbacks, sizeof(callbacks));
for (cp = state->sasl_callbacks; cp->id != SASL_CB_LIST_END; cp++)
cp->context = (void *) state;
#define NULL_SERVER_ADDR ((char *) 0)
#define NULL_CLIENT_ADDR ((char *) 0)
if (SASL_CLIENT_NEW("lmtp", state->session->host,
NULL_CLIENT_ADDR, NULL_SERVER_ADDR,
state->sasl_callbacks, NULL_SECFLAGS,
(sasl_conn_t **) &state->sasl_conn) != SASL_OK)
msg_fatal("per-session SASL client initialization");
memset(&sec_props, 0L, sizeof(sec_props));
sec_props.min_ssf = 0;
sec_props.max_ssf = 1;
sec_props.security_flags = name_mask(sasl_opts_name, lmtp_sasl_sec_mask,
sasl_opts_val);
sec_props.maxbufsize = 0;
sec_props.property_names = 0;
sec_props.property_values = 0;
if (sasl_setprop(state->sasl_conn, SASL_SEC_PROPS,
&sec_props) != SASL_OK)
msg_fatal("set per-session SASL security properties");
state->sasl_encoded = vstring_alloc(10);
state->sasl_decoded = vstring_alloc(10);
}
int lmtp_sasl_authenticate(LMTP_STATE *state, VSTRING *why)
{
char *myname = "lmtp_sasl_authenticate";
unsigned enc_length;
unsigned enc_length_out;
#if SASL_VERSION_MAJOR >= 2
const char *clientout;
#else
char *clientout;
#endif
unsigned clientoutlen;
unsigned serverinlen;
LMTP_RESP *resp;
const char *mechanism;
int result;
char *line;
#define NO_SASL_SECRET 0
#define NO_SASL_INTERACTION 0
#define NO_SASL_LANGLIST ((const char *) 0)
#define NO_SASL_OUTLANG ((const char **) 0)
if (msg_verbose)
msg_info("%s: %s: SASL mechanisms %s",
myname, state->session->namaddr, state->sasl_mechanism_list);
result = SASL_CLIENT_START((sasl_conn_t *) state->sasl_conn,
state->sasl_mechanism_list,
NO_SASL_SECRET, NO_SASL_INTERACTION,
&clientout, &clientoutlen, &mechanism);
if (result != SASL_OK && result != SASL_CONTINUE) {
vstring_sprintf(why, "cannot SASL authenticate to server %s: %s",
state->session->namaddr,
sasl_errstring(result, NO_SASL_LANGLIST,
NO_SASL_OUTLANG));
return (-1);
}
#define ENCODE64_LENGTH(n) ((((n) + 2) / 3) * 4)
if (clientoutlen > 0) {
if (msg_verbose)
msg_info("%s: %s: uncoded initial reply: %.*s",
myname, state->session->namaddr,
(int) clientoutlen, clientout);
enc_length = ENCODE64_LENGTH(clientoutlen) + 1;
VSTRING_SPACE(state->sasl_encoded, enc_length);
if (sasl_encode64(clientout, clientoutlen,
STR(state->sasl_encoded), enc_length,
&enc_length_out) != SASL_OK)
msg_panic("%s: sasl_encode64 botch", myname);
#if SASL_VERSION_MAJOR < 2
free(clientout);
#endif
lmtp_chat_cmd(state, "AUTH %s %s", mechanism, STR(state->sasl_encoded));
} else {
lmtp_chat_cmd(state, "AUTH %s", mechanism);
}
while ((resp = lmtp_chat_resp(state))->code / 100 == 3) {
line = resp->str;
(void) mystrtok(&line, "- \t\n");
serverinlen = strlen(line);
VSTRING_SPACE(state->sasl_decoded, serverinlen);
if (SASL_DECODE64(line, serverinlen, STR(state->sasl_decoded),
serverinlen, &enc_length) != SASL_OK) {
vstring_sprintf(why, "malformed SASL challenge from server %s",
state->session->namaddr);
return (-1);
}
if (msg_verbose)
msg_info("%s: %s: decoded challenge: %.*s",
myname, state->session->namaddr,
(int) enc_length, STR(state->sasl_decoded));
result = sasl_client_step((sasl_conn_t *) state->sasl_conn,
STR(state->sasl_decoded), enc_length,
NO_SASL_INTERACTION, &clientout, &clientoutlen);
if (result != SASL_OK && result != SASL_CONTINUE)
msg_warn("SASL authentication failed to server %s: %s",
state->session->namaddr,
sasl_errstring(result, NO_SASL_LANGLIST,
NO_SASL_OUTLANG));
if (clientoutlen > 0) {
if (msg_verbose)
msg_info("%s: %s: uncoded client response %.*s",
myname, state->session->namaddr,
(int) clientoutlen, clientout);
enc_length = ENCODE64_LENGTH(clientoutlen) + 1;
VSTRING_SPACE(state->sasl_encoded, enc_length);
if (sasl_encode64(clientout, clientoutlen,
STR(state->sasl_encoded), enc_length,
&enc_length_out) != SASL_OK)
msg_panic("%s: sasl_encode64 botch", myname);
#if SASL_VERSION_MAJOR < 2
free(clientout);
#endif
} else {
vstring_strcat(state->sasl_encoded, "");
}
lmtp_chat_cmd(state, "%s", STR(state->sasl_encoded));
}
if (resp->code / 100 != 2) {
vstring_sprintf(why, "SASL authentication failed; server %s said: %s",
state->session->namaddr, resp->str);
return (0);
}
return (1);
}
void lmtp_sasl_cleanup(LMTP_STATE *state)
{
if (state->sasl_username) {
myfree(state->sasl_username);
state->sasl_username = 0;
}
if (state->sasl_passwd) {
myfree(state->sasl_passwd);
state->sasl_passwd = 0;
}
if (state->sasl_mechanism_list) {
myfree(state->sasl_mechanism_list);
state->sasl_mechanism_list = 0;
}
if (state->sasl_conn) {
if (msg_verbose)
msg_info("disposing SASL state information");
sasl_dispose(&state->sasl_conn);
}
if (state->sasl_callbacks) {
myfree((char *) state->sasl_callbacks);
state->sasl_callbacks = 0;
}
if (state->sasl_encoded) {
vstring_free(state->sasl_encoded);
state->sasl_encoded = 0;
}
if (state->sasl_decoded) {
vstring_free(state->sasl_decoded);
state->sasl_decoded = 0;
}
}
#endif