#include <config.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sasl/sasl.h>
#include <sasl/saslutil.h>
#include <syslog.h>
#include "xmalloc.h"
#include "prot.h"
#include "imap_err.h"
#include "saslclient.h"
static int mysasl_simple_cb(void *context, int id, const char **result,
unsigned int *len)
{
if (!result) {
return SASL_BADPARAM;
}
switch (id) {
case SASL_CB_USER:
*result = (char *) context;
break;
case SASL_CB_AUTHNAME:
*result = (char *) context;
break;
case SASL_CB_LANGUAGE:
*result = NULL;
break;
default:
return SASL_BADPARAM;
}
if (len) {
*len = *result ? strlen(*result) : 0;
}
return SASL_OK;
}
static int mysasl_getrealm_cb(void *context, int id,
const char **availrealms __attribute__((unused)),
const char **result)
{
if (id != SASL_CB_GETREALM || !result) {
return SASL_BADPARAM;
}
*result = (char *) context;
return SASL_OK;
}
static int mysasl_getsecret_cb(sasl_conn_t *conn,
void *context,
int id,
sasl_secret_t **result)
{
if (!conn || !result || id != SASL_CB_PASS) {
return SASL_BADPARAM;
}
*result = (sasl_secret_t *)context;
return SASL_OK;
}
sasl_callback_t *mysasl_callbacks(const char *username,
const char *authname,
const char *realm,
const char *password)
{
sasl_callback_t *ret = xmalloc(5 * sizeof(sasl_callback_t));
int n = 0;
if (username) {
ret[n].id = SASL_CB_USER;
ret[n].proc = &mysasl_simple_cb;
ret[n].context = (char *) username;
n++;
}
if (authname) {
ret[n].id = SASL_CB_AUTHNAME;
ret[n].proc = &mysasl_simple_cb;
ret[n].context = (char *) authname;
n++;
}
if (realm) {
ret[n].id = SASL_CB_GETREALM;
ret[n].proc = &mysasl_getrealm_cb;
ret[n].context = (char *) realm;
n++;
}
if (password) {
sasl_secret_t *secret;
size_t len = strlen(password);
secret = (sasl_secret_t *)xmalloc(sizeof(sasl_secret_t) + len);
if(!secret) {
free(ret);
return NULL;
}
strcpy((char *) secret->data, password);
secret->len = len;
ret[n].id = SASL_CB_PASS;
ret[n].proc = &mysasl_getsecret_cb;
ret[n].context = secret;
n++;
}
ret[n].id = SASL_CB_LIST_END;
ret[n].proc = NULL;
ret[n].context = NULL;
return ret;
}
void free_callbacks(sasl_callback_t *in)
{
int i;
if(!in) return;
for(i=0; in[i].id != SASL_CB_LIST_END; i++)
if(in[i].id == SASL_CB_PASS)
free(in[i].context);
free(in);
}
#define BASE64_BUF_SIZE 21848
#define AUTH_BUF_SIZE BASE64_BUF_SIZE+50
int saslclient(sasl_conn_t *conn, struct sasl_cmd_t *sasl_cmd,
const char *mechlist,
struct protstream *pin, struct protstream *pout,
int *sasl_result, const char **status)
{
static char buf[AUTH_BUF_SIZE+1];
char *base64, *serverin;
unsigned int serverinlen = 0;
const char *mech, *clientout = NULL;
unsigned int clientoutlen = 0;
char cmdbuf[40];
int sendliteral = sasl_cmd->quote;
int r;
if (status) *status = NULL;
r = sasl_client_start(conn, mechlist, NULL,
sasl_cmd->maxlen ? &clientout : NULL,
&clientoutlen, &mech);
if (r != SASL_OK && r != SASL_CONTINUE) {
if (sasl_result) *sasl_result = r;
if (status) *status = sasl_errdetail(conn);
return IMAP_SASL_FAIL;
}
if (sasl_cmd->quote)
sprintf(cmdbuf, "%s \"%s\"", sasl_cmd->cmd, mech);
else
sprintf(cmdbuf, "%s %s", sasl_cmd->cmd, mech);
prot_printf(pout, "%s", cmdbuf);
if (!clientout) goto noinitresp;
if (!clientoutlen) {
prot_printf(pout, " =");
clientout = NULL;
}
else if (!sendliteral &&
((strlen(cmdbuf) + clientoutlen + 3) > sasl_cmd->maxlen)) {
goto noinitresp;
}
else {
prot_printf(pout, " ");
}
do {
char *p;
base64 = buf;
*base64 = '\0';
if (clientout) {
r = sasl_encode64(clientout, clientoutlen,
base64, BASE64_BUF_SIZE, NULL);
clientout = NULL;
}
if (sendliteral) {
prot_printf(pout, "{%d+}\r\n", strlen(base64));
prot_flush(pout);
}
prot_printf(pout, "%s", base64);
noinitresp:
prot_printf(pout, "\r\n");
prot_flush(pout);
if (!prot_fgets(buf, AUTH_BUF_SIZE, pin)) {
if (sasl_result) *sasl_result = SASL_FAIL;
if (status) *status = "EOF from server";
return IMAP_SASL_PROTERR;
}
base64 = NULL;
if (!strncasecmp(buf, sasl_cmd->ok, strlen(sasl_cmd->ok))) {
if (sasl_cmd->parse_success)
base64 = sasl_cmd->parse_success(buf, status);
if (!base64
&& status) *status = buf + strlen(sasl_cmd->ok);
r = SASL_OK;
}
else if (!strncasecmp(buf, sasl_cmd->fail, strlen(sasl_cmd->fail))) {
if (status) *status = buf + strlen(sasl_cmd->fail);
r = SASL_BADAUTH;
break;
}
else if (sasl_cmd->cont &&
!strncasecmp(buf, sasl_cmd->cont, strlen(sasl_cmd->cont))) {
base64 = buf + strlen(sasl_cmd->cont);
}
else if (!sasl_cmd->cont && buf[0] == '{') {
unsigned int litsize = atoi(buf+1);
if (!prot_fgets(buf, AUTH_BUF_SIZE, pin)) {
if (sasl_result) *sasl_result = SASL_FAIL;
if (status) *status = "EOF from server";
return IMAP_SASL_PROTERR;
}
base64 = buf;
}
else {
if (status) *status = buf;
r = SASL_BADPROT;
}
if (base64) {
p = base64 + strlen(base64) - 1;
if (p >= base64 && *p == '\n') *p-- = '\0';
if (p >= base64 && *p == '\r') *p-- = '\0';
serverin = buf;
r = sasl_decode64(base64, strlen(base64),
serverin, BASE64_BUF_SIZE, &serverinlen);
if (r == SASL_OK &&
(serverinlen || !clientout)) {
r = sasl_client_step(conn, serverin, serverinlen, NULL,
&clientout, &clientoutlen);
}
}
if (r != SASL_OK && r != SASL_CONTINUE) {
prot_printf(pout, "%s\r\n", sasl_cmd->cancel);
prot_flush(pout);
}
sendliteral = !sasl_cmd->cont;
} while (r == SASL_CONTINUE || (r == SASL_OK && clientout));
if (sasl_result) *sasl_result = r;
return (r == SASL_OK ? 0 : IMAP_SASL_FAIL);
}