#include "sys_defs.h"
#ifdef HAS_PGSQL
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <syslog.h>
#include <time.h>
#include <postgres_ext.h>
#include <libpq-fe.h>
#include "dict.h"
#include "msg.h"
#include "mymalloc.h"
#include "argv.h"
#include "vstring.h"
#include "split_at.h"
#include "find_inet.h"
#include "myrand.h"
#include "events.h"
#include "cfg_parser.h"
#include "dict_pgsql.h"
#define STATACTIVE (1<<0)
#define STATFAIL (1<<1)
#define STATUNTRIED (1<<2)
#define TYPEUNIX (1<<0)
#define TYPEINET (1<<1)
#define RETRY_CONN_MAX 100
#define RETRY_CONN_INTV 60
#define IDLE_CONN_INTV 60
typedef struct {
PGconn *db;
char *hostname;
char *name;
char *port;
unsigned type;
unsigned stat;
time_t ts;
} HOST;
typedef struct {
int len_hosts;
HOST **db_hosts;
} PLPGSQL;
typedef struct {
CFG_PARSER *parser;
char *username;
char *password;
char *dbname;
char *table;
char *query;
char *select_function;
char *select_field;
char *where_field;
char *additional_conditions;
char **hostnames;
int len_hosts;
} PGSQL_NAME;
typedef struct {
DICT dict;
PLPGSQL *pldb;
PGSQL_NAME *name;
} DICT_PGSQL;
#define PGSQL_RES PGresult
static PLPGSQL *plpgsql_init(char *hostnames[], int);
static PGSQL_RES *plpgsql_query(PLPGSQL *, const char *, char *, char *, char *);
static void plpgsql_dealloc(PLPGSQL *);
static void plpgsql_close_host(HOST *);
static void plpgsql_down_host(HOST *);
static void plpgsql_connect_single(HOST *, char *, char *, char *);
static const char *dict_pgsql_lookup(DICT *, const char *);
DICT *dict_pgsql_open(const char *, int, int);
static void dict_pgsql_close(DICT *);
static PGSQL_NAME *pgsqlname_parse(const char *);
static HOST *host_init(const char *);
static void pgsql_escape_string(char *new, const char *old, unsigned int len)
{
unsigned int x,
y;
for (x = 0, y = 0; x < len; x++, y++) {
switch (old[x]) {
case '\n':
new[y++] = '\\';
new[y] = 'n';
break;
case '\r':
new[y++] = '\\';
new[y] = 'r';
break;
case '\'':
new[y++] = '\\';
new[y] = '\'';
break;
case '"':
new[y++] = '\\';
new[y] = '"';
break;
case 0:
new[y++] = '\\';
new[y] = '0';
break;
default:
new[y] = old[x];
break;
}
}
new[y] = 0;
}
static void dict_pgsql_expand_filter(char *filter, char *value, VSTRING *out)
{
const char *myname = "dict_pgsql_expand_filter";
char *sub,
*end;
sub = filter;
end = sub + strlen(filter);
while (sub < end) {
if (*(sub) == '%') {
char *u = value;
char *p = strrchr(u, '@');
switch (*(sub + 1)) {
case 'd':
if (p)
vstring_strcat(out, p + 1);
break;
case 'u':
if (p)
vstring_strncat(out, u, p - u);
else
vstring_strcat(out, u);
break;
default:
msg_warn
("%s: Invalid filter substitution format '%%%c'!",
myname, *(sub + 1));
break;
case 's':
vstring_strcat(out, u);
break;
}
sub++;
} else
vstring_strncat(out, sub, 1);
sub++;
}
}
static const char *dict_pgsql_lookup(DICT *dict, const char *name)
{
PGSQL_RES *query_res;
DICT_PGSQL *dict_pgsql;
PLPGSQL *pldb;
static VSTRING *result;
static VSTRING *query = 0;
int i,
j,
numrows;
char *name_escaped = 0;
int isFunctionCall;
int numcols;
dict_pgsql = (DICT_PGSQL *) dict;
pldb = dict_pgsql->pldb;
query = vstring_alloc(24);
vstring_strcpy(query, "");
if ((name_escaped = (char *) mymalloc((sizeof(char) * (strlen(name) * 2) +1))) == NULL) {
msg_fatal("dict_pgsql_lookup: out of memory.");
}
pgsql_escape_string(name_escaped, name, (unsigned int) strlen(name));
isFunctionCall = (dict_pgsql->name->select_function != NULL);
if (isFunctionCall) {
vstring_sprintf(query, "select %s('%s')",
dict_pgsql->name->select_function,
name_escaped);
} else if (dict_pgsql->name->query) {
dict_pgsql_expand_filter(dict_pgsql->name->query, name_escaped, query);
} else {
vstring_sprintf(query, "select %s from %s where %s = '%s' %s",
dict_pgsql->name->select_field,
dict_pgsql->name->table,
dict_pgsql->name->where_field,
name_escaped,
dict_pgsql->name->additional_conditions);
}
if (msg_verbose)
msg_info("dict_pgsql_lookup using sql query: %s", vstring_str(query));
myfree(name_escaped);
if ((query_res = plpgsql_query(pldb,
vstring_str(query),
dict_pgsql->name->dbname,
dict_pgsql->name->username,
dict_pgsql->name->password)) == 0) {
dict_errno = DICT_ERR_RETRY;
vstring_free(query);
return 0;
}
dict_errno = 0;
vstring_free(query);
numrows = PQntuples(query_res);
if (msg_verbose)
msg_info("dict_pgsql_lookup: retrieved %d rows", numrows);
if (numrows == 0) {
PQclear(query_res);
return 0;
}
numcols = PQnfields(query_res);
if (numcols == 1 && numrows == 1 && isFunctionCall) {
if (PQgetisnull(query_res, 0, 0) == 1) {
PQclear(query_res);
return 0;
}
}
if (result == 0)
result = vstring_alloc(10);
vstring_strcpy(result, "");
for (i = 0; i < numrows; i++) {
if (i > 0)
vstring_strcat(result, ",");
for (j = 0; j < numcols; j++) {
if (j > 0)
vstring_strcat(result, ",");
vstring_strcat(result, PQgetvalue(query_res, i, j));
if (msg_verbose > 1)
msg_info("dict_pgsql_lookup: retrieved field: %d: %s", j, PQgetvalue(query_res, i, j));
}
}
PQclear(query_res);
return vstring_str(result);
}
static int dict_pgsql_check_stat(HOST *host, unsigned stat, unsigned type,
time_t t)
{
if ((host->stat & stat) && (!type || host->type & type)) {
if (host->stat == STATFAIL && host->ts > 0 && host->ts >= t)
return 0;
return 1;
}
return 0;
}
static HOST *dict_pgsql_find_host(PLPGSQL *PLDB, unsigned stat, unsigned type)
{
time_t t;
int count = 0;
int idx;
int i;
t = time((time_t *) 0);
for (i = 0; i < PLDB->len_hosts; i++) {
if (dict_pgsql_check_stat(PLDB->db_hosts[i], stat, type, t))
count++;
}
if (count) {
idx = (count > 1) ? 1 + (count - 1) * (double) myrand() / RAND_MAX : 1;
for (i = 0; i < PLDB->len_hosts; i++) {
if (dict_pgsql_check_stat(PLDB->db_hosts[i], stat, type, t) &&
--idx == 0)
return PLDB->db_hosts[i];
}
}
return 0;
}
static HOST *dict_pgsql_get_active(PLPGSQL *PLDB, char *dbname,
char *username, char *password)
{
const char *myname = "dict_pgsql_get_active";
HOST *host;
int count = RETRY_CONN_MAX;
if ((host = dict_pgsql_find_host(PLDB, STATACTIVE, TYPEUNIX)) != NULL ||
(host = dict_pgsql_find_host(PLDB, STATACTIVE, TYPEINET)) != NULL) {
if (msg_verbose)
msg_info("%s: found active connection to host %s", myname,
host->hostname);
return host;
}
while (--count > 0 &&
((host = dict_pgsql_find_host(PLDB, STATUNTRIED | STATFAIL,
TYPEUNIX)) != NULL ||
(host = dict_pgsql_find_host(PLDB, STATUNTRIED | STATFAIL,
TYPEINET)) != NULL)) {
if (msg_verbose)
msg_info("%s: attempting to connect to host %s", myname,
host->hostname);
plpgsql_connect_single(host, dbname, username, password);
if (host->stat == STATACTIVE)
return host;
}
return 0;
}
static void dict_pgsql_event(int unused_event, char *context)
{
HOST *host = (HOST *) context;
if (host->db)
plpgsql_close_host(host);
}
static PGSQL_RES *plpgsql_query(PLPGSQL *PLDB,
const char *query,
char *dbname,
char *username,
char *password)
{
HOST *host;
PGSQL_RES *res = 0;
while ((host = dict_pgsql_get_active(PLDB, dbname, username, password)) != NULL) {
if ((res = PQexec(host->db, query)) == 0) {
msg_warn("pgsql query failed: %s", PQerrorMessage(host->db));
plpgsql_down_host(host);
} else {
if (msg_verbose)
msg_info("dict_pgsql: successful query from host %s", host->hostname);
event_request_timer(dict_pgsql_event, (char *) host, IDLE_CONN_INTV);
break;
}
}
return res;
}
static void plpgsql_connect_single(HOST *host, char *dbname, char *username, char *password)
{
if ((host->db = PQsetdbLogin(host->name, host->port, NULL, NULL,
dbname, username, password)) != NULL) {
if (PQstatus(host->db) == CONNECTION_OK) {
if (msg_verbose)
msg_info("dict_pgsql: successful connection to host %s",
host->hostname);
host->stat = STATACTIVE;
} else {
msg_warn("connect to pgsql server %s: %s",
host->hostname, PQerrorMessage(host->db));
plpgsql_down_host(host);
}
} else {
msg_warn("connect to pgsql server %s: %s",
host->hostname, PQerrorMessage(host->db));
plpgsql_down_host(host);
}
}
static void plpgsql_close_host(HOST *host)
{
if (host->db)
PQfinish(host->db);
host->db = 0;
host->stat = STATUNTRIED;
}
static void plpgsql_down_host(HOST *host)
{
if (host->db)
PQfinish(host->db);
host->db = 0;
host->ts = time((time_t *) 0) + RETRY_CONN_INTV;
host->stat = STATFAIL;
event_cancel_timer(dict_pgsql_event, (char *) host);
}
DICT *dict_pgsql_open(const char *name, int open_flags, int dict_flags)
{
DICT_PGSQL *dict_pgsql;
if (open_flags != O_RDONLY)
msg_fatal("%s:%s map requires O_RDONLY access mode",
DICT_TYPE_PGSQL, name);
dict_pgsql = (DICT_PGSQL *) dict_alloc(DICT_TYPE_PGSQL, name,
sizeof(DICT_PGSQL));
dict_pgsql->dict.lookup = dict_pgsql_lookup;
dict_pgsql->dict.close = dict_pgsql_close;
dict_pgsql->name = pgsqlname_parse(name);
dict_pgsql->pldb = plpgsql_init(dict_pgsql->name->hostnames,
dict_pgsql->name->len_hosts);
dict_pgsql->dict.flags = dict_flags | DICT_FLAG_FIXED;
if (dict_pgsql->pldb == NULL)
msg_fatal("couldn't intialize pldb!\n");
return &dict_pgsql->dict;
}
static PGSQL_NAME *pgsqlname_parse(const char *pgsqlcf)
{
const char *myname = "pgsqlname_parse";
int i;
char *hosts;
PGSQL_NAME *name = (PGSQL_NAME *) mymalloc(sizeof(PGSQL_NAME));
ARGV *hosts_argv;
name->parser = cfg_parser_alloc(pgsqlcf);
name->username = cfg_get_str(name->parser, "user", "", 0, 0);
name->password = cfg_get_str(name->parser, "password", "", 0, 0);
name->dbname = cfg_get_str(name->parser, "dbname", "", 1, 0);
name->select_function = cfg_get_str(name->parser, "select_function",
NULL, 0, 0);
name->query = cfg_get_str(name->parser, "query", NULL, 0, 0);
if (name->select_function == 0 && name->query == 0) {
name->table = cfg_get_str(name->parser, "table", "", 1, 0);
name->select_field = cfg_get_str(name->parser, "select_field",
"", 1, 0);
name->where_field = cfg_get_str(name->parser, "where_field",
"", 1, 0);
name->additional_conditions = cfg_get_str(name->parser,
"additional_conditions",
"", 0, 0);
} else {
name->table = 0;
name->select_field = 0;
name->where_field = 0;
name->additional_conditions = 0;
}
hosts = cfg_get_str(name->parser, "hosts", "", 0, 0);
hosts_argv = argv_split(hosts, " ,\t\r\n");
if (hosts_argv->argc == 0) {
if (msg_verbose)
msg_info("%s: %s: no hostnames specified, defaulting to 'localhost'",
myname, pgsqlcf);
argv_add(hosts_argv, "localhost", ARGV_END);
argv_terminate(hosts_argv);
}
name->len_hosts = hosts_argv->argc;
name->hostnames = (char **) mymalloc((sizeof(char *)) * name->len_hosts);
i = 0;
for (i = 0; hosts_argv->argv[i] != NULL; i++) {
name->hostnames[i] = mystrdup(hosts_argv->argv[i]);
if (msg_verbose)
msg_info("%s: %s: adding host '%s' to list of pgsql server hosts",
myname, pgsqlcf, name->hostnames[i]);
}
myfree(hosts);
argv_free(hosts_argv);
return name;
}
static PLPGSQL *plpgsql_init(char *hostnames[], int len_hosts)
{
PLPGSQL *PLDB;
int i;
if ((PLDB = (PLPGSQL *) mymalloc(sizeof(PLPGSQL))) == NULL) {
msg_fatal("mymalloc of pldb failed");
}
PLDB->len_hosts = len_hosts;
if ((PLDB->db_hosts = (HOST **) mymalloc(sizeof(HOST *) * len_hosts)) == NULL)
return NULL;
for (i = 0; i < len_hosts; i++) {
PLDB->db_hosts[i] = host_init(hostnames[i]);
}
return PLDB;
}
static HOST *host_init(const char *hostname)
{
const char *myname = "pgsql host_init";
HOST *host = (HOST *) mymalloc(sizeof(HOST));
const char *d = hostname;
host->db = 0;
host->hostname = mystrdup(hostname);
host->stat = STATUNTRIED;
host->ts = 0;
if (strncmp(d, "unix:", 5) == 0 || strncmp(d, "inet:", 5) == 0)
d += 5;
host->name = mystrdup(d);
host->port = split_at_right(host->name, ':');
if (host->name[0] && host->name[0] != '/')
host->type = TYPEINET;
else
host->type = TYPEUNIX;
if (msg_verbose > 1)
msg_info("%s: host=%s, port=%s, type=%s", myname, host->name,
host->port ? host->port : "",
host->type == TYPEUNIX ? "unix" : "inet");
return host;
}
static void dict_pgsql_close(DICT *dict)
{
int i;
DICT_PGSQL *dict_pgsql = (DICT_PGSQL *) dict;
plpgsql_dealloc(dict_pgsql->pldb);
cfg_parser_free(dict_pgsql->name->parser);
myfree(dict_pgsql->name->username);
myfree(dict_pgsql->name->password);
myfree(dict_pgsql->name->dbname);
if (dict_pgsql->name->table)
myfree(dict_pgsql->name->table);
if (dict_pgsql->name->query)
myfree(dict_pgsql->name->query);
if (dict_pgsql->name->select_function)
myfree(dict_pgsql->name->select_function);
if (dict_pgsql->name->select_field)
myfree(dict_pgsql->name->select_field);
if (dict_pgsql->name->where_field)
myfree(dict_pgsql->name->where_field);
if (dict_pgsql->name->additional_conditions)
myfree(dict_pgsql->name->additional_conditions);
for (i = 0; i < dict_pgsql->name->len_hosts; i++) {
myfree(dict_pgsql->name->hostnames[i]);
}
myfree((char *) dict_pgsql->name->hostnames);
myfree((char *) dict_pgsql->name);
dict_free(dict);
}
static void plpgsql_dealloc(PLPGSQL *PLDB)
{
int i;
for (i = 0; i < PLDB->len_hosts; i++) {
event_cancel_timer(dict_pgsql_event, (char *) (PLDB->db_hosts[i]));
if (PLDB->db_hosts[i]->db)
PQfinish(PLDB->db_hosts[i]->db);
myfree(PLDB->db_hosts[i]->hostname);
myfree(PLDB->db_hosts[i]->name);
myfree((char *) PLDB->db_hosts[i]);
}
myfree((char *) PLDB->db_hosts);
myfree((char *) (PLDB));
}
#endif