test-gss-server.c   [plain text]


/*
 * gssServerSample.c - gssSample server program
 *
 * Copyright 2004-2005 Massachusetts Institute of Technology.
 * All Rights Reserved.
 *
 * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
 * distribute this software and its documentation for any purpose and
 * without fee is hereby granted, provided that the above copyright
 * notice appear in all copies and that both that copyright notice and
 * this permission notice appear in supporting documentation, and that
 * the name of M.I.T. not be used in advertising or publicity pertaining
 * to distribution of the software without specific, written prior
 * permission.  Furthermore if you modify this software you must label
 * your software as modified software and not distribute it in such a
 * fashion that it might be confused with the original M.I.T. software.
 * M.I.T. makes no representations about the suitability of
 * this software for any purpose.  It is provided "as is" without express
 * or implied warranty.
 */


#include <sys/types.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <Kerberos/Kerberos.h>
#include "test-gss-common.h"

const char *gServiceName = NULL;

/* --------------------------------------------------------------------------- */

static int SetupListeningSocket (int inPort, int *outFD)
{
    int err = 0;
    int fd = -1;
    
    if (!err) {
        fd = socket (AF_INET, SOCK_STREAM, 0);
        if (fd < 0) { err = errno; }
    }
    
    if (!err) {
        struct sockaddr_storage addressStorage;
        struct sockaddr_in *saddr = (struct sockaddr_in *) &addressStorage;
        
        saddr->sin_port = htons (inPort);
        saddr->sin_len = sizeof (struct sockaddr_in);
        saddr->sin_family = AF_INET;
        saddr->sin_addr.s_addr = INADDR_ANY;
        
        err = bind (fd, (struct sockaddr *) saddr, saddr->sin_len);
        if (err < 0) { err = errno; }
    }
    
    if (!err) {
        err = listen (fd, 5);
        if (err < 0) { err = errno; }
    }
    
    if (!err) {
        printf ("listening on port %d\n", inPort);
        *outFD = fd;
        fd = -1; /* only close on error */
    } else {
        printError (err, "SetupListeningSocket failed");
    }
    
    if (fd >= 0) { close (fd); }
    
    return err; 
}

/* --------------------------------------------------------------------------- */

static int Authenticate (int inSocket, gss_ctx_id_t *outContext)
{
    int err = 0;
    OM_uint32 majorStatus;
    OM_uint32 minorStatus = 0;
    gss_ctx_id_t context = GSS_C_NO_CONTEXT;
    
    char *inputTokenBuffer = NULL;
    size_t inputTokenBufferLength = 0;
    gss_buffer_desc inputToken;  /* buffer received from the server */
    
    if (inSocket    <  0)    { err = EINVAL; }
    if (outContext  == NULL) { err = EINVAL; }
    
    /* 
     * The main authentication loop:
     *
     * GSS is a multimechanism API.  The number of packet exchanges required to authenticate 
     * varies between mechanisms.  As a result, we need to loop reading "input tokens" from 
     * the client, calling gss_accept_sec_context on the "input tokens" and send the resulting 
     * "output tokens" back to the client until we get GSS_S_COMPLETE or an error.
     *
     * When we are done, save the client principal so we can make authorization checks.
     */
    
    majorStatus = GSS_S_CONTINUE_NEEDED;
    while (!err && (majorStatus != GSS_S_COMPLETE)) {
        /* Clean up old input buffer */
        if (inputTokenBuffer != NULL) {
            free (inputTokenBuffer);
            inputTokenBuffer = NULL;  /* don't double-free */
        }
        
        err = ReadToken (inSocket, &inputTokenBuffer, &inputTokenBufferLength);
        
        if (!err) {
            /* Set up input buffers for the next run through the loop */
            inputToken.value = inputTokenBuffer;
            inputToken.length = inputTokenBufferLength;
        }
        
        if (!err) {
            gss_buffer_desc outputToken = { 0, NULL }; /* buffer to send to the server */
            
            /*
             * accept_sec_context does the actual work of taking the client's request and 
             * generating an appropriate reply.  Note that we pass GSS_C_NO_CREDENTIAL for
             * the service principal.  This causes the server to accept any service principal
             * in the server's keytab, which enables you to support multihomed hosts by having
             * one key in the keytab for each host identity the server responds on.  
             *
             * However, since we may have more keys in the keytab than we want the server
             * to actually use, we will need to check which service principal the client used
             * after authentication succeeds.  See ServicePrincipalIsValidForService() for
             * where you would put these checks.  We don't check here since if we stopped
             * responding in the middle of the authentication negotiation, the client
             * would get an EOF, and the user wouldn't know what went wrong.
             */
            
            printf ("Calling gss_accept_sec_context...\n");
            majorStatus = gss_accept_sec_context (&minorStatus, &context, GSS_C_NO_CREDENTIAL, 
                                                  &inputToken, GSS_C_NO_CHANNEL_BINDINGS, NULL /* client_name */, 
                                                  NULL /* mech_types */, &outputToken, NULL /* req_flags */, 
                                                  NULL /* time_rec */, NULL /* delegated_cred_handle */);
            
            if ((outputToken.length > 0) && (outputToken.value != NULL)) {
                /* Send the output token to the client (even on error) */
                err = WriteToken (inSocket, outputToken.value, outputToken.length);
                
                /* free the output token */
                gss_release_buffer (&minorStatus, &outputToken);
            }
        }
        
        if ((majorStatus != GSS_S_COMPLETE) && (majorStatus != GSS_S_CONTINUE_NEEDED)) {
            printGSSErrors ("gss_accept_sec_context", majorStatus, minorStatus);
            err = minorStatus ? minorStatus : majorStatus; 
        }            
    }
    
    if (!err) { 
        *outContext = context;
    } else {
        printError (err, "Authenticate failed");
    }
    
    return err;
}

/* --------------------------------------------------------------------------- */

static int ServicePrincipalIsValidForService (const char *inServicePrincipal)
{
    int err = 0;
    krb5_context context = NULL;
    krb5_principal principal = NULL;
    
    if (inServicePrincipal == NULL) { err = EINVAL; }
    
    if (!err) {
        err = krb5_init_context (&context);
    }
    
    if (!err) {
        err = krb5_parse_name (context, inServicePrincipal, &principal);
    }
    
    if (!err) {
        /* 
         * Here is where we check to see if the service principal the client used is valid.
         * Typically we would just check that the first component is the service name.
         * Here we check only if the server was started with the service name option. 
         */
        if ((gServiceName != NULL) && (strcmp (gServiceName, krb5_princ_name (context, principal)->data) != 0)) {
            err = KRB5KRB_AP_WRONG_PRINC;
        }
    }
    
    if (principal != NULL) { krb5_free_principal (context, principal); }
    if (context   != NULL) { krb5_free_context (context); }
    
    return err;
}


/* --------------------------------------------------------------------------- */

static int ClientPrincipalIsAuthorizedForService (const char *inClientPrincipal)
{
    int err = 0;
    krb5_context context = NULL;
    krb5_principal principal = NULL;
    
    if (inClientPrincipal == NULL) { err = EINVAL; }
    
    if (!err) {
        err = krb5_init_context (&context);
    }
    
    if (!err) {
        err = krb5_parse_name (context, inClientPrincipal, &principal);
    }
    
    if (!err) {
        /* 
         * Here is where the server checks to see if the client principal should be allowed
         * to use your service. Typically it should check both the name and the realm, 
         * since with cross-realm shared keys, a user at another realm may be trying to 
         * contact your service.   Most sites don't want to let users from other realms
         * use their services except for specific individuals.
         */
        err = 0;
    }
    
    if (principal != NULL) { krb5_free_principal (context, principal); }
    if (context   != NULL) { krb5_free_context (context); }
    
    return err;
}

/* --------------------------------------------------------------------------- */

static int Authorize (gss_ctx_id_t *inContext, int *outAuthorized, int *outAuthorizationError)
{
    int err = 0;
    OM_uint32 majorStatus;
    OM_uint32 minorStatus = 0;
    gss_name_t clientName = NULL;
    gss_name_t serviceName = NULL;
    char *clientPrincipal = NULL;
    char *servicePrincipal = NULL;

    if (outAuthorized         == NULL) { err = EINVAL; }
    if (outAuthorizationError == NULL) { err = EINVAL; }
    
    if (!err) {
        /* Get the client and service principals used to authenticate */
        majorStatus = gss_inquire_context (&minorStatus, *inContext, &clientName, &serviceName, 
                                           NULL, NULL, NULL, NULL, NULL);
        if (majorStatus != GSS_S_COMPLETE) { err = minorStatus ? minorStatus : majorStatus; }
    }
    
    if (!err) {
        /* Pull the client principal string out of the gss name */
        gss_buffer_desc nameToken;
        
        majorStatus = gss_display_name (&minorStatus, clientName, &nameToken, NULL);
        if (majorStatus != GSS_S_COMPLETE) { err = minorStatus ? minorStatus : majorStatus; }
        
        if (!err) {
            clientPrincipal = malloc (nameToken.length + 1);
            if (clientPrincipal == NULL) { err = ENOMEM; }
        }
        
        if (!err) {
            memcpy (clientPrincipal, nameToken.value, nameToken.length);
            clientPrincipal[nameToken.length] = '\0';
        }        

        if (nameToken.value != NULL) { gss_release_buffer (&minorStatus, &nameToken); }
    }
    
    if (!err) {
        /* Pull the service principal string out of the gss name */
        gss_buffer_desc nameToken;
        
        majorStatus = gss_display_name (&minorStatus, serviceName, &nameToken, NULL);
        if (majorStatus != GSS_S_COMPLETE) { err = minorStatus ? minorStatus : majorStatus; }
        
        if (!err) {
            servicePrincipal = malloc (nameToken.length + 1);
            if (servicePrincipal == NULL) { err = ENOMEM; }
        }
        
        if (!err) {
            memcpy (servicePrincipal, nameToken.value, nameToken.length);
            servicePrincipal[nameToken.length] = '\0';
        }        

        if (nameToken.value != NULL) { gss_release_buffer (&minorStatus, &nameToken); }
    }
    
    if (!err) {
        int authorizationErr = ServicePrincipalIsValidForService (servicePrincipal);
        
        if (!authorizationErr) {
            authorizationErr = ClientPrincipalIsAuthorizedForService (clientPrincipal);
        }
        
        printf ("'%s' is%s authorized for service '%s'\n", 
                    clientPrincipal, authorizationErr ? " NOT" : "", servicePrincipal);            
        
        *outAuthorized = !authorizationErr;
        *outAuthorizationError = authorizationErr;
    }
    
    if (clientPrincipal  == NULL) { free (clientPrincipal); }
    if (servicePrincipal == NULL) { free (servicePrincipal); }

    return err; 
}

/* --------------------------------------------------------------------------- */

static void Usage (const char *argv[])
{
    fprintf (stderr, "Usage: %s [--port portNumber] [--sname serviceName]\n", argv[0]);
    exit (1);
}

/* --------------------------------------------------------------------------- */

int main (int argc, const char *argv[])
{
    int err = 0;
    OM_uint32 minorStatus;
    int port = kDefaultPort;
    int listenFD = -1;
    gss_ctx_id_t gssContext = GSS_C_NO_CONTEXT;
    gss_buffer_desc outputToken = { 0, NULL };
    int i = 0;
        
    for (i = 1; (i < argc) && !err; i++) {
        if ((strcmp (argv[i], "--port") == 0) && (i < (argc - 1))) {
            port = strtol (argv[++i], NULL, 0);
            if (port == 0) { err = errno; }
        } else if ((strcmp(argv[i], "--sname") == 0) && (i < (argc - 1))) {
            gServiceName = argv[++i];
        } else {
            err = EINVAL;
        }
    }

    if (!err) {
        printf ("%s: Starting up...\n", argv[0]);
        
        err = SetupListeningSocket (port, &listenFD);
    }
    
    if (!err) {
        int connectionErr = 0;
        int connectionFD = -1;
        int authorized = 0;
        int authorizationError = 0;
        
        connectionFD = accept (listenFD, NULL, NULL);
        if (connectionFD < 0) {
            if (errno != EINTR) { 
                err = errno;
            }
            //continue;  /* Try again */
        }
        
        printf ("Accepting new connection...\n");
        connectionErr = Authenticate (connectionFD, &gssContext);
        
        if (!connectionErr) {
            connectionErr = Authorize (&gssContext, &authorized, &authorizationError);
        }
        
        if (!connectionErr) {
            char buffer[1024];
            memset (buffer, 0, sizeof (buffer));                

            /* 
             * Here is where your protocol would go.  This sample server just
             * writes a nul terminated string to the client telling whether 
             * it was authorized.
             */
            if (authorized) {
                snprintf (buffer, sizeof (buffer), "SUCCESS!"); 
            } else {
                snprintf (buffer, sizeof(buffer),  "FAILURE! %s (err = %d)", 
                          error_message (authorizationError), authorizationError); 
            }
            connectionErr = WriteEncryptedToken (connectionFD, gssContext, buffer, strlen (buffer) + 1);
        }
        
        if (connectionErr) {
            printError (connectionErr, "Connection failed");
        }
        
        if (connectionFD >= 0) { printf ("Closing connection.\n"); close (connectionFD); }
    }
    
    if (err) { 
        if (err == EINVAL) {
            Usage (argv);
        } else {
            printError (err, "Server failed");
        }
    }
    
    if (listenFD          >= 0)    { close (listenFD); }
    if (gssContext        != NULL) { gss_delete_sec_context (&minorStatus, &gssContext, &outputToken); }
    if (outputToken.value != NULL) { gss_release_buffer (&minorStatus, &outputToken); }
    
    return err ? -1 : 0;
}