#include "sslContext.h"
#include "sslHandshake.h"
#include "sslMemory.h"
#include "sslAlertMessage.h"
#include "sslDebug.h"
#include "sslUtils.h"
#include "sslDigests.h"
#include "appleCdsa.h"
#include <string.h>
#include <assert.h>
OSStatus
SSLEncodeCertificate(SSLRecord &certificate, SSLContext *ctx)
{ OSStatus err;
UInt32 totalLength;
int i, j, certCount;
UInt8 *charPtr;
SSLCertificate *cert;
cert = ctx->localCert;
assert((ctx->negProtocolVersion == SSL_Version_3_0) ||
(ctx->negProtocolVersion == TLS_Version_1_0));
assert((cert != NULL) || (ctx->negProtocolVersion == TLS_Version_1_0));
totalLength = 0;
certCount = 0;
while (cert)
{ totalLength += 3 + cert->derCert.length;
++certCount;
cert = cert->next;
}
certificate.contentType = SSL_RecordTypeHandshake;
certificate.protocolVersion = ctx->negProtocolVersion;
if ((err = SSLAllocBuffer(certificate.contents, totalLength + 7, ctx)) != 0)
return err;
charPtr = certificate.contents.data;
*charPtr++ = SSL_HdskCert;
charPtr = SSLEncodeInt(charPtr, totalLength+3, 3);
charPtr = SSLEncodeInt(charPtr, totalLength, 3);
for (i = 0; i < certCount; ++i)
{ cert = ctx->localCert;
for (j = i+1; j < certCount; ++j)
cert = cert->next;
charPtr = SSLEncodeInt(charPtr, cert->derCert.length, 3);
memcpy(charPtr, cert->derCert.data, cert->derCert.length);
charPtr += cert->derCert.length;
}
assert(charPtr == certificate.contents.data + certificate.contents.length);
if ((ctx->protocolSide == SSL_ClientSide) && (ctx->localCert)) {
ctx->certSent = 1;
assert(ctx->clientCertState == kSSLClientCertRequested);
assert(ctx->certRequested);
ctx->clientCertState = kSSLClientCertSent;
}
return noErr;
}
OSStatus
SSLProcessCertificate(SSLBuffer message, SSLContext *ctx)
{ OSStatus err;
UInt32 listLen, certLen;
UInt8 *p;
SSLCertificate *cert;
p = message.data;
listLen = SSLDecodeInt(p,3);
p += 3;
if (listLen + 3 != message.length) {
sslErrorLog("SSLProcessCertificate: length decode error 1\n");
return errSSLProtocol;
}
while (listLen > 0)
{ certLen = SSLDecodeInt(p,3);
p += 3;
if (listLen < certLen + 3) {
sslErrorLog("SSLProcessCertificate: length decode error 2\n");
return errSSLProtocol;
}
cert = (SSLCertificate *)sslMalloc(sizeof(SSLCertificate));
if(cert == NULL) {
return memFullErr;
}
if ((err = SSLAllocBuffer(cert->derCert, certLen, ctx)) != 0)
{ sslFree(cert);
return err;
}
memcpy(cert->derCert.data, p, certLen);
p += certLen;
cert->next = ctx->peerCert;
ctx->peerCert = cert;
listLen -= 3+certLen;
}
assert(p == message.data + message.length && listLen == 0);
if (ctx->peerCert == 0) {
if((ctx->protocolSide == SSL_ServerSide) &&
(ctx->clientAuth != kAlwaysAuthenticate)) {
return noErr;
}
else {
AlertDescription desc;
if(ctx->negProtocolVersion == SSL_Version_3_0) {
desc = SSL_AlertBadCert;
}
else {
desc = SSL_AlertCertUnknown;
}
SSLFatalSessionAlert(desc, ctx);
return errSSLXCertChainInvalid;
}
}
if((err = sslVerifyCertChain(ctx, *ctx->peerCert)) != 0) {
AlertDescription desc;
switch(err) {
case errSSLUnknownRootCert:
case errSSLNoRootCert:
desc = SSL_AlertUnknownCA;
break;
case errSSLCertExpired:
case errSSLCertNotYetValid:
desc = SSL_AlertCertExpired;
break;
case errSSLXCertChainInvalid:
default:
desc = SSL_AlertCertUnknown;
break;
}
SSLFatalSessionAlert(desc, ctx);
return err;
}
cert = ctx->peerCert;
while (cert->next != 0)
cert = cert->next;
if ((err = sslPubKeyFromCert(ctx,
cert->derCert,
&ctx->peerPubKey,
&ctx->peerPubKeyCsp)) != 0)
return err;
return noErr;
}
OSStatus
SSLEncodeCertificateRequest(SSLRecord &request, SSLContext *ctx)
{
OSStatus err;
UInt32 dnListLen, msgLen;
UInt8 *charPtr;
DNListElem *dn;
assert(ctx->protocolSide == SSL_ServerSide);
dnListLen = 0;
dn = ctx->acceptableDNList;
while (dn)
{ dnListLen += 2 + dn->derDN.length;
dn = dn->next;
}
msgLen = 1 + 1 + 2 + dnListLen;
request.contentType = SSL_RecordTypeHandshake;
assert((ctx->negProtocolVersion == SSL_Version_3_0) ||
(ctx->negProtocolVersion == TLS_Version_1_0));
request.protocolVersion = ctx->negProtocolVersion;
if ((err = SSLAllocBuffer(request.contents, msgLen + 4, ctx)) != 0)
return err;
charPtr = request.contents.data;
*charPtr++ = SSL_HdskCertRequest;
charPtr = SSLEncodeInt(charPtr, msgLen, 3);
*charPtr++ = 1;
*charPtr++ = 1;
charPtr = SSLEncodeInt(charPtr, dnListLen, 2);
dn = ctx->acceptableDNList;
while (dn)
{ charPtr = SSLEncodeInt(charPtr, dn->derDN.length, 2);
memcpy(charPtr, dn->derDN.data, dn->derDN.length);
charPtr += dn->derDN.length;
dn = dn->next;
}
assert(charPtr == request.contents.data + request.contents.length);
return noErr;
}
OSStatus
SSLProcessCertificateRequest(SSLBuffer message, SSLContext *ctx)
{
unsigned i;
unsigned typeCount;
UInt8 *charPtr;
if (message.length < 3) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 1\n");
return errSSLProtocol;
}
charPtr = message.data;
typeCount = *charPtr++;
if (typeCount < 1 || message.length < 3 + typeCount) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 2\n");
return errSSLProtocol;
}
for (i = 0; i < typeCount; i++)
{ if (*charPtr++ == 1)
ctx->x509Requested = 1;
}
#if 0
unsigned dnListLen;
unsigned dnLen;
SSLBuffer dnBuf;
DNListElem *dn;
OSStatus err;
dnListLen = SSLDecodeInt(charPtr, 2);
charPtr += 2;
if (message.length != 3 + typeCount + dnListLen) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 3\n");
return errSSLProtocol;
}
while (dnListLen > 0)
{ if (dnListLen < 2) {
sslErrorLog("SSLProcessCertificateRequest: dnListLen error 1\n");
return errSSLProtocol;
}
dnLen = SSLDecodeInt(charPtr, 2);
charPtr += 2;
if (dnListLen < 2 + dnLen) {
sslErrorLog("SSLProcessCertificateRequest: dnListLen error 2\n");
return errSSLProtocol;
}
if ((err = SSLAllocBuffer(dnBuf, sizeof(DNListElem), ctx)) != 0)
return err;
dn = (DNListElem*)dnBuf.data;
if ((err = SSLAllocBuffer(dn->derDN, dnLen, ctx)) != 0)
{ SSLFreeBuffer(dnBuf, ctx);
return err;
}
memcpy(dn->derDN.data, charPtr, dnLen);
charPtr += dnLen;
dn->next = ctx->acceptableDNList;
ctx->acceptableDNList = dn;
dnListLen -= 2 + dnLen;
}
assert(charPtr == message.data + message.length);
#endif
return noErr;
}
OSStatus
SSLEncodeCertificateVerify(SSLRecord &certVerify, SSLContext *ctx)
{ OSStatus err;
UInt8 hashData[36];
SSLBuffer hashDataBuf, shaMsgState, md5MsgState;
UInt32 len;
UInt32 outputLen;
const CSSM_KEY *cssmKey;
certVerify.contents.data = 0;
hashDataBuf.data = hashData;
hashDataBuf.length = 36;
if ((err = CloneHashState(SSLHashSHA1, ctx->shaState, shaMsgState, ctx)) != 0)
goto fail;
if ((err = CloneHashState(SSLHashMD5, ctx->md5State, md5MsgState, ctx)) != 0)
goto fail;
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->computeCertVfyMac(ctx, hashDataBuf,
shaMsgState, md5MsgState)) != 0)
goto fail;
assert(ctx->signingPrivKeyRef != NULL);
err = SecKeyGetCSSMKey(ctx->signingPrivKeyRef, &cssmKey);
if(err) {
sslErrorLog("SSLEncodeCertificateVerify: SecKeyGetCSSMKey err %d\n", (int)err);
return err;
}
len = sslKeyLengthInBytes(cssmKey);
certVerify.contentType = SSL_RecordTypeHandshake;
assert((ctx->negProtocolVersion == SSL_Version_3_0) ||
(ctx->negProtocolVersion == TLS_Version_1_0));
certVerify.protocolVersion = ctx->negProtocolVersion;
if ((err = SSLAllocBuffer(certVerify.contents, len + 6, ctx)) != 0)
goto fail;
certVerify.contents.data[0] = SSL_HdskCertVerify;
SSLEncodeInt(certVerify.contents.data+1, len+2, 3);
SSLEncodeInt(certVerify.contents.data+4, len, 2);
err = sslRawSign(ctx,
ctx->signingPrivKeyRef,
hashData, 36, certVerify.contents.data+6, len, &outputLen);
if(err) {
goto fail;
}
assert(outputLen == len);
err = noErr;
fail:
SSLFreeBuffer(shaMsgState, ctx);
SSLFreeBuffer(md5MsgState, ctx);
return err;
}
OSStatus
SSLProcessCertificateVerify(SSLBuffer message, SSLContext *ctx)
{ OSStatus err;
UInt8 hashData[36];
UInt16 signatureLen;
SSLBuffer hashDataBuf, shaMsgState, md5MsgState;
unsigned int publicModulusLen;
shaMsgState.data = 0;
md5MsgState.data = 0;
if (message.length < 2) {
sslErrorLog("SSLProcessCertificateVerify: msg len error\n");
return errSSLProtocol;
}
signatureLen = (UInt16)SSLDecodeInt(message.data, 2);
if (message.length != (unsigned)(2 + signatureLen)) {
sslErrorLog("SSLProcessCertificateVerify: sig len error 1\n");
return errSSLProtocol;
}
assert(ctx->peerPubKey != NULL);
publicModulusLen = sslKeyLengthInBytes(ctx->peerPubKey);
if (signatureLen != publicModulusLen) {
sslErrorLog("SSLProcessCertificateVerify: sig len error 2\n");
return errSSLProtocol;
}
hashDataBuf.data = hashData;
hashDataBuf.length = 36;
if ((err = CloneHashState(SSLHashSHA1, ctx->shaState, shaMsgState, ctx)) != 0)
goto fail;
if ((err = CloneHashState(SSLHashMD5, ctx->md5State, md5MsgState, ctx)) != 0)
goto fail;
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->computeCertVfyMac(ctx, hashDataBuf,
shaMsgState, md5MsgState)) != 0)
goto fail;
err = sslRawVerify(ctx,
ctx->peerPubKey,
ctx->peerPubKeyCsp, hashData, 36,
message.data + 2, signatureLen);
if(err) {
SSLFatalSessionAlert(SSL_AlertDecryptError, ctx);
goto fail;
}
err = noErr;
fail:
SSLFreeBuffer(shaMsgState, ctx);
SSLFreeBuffer(md5MsgState, ctx);
return err;
}