tls1RecordCallouts.c [plain text]
#include "tls_record.h"
#include "sslMemory.h"
#include "sslDebug.h"
#include "sslUtils.h"
#include <AssertMacros.h>
#include <string.h>
#if 0
static int tls1WriteRecord(
SSLRecord rec,
SSLContext *ctx)
{
check(0);
return errSecUnimplemented;
}
#endif
static int tls1DecryptRecord(
uint8_t type,
SSLBuffer *payload,
struct SSLRecordInternalContext *ctx)
{
int err;
SSLBuffer content;
if ((ctx->readCipher.symCipher->params->blockSize > 0) &&
((payload->length % ctx->readCipher.symCipher->params->blockSize) != 0)) {
return errSSLRecordRecordOverflow;
}
if ((err = ctx->readCipher.symCipher->c.cipher.decrypt(payload->data,
payload->data, payload->length,
ctx->readCipher.cipherCtx)) != 0)
{
return errSSLRecordDecryptionFail;
}
if((ctx->negProtocolVersion>=TLS_Version_1_1) && (ctx->readCipher.symCipher->params->blockSize>0))
{
content.data = payload->data + ctx->readCipher.symCipher->params->blockSize;
content.length = payload->length - (ctx->readCipher.macRef->hash->digestSize + ctx->readCipher.symCipher->params->blockSize);
} else {
content.data = payload->data;
content.length = payload->length - ctx->readCipher.macRef->hash->digestSize;
}
if(content.length > payload->length) {
return errSSLRecordClosedAbort;
}
err = 0;
if (ctx->readCipher.symCipher->params->blockSize > 0) {
uint8_t padSize = payload->data[payload->length - 1];
if(padSize+1<=content.length) {
uint8_t *padChars;
content.length -= (1 + padSize);
padChars = payload->data + payload->length - (padSize+1);
while(padChars < (payload->data + payload->length - 1)) {
if(*padChars++ != padSize) {
err = errSSLRecordBadRecordMac;
}
}
} else {
err = errSSLRecordBadRecordMac;
}
}
if (ctx->readCipher.macRef->hash->digestSize > 0)
if (SSLVerifyMac(type, &content,
content.data + content.length, ctx) != 0)
{
err = errSSLRecordBadRecordMac;
}
*payload = content;
return err;
}
static int tls1InitMac (
CipherContext *cipherCtx) {
const HMACReference *hmac;
int serr;
check(cipherCtx);
check(cipherCtx->macRef != NULL);
hmac = cipherCtx->macRef->hmac;
check(hmac != NULL);
if(cipherCtx->macCtx.hmacCtx != NULL) {
hmac->free(cipherCtx->macCtx.hmacCtx);
cipherCtx->macCtx.hmacCtx = NULL;
}
serr = hmac->alloc(hmac, cipherCtx->macSecret,
cipherCtx->macRef->hmac->macSize, &cipherCtx->macCtx.hmacCtx);
memset(cipherCtx->macSecret, 0, sizeof(cipherCtx->macSecret));
return serr;
}
static int tls1FreeMac (
CipherContext *cipherCtx)
{
if(cipherCtx->macRef == NULL) {
return 0;
}
check(cipherCtx->macRef->hmac != NULL);
if(cipherCtx->macCtx.hmacCtx != NULL) {
cipherCtx->macRef->hmac->free(cipherCtx->macCtx.hmacCtx);
cipherCtx->macCtx.hmacCtx = NULL;
}
return 0;
}
#define HDR_LENGTH (8 + 1 + 2 + 2)
static int tls1ComputeMac (
uint8_t type,
SSLBuffer data,
SSLBuffer mac, CipherContext *cipherCtx, sslUint64 seqNo,
struct SSLRecordInternalContext *ctx)
{
uint8_t hdr[HDR_LENGTH];
uint8_t *p;
HMACContextRef hmacCtx;
int serr;
const HMACReference *hmac;
size_t macLength;
check(cipherCtx != NULL);
check(cipherCtx->macRef != NULL);
hmac = cipherCtx->macRef->hmac;
check(hmac != NULL);
hmacCtx = cipherCtx->macCtx.hmacCtx;
serr = hmac->init(hmacCtx);
if(serr) {
goto fail;
}
p = SSLEncodeUInt64(hdr, seqNo);
*p++ = type;
*p++ = ctx->negProtocolVersion >> 8;
*p++ = ctx->negProtocolVersion & 0xff;
*p++ = data.length >> 8;
*p = data.length & 0xff;
serr = hmac->update(hmacCtx, hdr, HDR_LENGTH);
if(serr) {
goto fail;
}
serr = hmac->update(hmacCtx, data.data, data.length);
if(serr) {
goto fail;
}
macLength = mac.length;
serr = hmac->final(hmacCtx, mac.data, &macLength);
if(serr) {
goto fail;
}
mac.length = macLength;
fail:
return serr;
}
const SslRecordCallouts Tls1RecordCallouts = {
tls1DecryptRecord,
ssl3WriteRecord,
tls1InitMac,
tls1FreeMac,
tls1ComputeMac,
};