#include "testParams.h"
#include <stdlib.h>
#include <stdio.h>
#include <Security/SecureTransport.h>
#include "ioSockThr.h"
#include "testutil.h"
#include <security_utilities/threading.h>
#include <utilLib/common.h>
#define DEFAULT_GETMSG "GET / HTTP/1.0\r\n\r\n"
#define DEFAULT_PORT 443
#define LOCALHOST_RANGE 0
#define ALLOW_ANY_ROOT 0
typedef struct {
const char *hostName;
unsigned short port;
} sslHostDef;
#if LOCALHOST_RANGE
static const sslHostDef knownSslHosts[] =
{
{ "localhost", 1300 },
{ "localhost", 1301 },
{ "localhost", 1302 },
{ "localhost", 1303 },
{ "localhost", 1304 },
{ "localhost", 1305 },
{ "localhost", 1306 },
{ "localhost", 1307 }
};
#else
static const sslHostDef knownSslHosts[] =
{
{"www.amazon.com", DEFAULT_PORT },
{"store.apple.com", DEFAULT_PORT },
{"www.thawte.com", DEFAULT_PORT },
{"account.authorize.net", DEFAULT_PORT },
{"gmail.google.com", DEFAULT_PORT },
{"digitalid.verisign.com", DEFAULT_PORT},
{"www.firstamlink.com", DEFAULT_PORT},
{"remote.harpercollins.com", DEFAULT_PORT},
{"mbanxonlinebanking.harrisbank.com", DEFAULT_PORT},
};
#endif
#define NUM_KNOWN_HOSTS (sizeof(knownSslHosts) / sizeof(sslHostDef))
#define DO_PAUSE 0
static int initFlag;
static SSLProtocol globalTryProt = kSSLProtocolUnknown;
static const char *globalProtStr = NULL;
static bool justOneHost = 1;
static bool resumeEnable = 0;
int sslPingInit(TestParams *testParams)
{
if(initFlag) {
return 0;
}
if(testParams->testOpts == NULL) {
initFlag = 1;
return 0;
}
char *testOpts;
for(testOpts=testParams->testOpts; *testOpts; testOpts++) {
switch(*testOpts) {
case '2':
globalTryProt = kSSLProtocol2;
globalProtStr = "SSL2";
break;
case '3':
globalTryProt = kSSLProtocol3Only;
globalProtStr = "SSL3";
break;
case 't':
globalTryProt = kTLSProtocol1Only;
globalProtStr = "TLS1";
break;
case 'm':
justOneHost = 0;
break;
case 'r':
resumeEnable = 1;
break;
default:
break;
}
}
if(!testParams->quiet) {
printf("...sslPing using %s only\n", globalProtStr);
}
initFlag = 1;
return 0;
}
static Mutex connectLock;
#define ENABLE_SSL2 0
static const char *selectHostAndProt(
unsigned short &port,
SSLProtocol &tryProt,
const char *&protStr)
{
unsigned char r[2];
appGetRandomBytes(r, 2);
if(globalTryProt != kSSLProtocolUnknown) {
tryProt = globalTryProt;
protStr = globalProtStr;
}
else {
unsigned modulo = ENABLE_SSL2 ? 5 : 4;
switch(r[0] % modulo) {
case 0:
tryProt = kSSLProtocol3;
protStr = "SSL3";
break;
case 1:
tryProt = kSSLProtocol3Only;
protStr = "SSL3Only";
break;
case 2:
tryProt = kTLSProtocol1;
protStr = "TLS1";
break;
case 3:
tryProt = kTLSProtocol1Only;
protStr = "TLS1Only";
break;
case 4:
tryProt = kSSLProtocol2;
protStr = "SSL2";
break;
default:
printf("Huh?\n");
exit(1);
}
}
const sslHostDef *hostDef;
if(justOneHost) {
hostDef = &knownSslHosts[0];
}
else {
hostDef = &(knownSslHosts[r[1] % NUM_KNOWN_HOSTS]);
}
port = hostDef->port;
return hostDef->hostName;
}
#define RCV_BUF_SIZE 256
static OSStatus doSslPing(
SSLProtocol tryVersion,
const char *hostName, unsigned short port,
const char *getMsg, CSSM_BOOL allowExpired,
CSSM_BOOL keepConnected,
CSSM_BOOL requireNotify, SSLProtocol *negVersion, SSLCipherSuite *negCipher) {
PeerSpec peerId;
otSocket sock = 0;
OSStatus ortn;
SSLContextRef ctx = NULL;
size_t length;
size_t actLen;
uint8 rcvBuf[RCV_BUF_SIZE];
*negVersion = kSSLProtocolUnknown;
*negCipher = SSL_NULL_WITH_NULL_NULL;
connectLock.lock();
ortn = MakeServerConnection(hostName, port, &sock, &peerId);
connectLock.unlock();
if(ortn) {
printf("MakeServerConnection(%s) returned %d; aborting\n",
hostName, (int)ortn);
return ortn;
}
ortn = SSLNewContext(false, &ctx);
if(ortn) {
printSslErrStr("SSLNewContext", ortn);
goto cleanup;
}
ortn = SSLSetIOFuncs(ctx, SocketRead, SocketWrite);
if(ortn) {
printSslErrStr("SSLSetIOFuncs", ortn);
goto cleanup;
}
ortn = SSLSetProtocolVersion(ctx, tryVersion);
if(ortn) {
printSslErrStr("SSLSetProtocolVersion", ortn);
goto cleanup;
}
ortn = SSLSetConnection(ctx, (SSLConnectionRef)sock);
if(ortn) {
printSslErrStr("SSLSetConnection", ortn);
goto cleanup;
}
if(resumeEnable) {
ortn = SSLSetPeerID(ctx, &peerId, sizeof(PeerSpec));
if(ortn) {
printSslErrStr("SSLSetPeerID", ortn);
goto cleanup;
}
}
if(allowExpired) {
ortn = SSLSetAllowsExpiredCerts(ctx, true);
if(ortn) {
printSslErrStr("SSLSetAllowExpiredCerts", ortn);
goto cleanup;
}
}
#if ALLOW_ANY_ROOT
ortn = SSLSetAllowsAnyRoot(ctx, true);
if(ortn) {
printSslErrStr("SSLSetAllowAnyRoot", ortn);
goto cleanup;
}
#endif
do
{ ortn = SSLHandshake(ctx);
if(ortn == errSSLWouldBlock) {
}
} while (ortn == errSSLWouldBlock);
SSLGetNegotiatedCipher(ctx, negCipher);
SSLGetNegotiatedProtocolVersion(ctx, negVersion);
if(ortn) {
printf("\n");
goto cleanup;
}
length = strlen(getMsg);
ortn = SSLWrite(ctx, getMsg, length, &actLen);
while (1) {
actLen = 0;
ortn = SSLRead(ctx, rcvBuf, RCV_BUF_SIZE, &actLen);
if(actLen == 0) {
}
if (ortn == errSSLWouldBlock) {
ortn = noErr;
}
if(keepConnected) {
if(ortn != noErr) {
break;
}
}
else if(actLen > 0) {
break;
}
}
if(ortn == errSSLClosedGraceful) {
ortn = noErr;
}
if((ortn == errSSLClosedNoNotify) && !requireNotify) {
ortn = noErr;
}
if (ortn == noErr) {
ortn = SSLClose(ctx);
}
cleanup:
if(sock) {
endpointShutdown(sock);
}
if(ctx) {
SSLDisposeContext(ctx);
}
return ortn;
}
int sslPing(TestParams *testParams)
{
unsigned loopNum;
SSLProtocol negVersion;
SSLProtocol tryVersion;
const char *hostName;
unsigned short port;
SSLCipherSuite negCipher;
OSStatus err;
const char *protStr;
for(loopNum=0; loopNum<testParams->numLoops; loopNum++) {
if(!testParams->quiet) {
printChar(testParams->progressChar);
}
hostName = selectHostAndProt(port, tryVersion, protStr);
if(testParams->verbose) {
printf("\nConnecting to host %s with %s...",
hostName, protStr);
fflush(stdout);
}
err = doSslPing(tryVersion,
hostName,
port,
DEFAULT_GETMSG,
CSSM_FALSE, CSSM_FALSE, CSSM_FALSE, &negVersion,
&negCipher);
if(err) {
printf("sslPing error (%d)\n", (int)err);
break;
}
if(testParams->verbose) {
switch(negVersion) {
case kSSLProtocol2:
printf("negVersion = SSL2\n");
break;
case kSSLProtocol3:
printf("negVersion = SSL3\n");
break;
case kTLSProtocol1:
printf("negVersion = TLS1\n");
break;
default:
printf("unknown negVersion! (%d)\n",
(int)negVersion);
break;
}
}
#if DO_PAUSE
fpurge(stdin);
printf("Hit CR to proceed: ");
getchar();
#endif
}
return (int)err;
}