#include "lib.h"
#include "array.h"
#include "ioloop.h"
#include "istream.h"
#include "istream-tee.h"
#include "ostream.h"
#include "lmtp-client.h"
#include "lmtp-proxy.h"
#define LMTP_MAX_LINE_LEN 1024
#define LMTP_PROXY_DATA_INPUT_TIMEOUT_MSECS (1000*60)
struct lmtp_proxy_recipient {
struct lmtp_proxy_connection *conn;
const char *address;
const char *reply;
unsigned int rcpt_to_failed:1;
unsigned int data_reply_received:1;
};
struct lmtp_proxy_connection {
struct lmtp_proxy *proxy;
struct lmtp_proxy_settings set;
struct lmtp_client *client;
struct istream *data_input;
unsigned int finished:1;
unsigned int failed:1;
};
struct lmtp_proxy {
pool_t pool;
const char *mail_from, *my_hostname;
const char *dns_client_socket_path;
ARRAY_DEFINE(connections, struct lmtp_proxy_connection *);
ARRAY_DEFINE(rcpt_to, struct lmtp_proxy_recipient *);
unsigned int next_data_reply_idx;
struct timeout *to, *to_data_idle, *to_finish;
struct io *io;
struct istream *data_input, *orig_data_input;
struct ostream *client_output;
struct tee_istream *tee_data_input;
unsigned int max_timeout_msecs;
lmtp_proxy_finish_callback_t *finish_callback;
void *finish_context;
unsigned int finished:1;
unsigned int input_timeout:1;
unsigned int handling_data_input:1;
};
static void lmtp_conn_finish(void *context);
static void lmtp_proxy_data_input(struct lmtp_proxy *proxy);
struct lmtp_proxy *
lmtp_proxy_init(const char *my_hostname, const char *dns_client_socket_path,
struct ostream *client_output)
{
struct lmtp_proxy *proxy;
pool_t pool;
o_stream_ref(client_output);
pool = pool_alloconly_create("lmtp proxy", 1024);
proxy = p_new(pool, struct lmtp_proxy, 1);
proxy->pool = pool;
proxy->my_hostname = p_strdup(pool, my_hostname);
proxy->client_output = client_output;
proxy->dns_client_socket_path = p_strdup(pool, dns_client_socket_path);
i_array_init(&proxy->rcpt_to, 32);
i_array_init(&proxy->connections, 32);
return proxy;
}
static void lmtp_proxy_connections_deinit(struct lmtp_proxy *proxy)
{
struct lmtp_proxy_connection *const *conns;
array_foreach(&proxy->connections, conns) {
struct lmtp_proxy_connection *conn = *conns;
lmtp_client_deinit(&conn->client);
}
}
void lmtp_proxy_deinit(struct lmtp_proxy **_proxy)
{
struct lmtp_proxy *proxy = *_proxy;
*_proxy = NULL;
lmtp_proxy_connections_deinit(proxy);
if (proxy->data_input != NULL)
i_stream_unref(&proxy->data_input);
if (proxy->client_output != NULL)
o_stream_unref(&proxy->client_output);
if (proxy->to_data_idle != NULL)
timeout_remove(&proxy->to_data_idle);
if (proxy->to_finish != NULL)
timeout_remove(&proxy->to_finish);
if (proxy->to != NULL)
timeout_remove(&proxy->to);
if (proxy->io != NULL)
io_remove(&proxy->io);
array_free(&proxy->rcpt_to);
array_free(&proxy->connections);
pool_unref(&proxy->pool);
}
void lmtp_proxy_mail_from(struct lmtp_proxy *proxy, const char *value)
{
proxy->mail_from = p_strdup(proxy->pool, value);
}
static struct lmtp_proxy_connection *
lmtp_proxy_get_connection(struct lmtp_proxy *proxy,
const struct lmtp_proxy_settings *set)
{
struct lmtp_proxy_connection *const *conns, *conn;
struct lmtp_client_settings client_set;
i_assert(set->timeout_msecs > 0);
array_foreach(&proxy->connections, conns) {
conn = *conns;
if (conn->set.port == set->port &&
strcmp(conn->set.host, set->host) == 0)
return conn;
}
memset(&client_set, 0, sizeof(client_set));
client_set.mail_from = proxy->mail_from;
client_set.my_hostname = proxy->my_hostname;
client_set.dns_client_socket_path = proxy->dns_client_socket_path;
conn = p_new(proxy->pool, struct lmtp_proxy_connection, 1);
conn->proxy = proxy;
conn->set.host = p_strdup(proxy->pool, set->host);
conn->set.port = set->port;
conn->set.timeout_msecs = set->timeout_msecs;
array_append(&proxy->connections, &conn, 1);
conn->client = lmtp_client_init(&client_set, lmtp_conn_finish, conn);
if (lmtp_client_connect_tcp(conn->client, set->protocol,
conn->set.host, conn->set.port) < 0)
conn->failed = TRUE;
if (proxy->max_timeout_msecs < set->timeout_msecs)
proxy->max_timeout_msecs = set->timeout_msecs;
return conn;
}
static bool lmtp_proxy_send_data_replies(struct lmtp_proxy *proxy)
{
struct lmtp_proxy_recipient *const *rcpt;
unsigned int i, count;
o_stream_cork(proxy->client_output);
rcpt = array_get(&proxy->rcpt_to, &count);
for (i = proxy->next_data_reply_idx; i < count; i++) {
if (!(rcpt[i]->rcpt_to_failed || rcpt[i]->data_reply_received))
break;
o_stream_send_str(proxy->client_output,
t_strconcat(rcpt[i]->reply, "\r\n", NULL));
}
o_stream_uncork(proxy->client_output);
proxy->next_data_reply_idx = i;
return i == count;
}
static void lmtp_proxy_finish_timeout(struct lmtp_proxy *proxy)
{
i_assert(!proxy->finished);
timeout_remove(&proxy->to_finish);
proxy->finished = TRUE;
proxy->finish_callback(proxy->input_timeout, proxy->finish_context);
}
static void lmtp_proxy_finish(struct lmtp_proxy *proxy)
{
if (proxy->to_finish == NULL) {
proxy->to_finish = timeout_add(0, lmtp_proxy_finish_timeout,
proxy);
}
}
static void lmtp_proxy_try_finish(struct lmtp_proxy *proxy)
{
if (proxy->finish_callback == NULL) {
return;
}
if (lmtp_proxy_send_data_replies(proxy) &&
(proxy->data_input == NULL ||
proxy->data_input->eof ||
proxy->data_input->stream_errno != 0 ||
proxy->input_timeout))
lmtp_proxy_finish(proxy);
}
static void lmtp_conn_finish(void *context)
{
struct lmtp_proxy_connection *conn = context;
conn->finished = TRUE;
if (conn->data_input != NULL)
i_stream_unref(&conn->data_input);
lmtp_proxy_try_finish(conn->proxy);
}
static void lmtp_proxy_fail_all(struct lmtp_proxy *proxy, const char *reason)
{
struct lmtp_proxy_connection *const *conns;
unsigned int i, count;
const char *line;
conns = array_get(&proxy->connections, &count);
for (i = 0; i < count; i++) {
line = t_strdup_printf(ERRSTR_TEMP_REMOTE_FAILURE
" (%s while waiting for reply to %s)", reason,
lmtp_client_state_to_string(conns[i]->client));
lmtp_client_fail(conns[i]->client, line);
}
if (proxy->to_finish == NULL) {
if (proxy->io == NULL) {
proxy->io = io_add(i_stream_get_fd(proxy->data_input),
IO_READ,
lmtp_proxy_data_input, proxy);
}
}
}
static void lmtp_proxy_data_input_timeout(struct lmtp_proxy *proxy)
{
struct lmtp_proxy_connection *const *conns;
unsigned int i, count;
proxy->input_timeout = TRUE;
i_stream_close(proxy->orig_data_input);
conns = array_get(&proxy->connections, &count);
for (i = 0; i < count; i++) {
lmtp_client_fail(conns[i]->client, ERRSTR_TEMP_REMOTE_FAILURE
" (timeout in DATA input)");
}
if (proxy->to_finish == NULL) {
lmtp_proxy_finish(proxy);
}
}
static void
lmtp_proxy_conn_rcpt_to(bool success, const char *reply, void *context)
{
struct lmtp_proxy_recipient *rcpt = context;
struct lmtp_proxy_connection *conn = rcpt->conn;
i_assert(rcpt->reply == NULL);
rcpt->reply = p_strdup(conn->proxy->pool, reply);
rcpt->rcpt_to_failed = !success;
}
static void
lmtp_proxy_conn_data(bool success ATTR_UNUSED, const char *reply, void *context)
{
struct lmtp_proxy_recipient *rcpt = context;
struct lmtp_proxy_connection *conn = rcpt->conn;
i_assert(!rcpt->rcpt_to_failed);
i_assert(rcpt->reply != NULL);
rcpt->reply = p_strdup(conn->proxy->pool, reply);
rcpt->data_reply_received = TRUE;
lmtp_proxy_try_finish(conn->proxy);
}
int lmtp_proxy_add_rcpt(struct lmtp_proxy *proxy, const char *address,
const struct lmtp_proxy_settings *set)
{
struct lmtp_proxy_connection *conn;
struct lmtp_proxy_recipient *rcpt;
conn = lmtp_proxy_get_connection(proxy, set);
if (conn->failed)
return -1;
rcpt = p_new(proxy->pool, struct lmtp_proxy_recipient, 1);
rcpt->conn = conn;
rcpt->address = p_strdup(proxy->pool, address);
array_append(&proxy->rcpt_to, &rcpt, 1);
lmtp_client_add_rcpt(conn->client, address, lmtp_proxy_conn_rcpt_to,
lmtp_proxy_conn_data, rcpt);
return 0;
}
static uoff_t lmtp_proxy_find_lowest_offset(struct lmtp_proxy *proxy)
{
struct lmtp_proxy_connection *const *conns;
uoff_t min_offset = (uoff_t)-1;
array_foreach(&proxy->connections, conns) {
struct lmtp_proxy_connection *conn = *conns;
if (conn->data_input != NULL &&
min_offset > conn->data_input->v_offset &&
i_stream_have_bytes_left(conn->data_input))
min_offset = conn->data_input->v_offset;
}
return min_offset;
}
static bool lmtp_proxy_disconnect_hanging_output(struct lmtp_proxy *proxy)
{
struct lmtp_proxy_connection *const *conns;
uoff_t min_offset;
size_t size;
const char *errstr;
min_offset = lmtp_proxy_find_lowest_offset(proxy);
if (min_offset == (uoff_t)-1)
return FALSE;
array_foreach(&proxy->connections, conns) {
struct lmtp_proxy_connection *conn = *conns;
if (conn->data_input != NULL &&
conn->data_input->v_offset == min_offset) {
(void)i_stream_get_data(conn->data_input, &size);
errstr = t_strdup_printf(ERRSTR_TEMP_REMOTE_FAILURE
" (DATA output stalled for %u secs, "
"%"PRIuUOFF_T"B sent, %"PRIuSIZE_T"B buffered)",
proxy->max_timeout_msecs/1000,
min_offset, size);
lmtp_client_fail(conn->client, errstr);
}
}
return TRUE;
}
static void lmtp_proxy_output_timeout(struct lmtp_proxy *proxy)
{
timeout_remove(&proxy->to);
if (lmtp_proxy_disconnect_hanging_output(proxy))
lmtp_proxy_data_input(proxy);
else {
i_assert(proxy->data_input->eof);
lmtp_proxy_fail_all(proxy, "timeout");
}
}
static void lmtp_proxy_wait_for_output(struct lmtp_proxy *proxy)
{
if (proxy->io != NULL)
io_remove(&proxy->io);
if (proxy->to == NULL) {
proxy->to = timeout_add(proxy->max_timeout_msecs,
lmtp_proxy_output_timeout, proxy);
}
}
static void proxy_send_more(struct lmtp_proxy *proxy)
{
struct lmtp_proxy_connection *const *conns;
array_foreach(&proxy->connections, conns)
lmtp_client_send_more((*conns)->client);
}
static bool lmtp_proxy_data_read(struct lmtp_proxy *proxy)
{
size_t size;
timeout_reset(proxy->to_data_idle);
switch (i_stream_read(proxy->data_input)) {
case 0:
if (!tee_i_stream_child_is_waiting(proxy->data_input)) {
if (proxy->io != NULL)
return FALSE;
proxy->io = io_add(i_stream_get_fd(proxy->data_input),
IO_READ,
lmtp_proxy_data_input, proxy);
return FALSE;
}
case -2:
lmtp_proxy_wait_for_output(proxy);
return FALSE;
case -1:
if (proxy->data_input->stream_errno != 0)
lmtp_proxy_fail_all(proxy, "disconnect");
else {
proxy_send_more(proxy);
lmtp_proxy_wait_for_output(proxy);
lmtp_proxy_try_finish(proxy);
}
return FALSE;
default:
if (proxy->to != NULL)
timeout_remove(&proxy->to);
(void)i_stream_get_data(proxy->data_input, &size);
i_stream_skip(proxy->data_input, size);
return TRUE;
}
}
static void lmtp_proxy_data_input(struct lmtp_proxy *proxy)
{
i_assert(!proxy->handling_data_input);
proxy->handling_data_input = TRUE;
do {
proxy_send_more(proxy);
} while (lmtp_proxy_data_read(proxy));
proxy->handling_data_input = FALSE;
}
static void lmtp_proxy_more_data_sent(void *context)
{
struct lmtp_proxy *proxy = context;
if (proxy->to != NULL && !proxy->handling_data_input) {
lmtp_proxy_data_input(proxy);
}
}
void lmtp_proxy_start(struct lmtp_proxy *proxy, struct istream *data_input,
const char *header,
lmtp_proxy_finish_callback_t *callback, void *context)
{
struct lmtp_proxy_connection *const *conns;
proxy->finish_callback = callback;
proxy->finish_context = context;
proxy->orig_data_input = data_input;
proxy->tee_data_input = tee_i_stream_create(data_input);
proxy->data_input = tee_i_stream_create_child(proxy->tee_data_input);
proxy->to_data_idle = timeout_add(LMTP_PROXY_DATA_INPUT_TIMEOUT_MSECS,
lmtp_proxy_data_input_timeout, proxy);
array_foreach(&proxy->connections, conns) {
struct lmtp_proxy_connection *conn = *conns;
if (conn->finished) {
continue;
}
lmtp_client_set_data_output_callback(conn->client,
lmtp_proxy_more_data_sent,
proxy);
conn->data_input =
tee_i_stream_create_child(proxy->tee_data_input);
lmtp_client_set_data_header(conn->client, header);
lmtp_client_send(conn->client, conn->data_input);
}
lmtp_proxy_data_input(proxy);
}