#include "config.h"
#include <stdio.h>
#include <gssapi.h>
#include <gssapi_scram.h>
#include <gssapi_spi.h>
#include <err.h>
#include <roken.h>
#include <hex.h>
#include <getarg.h>
#include "test_common.h"
static char *user_string = NULL;
static char *password_string = NULL;
#ifdef ENABLE_SCRAM
#include <heimscram.h>
static unsigned int giterations = 1000;
static heim_scram_data gsalt = {
.data = rk_UNCONST("salt"),
.length = 4
};
static int
param(void *ctx,
const heim_scram_data *user,
heim_scram_data *salt,
unsigned int *iteration,
heim_scram_data *servernonce)
{
if (user->length != strlen(user_string) && memcmp(user->data, user_string, user->length) != 0)
return ENOENT;
*iteration = giterations;
salt->data = malloc(gsalt.length);
memcpy(salt->data, gsalt.data, gsalt.length);
salt->length = gsalt.length;
servernonce->data = NULL;
servernonce->length = 0;
return 0;
}
static int
calculate(void *ctx,
heim_scram_method method,
const heim_scram_data *user,
const heim_scram_data *c1,
const heim_scram_data *s1,
const heim_scram_data *c2noproof,
const heim_scram_data *proof,
heim_scram_data *server,
heim_scram_data *sessionKey)
{
heim_scram_data client_key, client_key2, stored_key, server_key, clientSig;
int ret;
memset(&client_key2, 0, sizeof(client_key2));
ret = heim_scram_stored_key(method,
password_string, giterations, &gsalt,
&client_key, &stored_key, &server_key);
if (ret)
return ret;
ret = heim_scram_generate(method, &stored_key, &server_key,
c1, s1, c2noproof, &clientSig, server);
heim_scram_data_free(&server_key);
if (ret)
goto out;
ret = heim_scram_validate_client_signature(method,
&stored_key,
&clientSig,
proof,
&client_key2);
if (ret)
goto out;
if (client_key2.length != client_key.length ||
memcmp(client_key.data, client_key2.data, client_key.length) != 0) {
ret = EINVAL;
goto out;
}
ret = heim_scram_session_key(method,
&stored_key,
&client_key,
c1, s1, c2noproof, sessionKey);
if (ret)
goto out;
out:
heim_scram_data_free(&stored_key);
heim_scram_data_free(&client_key);
heim_scram_data_free(&client_key2);
return ret;
}
static struct heim_scram_server server_proc = {
.version = SCRAM_SERVER_VERSION_1,
.param = param,
.calculate = calculate
};
static gss_cred_id_t client_cred = GSS_C_NO_CREDENTIAL;
static void
ac_complete(void *ctx, OM_uint32 major, gss_status_id_t status,
gss_cred_id_t cred, gss_OID_set oids, OM_uint32 time_rec)
{
OM_uint32 junk;
if (major) {
fprintf(stderr, "error: %d", (int)major);
gss_release_cred(&junk, &cred);
goto out;
}
client_cred = cred;
out:
gss_release_oid_set(&junk, &oids);
}
static int
test_scram(const char *test_name, const char *user, const char *password)
{
gss_name_t target = GSS_C_NO_NAME;
OM_uint32 maj_stat, min_stat;
gss_ctx_id_t ctx = GSS_C_NO_CONTEXT;
gss_buffer_desc input, output, output2;
int ret;
heim_scram *scram = NULL;
heim_scram_data in, out;
gss_auth_identity_desc identity;
memset(&identity, 0, sizeof(identity));
identity.username = rk_UNCONST(user);
identity.realm = "";
identity.password = rk_UNCONST(password);
maj_stat = gss_acquire_cred_ex_f(NULL,
GSS_C_NO_NAME,
0,
GSS_C_INDEFINITE,
GSS_SCRAM_MECHANISM,
GSS_C_INITIATE,
&identity,
NULL,
ac_complete);
if (maj_stat)
errx(1, "gss_acquire_cred_ex_f: %d", (int)maj_stat);
if (client_cred == GSS_C_NO_CREDENTIAL)
errx(1, "gss_acquire_cred_ex_f");
maj_stat = gss_init_sec_context(&min_stat, client_cred, &ctx,
target, GSS_SCRAM_MECHANISM,
0, 0, NULL,
GSS_C_NO_BUFFER, NULL,
&output, NULL, NULL);
if (maj_stat != GSS_S_CONTINUE_NEEDED)
errx(1, "accept_sec_context %s %s", test_name,
gssapi_err(maj_stat, min_stat, GSS_C_NO_OID));
if (output.length == 0)
errx(1, "output.length == 0");
maj_stat = gss_decapsulate_token(&output, GSS_SCRAM_MECHANISM, &output2);
if (maj_stat)
errx(1, "decapsulate token");
in.length = output2.length;
in.data = output2.value;
ret = heim_scram_server1(&in, NULL, HEIM_SCRAM_DIGEST_SHA1, &server_proc, NULL, &scram, &out);
if (ret)
errx(1, "heim_scram_server1");
gss_release_buffer(&min_stat, &output);
input.length = out.length;
input.value = out.data;
maj_stat = gss_init_sec_context(&min_stat, client_cred, &ctx,
target, GSS_SCRAM_MECHANISM,
0, 0, NULL,
&input, NULL,
&output, NULL, NULL);
if (maj_stat != GSS_S_CONTINUE_NEEDED) {
warnx("accept_sec_context v1 2 %s",
gssapi_err(maj_stat, min_stat, GSS_C_NO_OID));
return 1;
}
in.length = output.length;
in.data = output.value;
ret = heim_scram_server2(&in, scram, &out);
if (ret)
errx(1, "heim_scram_server2");
gss_release_buffer(&min_stat, &output);
input.length = out.length;
input.value = out.data;
maj_stat = gss_init_sec_context(&min_stat, client_cred, &ctx,
target, GSS_SCRAM_MECHANISM,
0, 0, NULL,
&input, NULL,
&output, NULL, NULL);
if (maj_stat != GSS_S_COMPLETE) {
warnx("accept_sec_context v1 2 %s",
gssapi_err(maj_stat, min_stat, GSS_C_NO_OID));
return 1;
}
heim_scram_free(scram);
printf("done: %s\n", test_name);
return 0;
}
#endif
static int version_flag = 0;
static int help_flag = 0;
static struct getargs args[] = {
{"user", 0, arg_string, &user_string, "user name", "user" },
{"password",0, arg_string, &password_string, "password", "password" },
{"version", 0, arg_flag, &version_flag, "print version", NULL },
{"help", 0, arg_flag, &help_flag, NULL, NULL }
};
static void
usage (int ret)
{
arg_printusage (args, sizeof(args)/sizeof(*args),
NULL, "");
exit (ret);
}
int
main(int argc, char **argv)
{
int ret = 0, optidx = 0;
setprogname(argv[0]);
if(getarg(args, sizeof(args) / sizeof(args[0]), argc, argv, &optidx))
usage(1);
if (help_flag)
usage(0);
if(version_flag){
print_version(NULL);
exit(0);
}
if (user_string == NULL)
errx(1, "no username");
if (password_string == NULL)
errx(1, "no password");
#ifdef ENABLE_SCRAM
ret += test_scram("scram", user_string, password_string);
#endif
return (ret != 0) ? 1 : 0;
}