#include "tls_handshake_priv.h"
#include "sslHandshake.h"
#include "sslMemory.h"
#include "sslAlertMessage.h"
#include "sslDebug.h"
#include "sslUtils.h"
#include "sslDigests.h"
#include "sslCrypto.h"
#include <string.h>
#include <assert.h>
int SSLFreeCertificates(SSLCertificate *certs)
{
return tls_free_buffer_list((tls_buffer_list_t *)certs);
}
int SSLFreeDNList(DNListElem *dn)
{
return tls_free_buffer_list((tls_buffer_list_t *)dn);
}
int
SSLEncodeCertificate(tls_buffer *certificate, tls_handshake_t ctx)
{
int err;
size_t totalLength;
uint8_t *charPtr;
int certCount;
SSLCertificate *cert;
int head;
assert(ctx->negProtocolVersion >= tls_protocol_version_SSL_3);
assert((ctx->localCert != NULL) || (ctx->negProtocolVersion >= tls_protocol_version_TLS_1_0));
totalLength = 0;
certCount = 0;
if(ctx->isServer || ctx->negAuthType != tls_client_auth_type_None) {
cert = ctx->localCert;
while (cert)
{ totalLength += 3 + cert->derCert.length;
++certCount;
cert = cert->next;
}
cert = ctx->localCert;
} else {
certCount = 0;
cert = NULL;
}
head = SSLHandshakeHeaderSize(ctx);
if ((err = SSLAllocBuffer(certificate, totalLength + head + 3)))
return err;
charPtr = SSLEncodeHandshakeHeader(ctx, certificate, SSL_HdskCert, totalLength+3);
charPtr = SSLEncodeSize(charPtr, totalLength, 3);
while(cert) {
charPtr = SSLEncodeSize(charPtr, cert->derCert.length, 3);
memcpy(charPtr, cert->derCert.data, cert->derCert.length);
charPtr += cert->derCert.length;
cert = cert->next;
}
assert(charPtr == certificate->data + certificate->length);
if ((!ctx->isServer) && (ctx->negAuthType != tls_client_auth_type_None)) {
ctx->certSent = 1;
assert(ctx->clientCertState == kSSLClientCertRequested);
assert(ctx->certRequested);
ctx->clientCertState = kSSLClientCertSent;
}
if(certCount == 0) {
sslCertDebug("...sending empty cert msg");
}
return errSSLSuccess;
}
static bool
CertificateChainEqual(SSLCertificate *cert1, SSLCertificate *cert2)
{
do {
if(cert1 == NULL || cert2 == NULL)
return false;
if(cert1->derCert.length != cert2->derCert.length)
return false;
if(memcmp(cert1->derCert.data, cert2->derCert.data, cert1->derCert.length) != 0)
return false;
cert1 = cert1->next;
cert2 = cert2->next;
} while (cert1 != NULL || cert2 != NULL);
return true;
}
static const char base64_chars[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static void
debug_log_chain(const char *scope, SSLCertificate *cert)
{
size_t n, m, count = 0;
while(cert) {
char line[65];
uint32_t c;
ssl_secdebug(scope, "cert: %lu", (unsigned long)count++);
ssl_secdebug(scope, "-----BEGIN CERTIFICATE-----");
m = n = 0;
while (n < cert->derCert.length) {
c = cert->derCert.data[n];
n++;
c = c << 8;
if (n < cert->derCert.length)
c |= cert->derCert.data[n];
n++;
c = c << 8;
if (n < cert->derCert.length)
c |= cert->derCert.data[n];
n++;
line[m++] = base64_chars[(c & 0x00fc0000) >> 18];
line[m++] = base64_chars[(c & 0x0003f000) >> 12];
if (n > cert->derCert.length + 1)
line[m++] = '=';
else
line[m++] = base64_chars[(c & 0x00000fc0) >> 6];
if (n > cert->derCert.length)
line[m++] = '=';
else
line[m++] = base64_chars[(c & 0x0000003f) >> 0];
if (m == sizeof(line) - 1) {
line[sizeof(line) - 1] = '\0';
ssl_secdebug(scope, "%s", line);
m = 0;
}
assert(m < sizeof(line) - 1);
}
if (m) {
line[m] = '\0';
ssl_secdebug(scope, "%s", line);
}
ssl_secdebug(scope, "-----END CERTIFICATE-----");
cert = cert->next;
}
}
int
SSLProcessCertificate(tls_buffer message, tls_handshake_t ctx)
{
size_t listLen;
UInt8 *p;
int err = 0;
SSLCertificate *certChain;
if (message.length < 3) {
sslErrorLog("SSLProcessCertificate: message length decode error\n");
return errSSLProtocol;
}
p = message.data;
listLen = SSLDecodeInt(p,3);
p += 3;
if (listLen + 3 != message.length) {
sslErrorLog("SSLProcessCertificate: length decode error 1\n");
return errSSLProtocol;
}
if((err = SSLDecodeBufferList(p, listLen, 3, (tls_buffer_list_t **)&certChain))) {
return err;
}
p+=listLen;
if(!ctx->allowServerIdentityChange && ctx->peerCert && !CertificateChainEqual(ctx->peerCert, certChain)) {
sslErrorLog("Illegal server identity change during renegotiation\n");
SSLFreeCertificates(certChain);
return errSSLProtocol;
}
if (ctx->peerCert == NULL && __ssl_debug_enabled("sslLogNegotiateDebug")) {
debug_log_chain("sslLogNegotiateDebug", certChain);
}
SSLFreeCertificates(ctx->peerCert);
ctx->peerCert=certChain;
assert(p == message.data + message.length);
return err;
}
int
SSLEncodeCertificateStatus(tls_buffer *status, tls_handshake_t ctx)
{
int err;
size_t totalLength;
uint8_t *charPtr;
int head;
assert(ctx->isServer);
assert(ctx->ocsp_enabled && ctx->ocsp_peer_enabled);
if(ctx->ocsp_response.length==0) {
return errSSLInternal;
}
totalLength = 1 + 3 + ctx->ocsp_response.length;
head = SSLHandshakeHeaderSize(ctx);
if ((err = SSLAllocBuffer(status, totalLength + head)))
return err;
charPtr = SSLEncodeHandshakeHeader(ctx, status, SSL_HdskCertificateStatus, totalLength);
*charPtr++ = SSL_CST_Ocsp;
charPtr = SSLEncodeSize(charPtr, ctx->ocsp_response.length, 3);
memcpy(charPtr, ctx->ocsp_response.data, ctx->ocsp_response.length);
return 0;
}
int
SSLProcessCertificateStatus(tls_buffer message, tls_handshake_t ctx)
{
uint8_t status_type;
uint8_t *p = message.data;
assert(!ctx->isServer);
if (message.length < 1) {
sslErrorLog("SSLProcessCertificateStatus: message length decode error (1)\n");
return errSSLProtocol;
}
status_type = *p++;
if(status_type!=SSL_CST_Ocsp) {
return noErr;
}
if (message.length < 3) {
sslErrorLog("SSLProcessCertificateStatus: message length decode error (2)\n");
return errSSLProtocol;
}
size_t OCSPResponseLen = SSLDecodeSize(p, 3); p+=3;
if(OCSPResponseLen==0) {
sslErrorLog("SSLProcessCertificateStatus: message length decode error (3)\n");
return errSSLProtocol;
}
if(OCSPResponseLen+4 != message.length) {
sslErrorLog("SSLProcessCertificateStatus: message length decode error (4)\n");
return errSSLProtocol;
}
ctx->ocsp_response_received = true;
SSLFreeBuffer(&ctx->ocsp_response);
return SSLCopyBufferFromData(p, OCSPResponseLen, &ctx->ocsp_response);
}
int
SSLEncodeCertificateRequest(tls_buffer *request, tls_handshake_t ctx)
{
int err;
size_t shListLen = 0, dnListLen, msgLen;
UInt8 *charPtr;
DNListElem *dn;
int head;
assert(ctx->isServer);
if (sslVersionIsLikeTls12(ctx)) {
shListLen = 2 + 2 * ctx->numLocalSigAlgs;
}
dnListLen = 0;
dn = ctx->acceptableDNList;
while (dn)
{ dnListLen += 2 + dn->derDN.length;
dn = dn->next;
}
msgLen = 1 + 2 + shListLen + 2 + dnListLen;
assert(ctx->negProtocolVersion >= tls_protocol_version_SSL_3);
head = SSLHandshakeHeaderSize(ctx);
if ((err = SSLAllocBuffer(request, msgLen + head)))
return err;
charPtr = SSLEncodeHandshakeHeader(ctx, request, SSL_HdskCertRequest, msgLen);
*charPtr++ = 2;
*charPtr++ = tls_client_auth_type_RSASign;
*charPtr++ = tls_client_auth_type_ECDSASign;
if (shListLen) {
charPtr = SSLEncodeSize(charPtr, shListLen - 2, 2);
for(int i=0; i<ctx->numLocalSigAlgs; i++) {
charPtr = SSLEncodeInt(charPtr, ctx->localSigAlgs[i].hash, 1);
charPtr = SSLEncodeInt(charPtr, ctx->localSigAlgs[i].signature, 1);
}
}
charPtr = SSLEncodeSize(charPtr, dnListLen, 2);
dn = ctx->acceptableDNList;
while (dn)
{ charPtr = SSLEncodeSize(charPtr, dn->derDN.length, 2);
memcpy(charPtr, dn->derDN.data, dn->derDN.length);
charPtr += dn->derDN.length;
dn = dn->next;
}
assert(charPtr == request->data + request->length);
return errSSLSuccess;
}
int
SSLProcessCertificateRequest(tls_buffer message, tls_handshake_t ctx)
{
unsigned i;
unsigned typeCount;
unsigned shListLen = 0;
UInt8 *charPtr;
unsigned dnListLen;
unsigned dnLen;
tls_buffer dnBuf;
DNListElem *dn;
int err;
unsigned minLen = (sslVersionIsLikeTls12(ctx)) ? 5 : 3;
if (message.length < minLen) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 1\n");
return errSSLProtocol;
}
charPtr = message.data;
typeCount = *charPtr++;
if ((typeCount < 1) || (message.length < minLen + typeCount)) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 2\n");
return errSSLProtocol;
}
sslFree(ctx->clientAuthTypes);
ctx->numAuthTypes = typeCount;
ctx->clientAuthTypes = (tls_client_auth_type *)
sslMalloc(ctx->numAuthTypes * sizeof(tls_client_auth_type));
if(ctx->clientAuthTypes==NULL)
return errSSLInternal;
for(i=0; i<ctx->numAuthTypes; i++) {
sslLogNegotiateDebug("===Server specifies authType %d", (int)(*charPtr));
ctx->clientAuthTypes[i] = (tls_client_auth_type)(*charPtr++);
}
if (sslVersionIsLikeTls12(ctx)) {
shListLen = SSLDecodeInt(charPtr, 2);
charPtr += 2;
if ((shListLen < 2) || (message.length < minLen + typeCount + shListLen)) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 3\n");
return errSSLProtocol;
}
if (shListLen & 1) {
sslErrorLog("SSLProcessCertificateRequest: signAlg len odd\n");
return errSSLProtocol;
}
sslFree(ctx->peerSigAlgs);
ctx->numPeerSigAlgs = shListLen / 2;
ctx->peerSigAlgs = (tls_signature_and_hash_algorithm *)
sslMalloc((ctx->numPeerSigAlgs) * sizeof(tls_signature_and_hash_algorithm));
if(ctx->peerSigAlgs==NULL)
return errSSLInternal;
for(i=0; i<ctx->numPeerSigAlgs; i++) {
ctx->peerSigAlgs[i].hash = *charPtr++;
ctx->peerSigAlgs[i].signature = *charPtr++;
sslLogNegotiateDebug("===Server specifies sigAlg %d %d",
ctx->peerSigAlgs[i].hash,
ctx->peerSigAlgs[i].signature);
}
}
SSLFreeDNList(ctx->acceptableDNList);
ctx->acceptableDNList=NULL;
dnListLen = SSLDecodeInt(charPtr, 2);
charPtr += 2;
if (message.length != minLen + typeCount + shListLen + dnListLen) {
sslErrorLog("SSLProcessCertificateRequest: length decode error 3\n");
return errSSLProtocol;
}
while (dnListLen > 0)
{ if (dnListLen < 2) {
sslErrorLog("SSLProcessCertificateRequest: dnListLen error 1\n");
return errSSLProtocol;
}
dnLen = SSLDecodeInt(charPtr, 2);
charPtr += 2;
if (dnListLen < 2 + dnLen) {
sslErrorLog("SSLProcessCertificateRequest: dnListLen error 2\n");
return errSSLProtocol;
}
if ((err = SSLAllocBuffer(&dnBuf, sizeof(DNListElem))))
return err;
dn = (DNListElem*)dnBuf.data;
if ((err = SSLAllocBuffer(&dn->derDN, dnLen)))
{ SSLFreeBuffer(&dnBuf);
return err;
}
memcpy(dn->derDN.data, charPtr, dnLen);
charPtr += dnLen;
dn->next = ctx->acceptableDNList;
ctx->acceptableDNList = dn;
dnListLen -= 2 + dnLen;
}
assert(charPtr == message.data + message.length);
return errSSLSuccess;
}
static
int FindCertSigAlg(tls_handshake_t ctx, tls_signature_and_hash_algorithm *alg)
{
assert(!ctx->isServer);
assert(ctx->negProtocolVersion >= tls_protocol_version_TLS_1_2);
assert(!ctx->isDTLS);
if((ctx->numPeerSigAlgs==0) || (ctx->numLocalSigAlgs==0)) {
assert(0);
return errSSLInternal;
}
for(int i=0; i<ctx->numLocalSigAlgs; i++) {
if (alg->signature != ctx->localSigAlgs[i].signature)
continue;
alg->hash = ctx->localSigAlgs[i].hash;
for(int j=0; j<ctx->numPeerSigAlgs; j++) {
if (alg->signature != ctx->peerSigAlgs[j].signature)
continue;
if(alg->hash == ctx->peerSigAlgs[j].hash) {
return errSSLSuccess;
}
}
}
return errSSLProtocol;
}
int
SSLEncodeCertificateVerify(tls_buffer *certVerify, tls_handshake_t ctx)
{ int err;
UInt8 hashData[SSL_MAX_DIGEST_LEN];
tls_buffer hashDataBuf;
size_t len;
size_t outputLen;
UInt8 *charPtr;
int head;
size_t maxSigLen;
certVerify->data = 0;
hashDataBuf.data = hashData;
hashDataBuf.length = SSL_MAX_DIGEST_LEN;
assert(ctx->signingPrivKeyRef != NULL);
err = sslGetMaxSigSize(ctx->signingPrivKeyRef, &maxSigLen);
if(err) {
goto fail;
}
tls_signature_and_hash_algorithm sigAlg = {0,};
switch(ctx->signingPrivKeyRef->desc.type) {
case tls_private_key_type_rsa:
sigAlg.signature = tls_signature_algorithm_RSA;
break;
case tls_private_key_type_ecdsa:
sigAlg.signature = tls_signature_algorithm_ECDSA;
if (ctx->negProtocolVersion <= tls_protocol_version_SSL_3) {
return errSSLInternal;
}
break;
default:
assert(0);
return errSSLInternal;
}
assert(ctx->negProtocolVersion >= tls_protocol_version_SSL_3);
head = SSLHandshakeHeaderSize(ctx);
outputLen = maxSigLen + head + 2;
if (sslVersionIsLikeTls12(ctx)) {
err=FindCertSigAlg(ctx, &sigAlg);
if(err)
goto fail;
outputLen += 2;
ctx->certSigAlg = sigAlg; }
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->computeCertVfyMac(ctx, &hashDataBuf, sigAlg.hash)) != 0)
goto fail;
if ((err = SSLAllocBuffer(certVerify, outputLen)) != 0)
goto fail;
charPtr = certVerify->data+head;
if (sslVersionIsLikeTls12(ctx))
{
*charPtr++ = sigAlg.hash;
*charPtr++ = sigAlg.signature;
switch (sigAlg.hash) {
case tls_hash_algorithm_SHA512:
case tls_hash_algorithm_SHA384:
case tls_hash_algorithm_SHA256:
case tls_hash_algorithm_SHA1:
break;
default:
sslErrorLog("SSLEncodeCertificateVerify: unsupported signature hash algorithm (%d)\n",
sigAlg.hash);
assert(0); err=errSSLInternal;
goto fail;
}
if (sigAlg.signature == tls_signature_algorithm_RSA) {
err = sslRsaSign(ctx->signingPrivKeyRef,
sigAlg.hash,
hashData,
hashDataBuf.length,
charPtr+2,
maxSigLen,
&outputLen);
} else {
err = sslEcdsaSign(ctx->signingPrivKeyRef,
hashData,
hashDataBuf.length,
charPtr+2,
maxSigLen,
&outputLen);
}
len=outputLen+2+2;
} else {
err = sslRawSign(ctx->signingPrivKeyRef,
hashData, hashDataBuf.length, charPtr+2, maxSigLen, &outputLen);
len = outputLen+2;
}
if(err) {
sslErrorLog("SSLEncodeCertificateVerify: unable to sign data (error %d)\n", (int)err);
goto fail;
}
certVerify->length = len + head;
SSLEncodeSize(charPtr, outputLen, 2);
SSLEncodeHandshakeHeader(ctx, certVerify, SSL_HdskCertVerify, len);
err = errSSLSuccess;
fail:
return err;
}
int
SSLProcessCertificateVerify(tls_buffer message, tls_handshake_t ctx)
{ int err;
UInt8 hashData[SSL_MAX_DIGEST_LEN];
size_t signatureLen;
tls_buffer hashDataBuf;
uint8_t *charPtr = message.data;
uint8_t *endCp = charPtr + message.length;
tls_signature_and_hash_algorithm sigAlg = {0,};
if (sslVersionIsLikeTls12(ctx)) {
if((charPtr+2) > endCp) {
sslErrorLog("SSLProcessCertificateVerify: msg len error 1\n");
return errSSLProtocol;
}
sigAlg.hash = *charPtr++;
sigAlg.signature = *charPtr++;
}
if ((charPtr + 2) > endCp) {
sslErrorLog("SSLProcessCertificateVerify: msg len error\n");
return errSSLProtocol;
}
signatureLen = SSLDecodeSize(charPtr, 2);
charPtr += 2;
if ((charPtr + signatureLen) > endCp) {
sslErrorLog("SSLProcessCertificateVerify: sig len error 1\n");
return errSSLProtocol;
}
hashDataBuf.data = hashData;
hashDataBuf.length = SSL_MAX_DIGEST_LEN;
assert(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->computeCertVfyMac(ctx, &hashDataBuf, sigAlg.hash)) != 0)
goto fail;
if (sslVersionIsLikeTls12(ctx))
{
if(sigAlg.signature==tls_signature_algorithm_RSA) {
err = sslRsaVerify(&ctx->peerPubKey,
sigAlg.hash,
hashData,
hashDataBuf.length,
charPtr,
signatureLen);
} else {
err = sslRawVerify(&ctx->peerPubKey,
hashData,
hashDataBuf.length,
charPtr,
signatureLen);
}
} else {
err = sslRawVerify(&ctx->peerPubKey,
hashData, hashDataBuf.length,
charPtr, signatureLen);
}
if(err) {
SSLFatalSessionAlert(SSL_AlertDecryptError, ctx);
goto fail;
}
err = errSSLSuccess;
fail:
return err;
}