#include <Security/cssmapple.h>
#include <open_ssl/openssl/bn.h>
#include <PBKDF2/pbkdDigest.h>
#include "pkcs12Derive.h"
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <SecurityNssAsn1/SecNssCoder.h>
typedef enum {
PBE_ID_Key = 1,
PBE_ID_IV = 2,
PBE_ID_MAC = 3
} P12_PBE_ID;
#if 0
typedef CSSM_CC_HANDLE HashHand;
static HashHand hashCreate(CSSM_CSP_HANDLE cspHand,
CSSM_ALGORITHMS alg)
{
CSSM_CC_HANDLE hashHand;
CSSM_RETURN crtn = CSSM_CSP_CreateDigestContext(cspHand,
alg,
&hashHand);
if(crtn) {
printf("CSSM_CSP_CreateDigestContext error\n");
return 0;
}
return hashHand;
}
static CSSM_RETURN hashInit(HashHand hand)
{
return CSSM_DigestDataInit(hand);
}
static CSSM_RETURN hashUpdate(HashHand hand,
const unsigned char *buf,
unsigned bufLen)
{
const CSSM_DATA cdata = {bufLen, (uint8 *)buf};
return CSSM_DigestDataUpdate(hand, &cdata, 1);
}
static CSSM_RETURN hashFinal(HashHand hand,
unsigned char *digest, unsigned *digestLen) {
CSSM_DATA cdata = {(uint32)digestLen, digest};
return CSSM_DigestDataFinal(hand, &cdata);
}
static CSSM_RETURN hashDone(HashHand hand)
{
return CSSM_DeleteContext(hand);
}
#endif
static unsigned char *p12StrCat(
const unsigned char *inStr,
unsigned inStrLen,
SecNssCoder &coder,
unsigned outLen,
unsigned char *outStr = NULL) {
if(outStr == NULL) {
outStr = (unsigned char *)coder.malloc(outLen);
}
unsigned toMove = outLen;
unsigned char *outp = outStr;
while(toMove) {
unsigned thisMove = inStrLen;
if(thisMove > toMove) {
thisMove = toMove;
}
memmove(outp, inStr, thisMove);
toMove -= thisMove;
outp += thisMove;
}
return outStr;
}
static CSSM_RETURN p12PbeGen(
const CSSM_DATA &pwd, const uint8 *salt,
unsigned saltLen,
unsigned iterCount,
P12_PBE_ID pbeId,
CSSM_ALGORITHMS hashAlg, SecNssCoder &coder,
uint8 *outbuf,
unsigned outbufLen)
{
CSSM_RETURN ourRtn = CSSM_OK;
unsigned unipassLen = pwd.Length;
unsigned char *unipass = pwd.Data;
unsigned p12_r = iterCount;
unsigned p12_n = outbufLen;
unsigned p12_u; unsigned p12_v; unsigned char *p12_P = NULL; unsigned char *p12_S = NULL; CSSM_BOOL isSha1 = CSSM_TRUE;
switch(hashAlg) {
case CSSM_ALGID_MD5:
p12_u = kMD5DigestSize;
p12_v = kMD5BlockSize;
isSha1 = CSSM_FALSE;
break;
case CSSM_ALGID_SHA1:
p12_u = kSHA1DigestSize;
p12_v = kSHA1BlockSize;
break;
default:
return CSSMERR_CSP_INVALID_ALGORITHM;
}
unsigned char *p12_D = NULL; p12_D = (unsigned char *)coder.malloc(p12_v);
for(unsigned dex=0; dex<p12_v; dex++) {
p12_D[dex] = (unsigned char)pbeId;
}
unsigned p12_Slen = p12_v * ((saltLen + p12_v - 1) / p12_v);
if(p12_Slen) {
p12_S = p12StrCat(salt, saltLen, coder, p12_Slen);
}
unsigned p12_Plen = p12_v * ((unipassLen + p12_v - 1) / p12_v);
if(p12_Plen) {
p12_P = p12StrCat(unipass, unipassLen, coder, p12_Plen);
}
unsigned char *p12_I =
(unsigned char *)coder.malloc(p12_Slen + p12_Plen);
memmove(p12_I, p12_S, p12_Slen);
if(p12_Plen) {
memmove(p12_I + p12_Slen, p12_P, p12_Plen);
}
unsigned p12_c = (p12_n + p12_u - 1) / p12_u;
unsigned char *p12_A = (unsigned char *)coder.malloc(p12_c * p12_u);
DigestCtx ourDigest;
DigestCtx *hashHand = &ourDigest;
memset(hashHand, 0, sizeof(hashHand));
unsigned char *p12_B = (unsigned char *)coder.malloc(p12_v);
BIGNUM *Ij = BN_new();
BIGNUM *Bpl1 = BN_new();
for(unsigned p12_i=0; p12_i<p12_c; p12_i++) {
unsigned char *p12_AsubI = p12_A + (p12_i * p12_u);
ourRtn = DigestCtxInit(hashHand, isSha1);
if(ourRtn) break;
DigestCtxUpdate(hashHand, p12_D, p12_v);
DigestCtxUpdate(hashHand, p12_I, p12_Slen + p12_Plen);
DigestCtxFinal(hashHand, p12_AsubI);
for(unsigned iter=1; iter<p12_r; iter++) {
ourRtn = DigestCtxInit(hashHand, isSha1);
if(ourRtn) break;
DigestCtxUpdate(hashHand, p12_AsubI, p12_u);
DigestCtxFinal(hashHand, p12_AsubI);
}
p12StrCat(p12_AsubI, p12_u, coder, p12_v, p12_B);
BN_bin2bn (p12_B, p12_v, Bpl1);
BN_add_word (Bpl1, 1);
unsigned Ilen = p12_Slen + p12_Plen;
for (unsigned j = 0; j < Ilen; j+=p12_v) {
BN_bin2bn (p12_I + j, p12_v, Ij);
BN_add (Ij, Ij, Bpl1);
BN_bn2bin (Ij, p12_B);
unsigned Ijlen = BN_num_bytes (Ij);
if (Ijlen > p12_v) {
BN_bn2bin (Ij, p12_B);
memcpy (p12_I + j, p12_B + 1, p12_v);
} else if (Ijlen < p12_v) {
memset(p12_I + j, 0, p12_v - Ijlen);
BN_bn2bin(Ij, p12_I + j + p12_v - Ijlen);
} else BN_bn2bin (Ij, p12_I + j);
}
}
if(ourRtn == CSSM_OK) {
memmove(outbuf, p12_A, outbufLen);
}
if(p12_D) {
memset(p12_D, 0, p12_v);
}
if(p12_S) {
memset(p12_S, 0, p12_Slen);
}
if(p12_P) {
memset(p12_P, 0, p12_Plen);
}
if(p12_I) {
memset(p12_I, 0, p12_Slen + p12_Plen);
}
if(p12_A) {
memset(p12_A, 0, p12_c * p12_u);
}
if(p12_B) {
memset(p12_B, 0, p12_v);
}
if(hashHand) {
DigestCtxFree(hashHand);
}
BN_free(Bpl1);
BN_free(Ij);
return ourRtn;
}
void DeriveKey_PKCS12 (
const Context &context,
const CssmData &Param, CSSM_DATA *keyData) {
CSSM_DATA pwd = {0, NULL};
CssmCryptoData *cryptData =
context.get<CssmCryptoData>(CSSM_ATTRIBUTE_SEED);
if(cryptData) {
pwd.Length = cryptData->Param.Length;
pwd.Data = cryptData->Param.Data;
}
uint32 saltLen = 0;
uint8 *salt = NULL;
CssmData *csalt = context.get<CssmData>(CSSM_ATTRIBUTE_SALT);
if(csalt) {
salt = csalt->Data;
saltLen = csalt->Length;
}
uint32 iterCount = context.getInt(CSSM_ATTRIBUTE_ITERATION_COUNT,
CSSMERR_CSP_MISSING_ATTR_ITERATION_COUNT);
if(iterCount == 0) {
CssmError::throwMe(CSSMERR_CSP_INVALID_ATTR_ITERATION_COUNT);
}
P12_PBE_ID pbeId = PBE_ID_Key;
switch(context.algorithm()) {
case CSSM_ALGID_PKCS12_PBE_ENCR:
pbeId = PBE_ID_Key;
break;
case CSSM_ALGID_PKCS12_PBE_MAC:
pbeId = PBE_ID_MAC;
break;
default:
assert(0);
CssmError::throwMe(CSSMERR_CSP_INTERNAL_ERROR);
}
SecNssCoder tmpCoder;
CSSM_RETURN crtn = p12PbeGen(pwd,
salt, saltLen,
iterCount,
pbeId,
CSSM_ALGID_SHA1, tmpCoder,
keyData->Data,
keyData->Length);
if(crtn) {
CssmError::throwMe(crtn);
}
if(Param.Data) {
crtn = p12PbeGen(pwd,
salt, saltLen,
iterCount,
PBE_ID_IV,
CSSM_ALGID_SHA1, tmpCoder,
Param.Data,
Param.Length);
if(crtn) {
CssmError::throwMe(crtn);
}
}
}