#include "mp.h"
static void mp_outmem(void);
static void *_mp_malloc(size_t);
static void *_mp_calloc(size_t, size_t);
static void *_mp_realloc(void*, size_t);
static void _mp_free(void*);
static mp_malloc_fun __mp_malloc = _mp_malloc;
static mp_calloc_fun __mp_calloc = _mp_calloc;
static mp_realloc_fun __mp_realloc = _mp_realloc;
static mp_free_fun __mp_free = _mp_free;
static void
mp_outmem(void)
{
fprintf(stderr, "out of memory in MP library.\n");
exit(1);
}
static void *
_mp_malloc(size_t size)
{
return (malloc(size));
}
void *
mp_malloc(size_t size)
{
void *pointer = (*__mp_malloc)(size);
if (pointer == NULL)
mp_outmem();
return (pointer);
}
mp_malloc_fun
mp_set_malloc(mp_malloc_fun fun)
{
mp_malloc_fun old = __mp_malloc;
__mp_malloc = fun;
return (old);
}
static void *
_mp_calloc(size_t nmemb, size_t size)
{
return (calloc(nmemb, size));
}
void *
mp_calloc(size_t nmemb, size_t size)
{
void *pointer = (*__mp_calloc)(nmemb, size);
if (pointer == NULL)
mp_outmem();
return (pointer);
}
mp_calloc_fun
mp_set_calloc(mp_calloc_fun fun)
{
mp_calloc_fun old = __mp_calloc;
__mp_calloc = fun;
return (old);
}
static void *
_mp_realloc(void *old, size_t size)
{
return (realloc(old, size));
}
void *
mp_realloc(void *old, size_t size)
{
void *pointer = (*__mp_realloc)(old, size);
if (pointer == NULL)
mp_outmem();
return (pointer);
}
mp_realloc_fun
mp_set_realloc(mp_realloc_fun fun)
{
mp_realloc_fun old = __mp_realloc;
__mp_realloc = fun;
return (old);
}
static void
_mp_free(void *pointer)
{
free(pointer);
}
void
mp_free(void *pointer)
{
(*__mp_free)(pointer);
}
mp_free_fun
mp_set_free(mp_free_fun fun)
{
mp_free_fun old = __mp_free;
__mp_free = fun;
return (old);
}
long
mp_add(BNS *rop, BNS *op1, BNS *op2, BNI len1, BNI len2)
{
BNI value;
BNS carry;
long size;
if (len1 < len2)
MP_SWAP(op1, op2, len1, len2);
value = op1[0] + op2[0];
rop[0] = value;
carry = value >> BNSBITS;
for (size = 1; size < len2; size++) {
value = op1[size] + op2[size] + carry;
rop[size] = value;
carry = value >> BNSBITS;
}
if (rop != op1) {
for (; size < len1; size++) {
value = op1[size] + carry;
rop[size] = value;
carry = value >> BNSBITS;
}
}
else {
for (; carry && size < len1; size++) {
value = op1[size] + carry;
rop[size] = value;
carry = value >> BNSBITS;
}
size = len1;
}
if (carry)
rop[size++] = carry;
return (size);
}
long
mp_sub(BNS *rop, BNS *op1, BNS *op2, BNI len1, BNI len2)
{
long svalue;
BNS carry;
long size;
if (op1 == op2) {
rop[0] = 0;
return (1);
}
svalue = op1[0] - op2[0];
rop[0] = svalue;
carry = svalue < 0;
for (size = 1; size < len2; size++) {
svalue = (long)(op1[size]) - op2[size] - carry;
rop[size] = svalue;
carry = svalue < 0;
}
if (rop != op1) {
for (; size < len1; size++) {
svalue = op1[size] - carry;
rop[size] = svalue;
carry = svalue < 0;
}
}
else {
for (; carry && size < len1; size++) {
svalue = op1[size] - carry;
rop[size] = svalue;
carry = svalue < 0;
}
size = len1;
}
while (size > 1 && rop[size - 1] == 0)
--size;
return (size);
}
long
mp_lshift(BNS *rop, BNS *op, BNI len, long shift)
{
long i, size;
BNI words, bits;
words = shift / BNSBITS;
bits = shift % BNSBITS;
size = len + words;
if (bits) {
BNS hi, lo;
BNI carry;
int adj;
for (i = 1, carry = CARRY >> 1; carry; i++, carry >>= 1)
if (op[len - 1] & carry)
break;
adj = (bits + (BNSBITS - i)) / BNSBITS;
size += adj;
lo = hi = op[0];
rop[words] = lo << bits;
for (i = 1; i < len; i++) {
hi = op[i];
rop[words + i] = hi << bits | (lo >> (BNSBITS - bits));
lo = hi;
}
if (adj)
rop[size - 1] = hi >> (BNSBITS - bits);
}
else
memmove(rop + size - len, op, sizeof(BNS) * len);
if (words)
memset(rop, '\0', sizeof(BNS) * words);
return (size);
}
long
mp_rshift(BNS *rop, BNS *op, BNI len, long shift)
{
int adj = 0;
long i, size;
BNI words, bits;
words = shift / BNSBITS;
bits = shift % BNSBITS;
size = len - words;
if (bits) {
BNS hi, lo;
BNI carry;
for (i = 0, carry = CARRY >> 1; carry; i++, carry >>= 1)
if (op[len - 1] & carry)
break;
adj = (bits + i) / BNSBITS;
if (size - adj == 0) {
rop[0] = 0;
return (1);
}
hi = lo = op[words + size - 1];
rop[size - 1] = hi >> bits;
for (i = size - 2; i >= 0; i--) {
lo = op[words + i];
rop[i] = (lo >> bits) | (hi << (BNSBITS - bits));
hi = lo;
}
if (adj)
rop[0] |= lo << (BNSBITS - bits);
}
else
memmove(rop, op + len - size, size * sizeof(BNS));
return (size - adj);
}
long
mp_base_mul(BNS *rop, BNS *op1, BNS *op2, BNI len1, BNI len2)
{
long i, j;
BNI value;
BNS carry;
long size = len1 + len2;
if (op1[0]) {
value = (BNI)(op1[0]) * op2[0];
rop[0] = value;
carry = (BNS)(value >> BNSBITS);
for (j = 1; j < len2; j++) {
value = (BNI)(op1[0]) * op2[j] + carry;
rop[j] = value;
carry = (BNS)(value >> BNSBITS);
}
rop[j] = carry;
}
for (i = 1; i < len1; i++) {
if (op1[i]) {
value = (BNI)(op1[i]) * op2[0] + rop[i];
rop[i] = value;
carry = (BNS)(value >> BNSBITS);
for (j = 1; j < len2; j++) {
value = (BNI)(op1[i]) * op2[j] + rop[i + j] + carry;
rop[i + j] = value;
carry = (BNS)(value >> BNSBITS);
}
rop[i + j] = carry;
}
}
if (size > 1 && rop[size - 1] == 0)
--size;
return (size);
}
long
mp_karatsuba_mul(BNS *rop, BNS *op1, BNS *op2, BNI len1, BNI len2)
{
BNI x;
BNI la0, la1, lb0, lb1;
BNS *t;
BNS *u;
BNS *r;
long xlen, tlen, ulen;
if (len1 >= len2)
x = (len1 + 1) >> 1;
else
x = (len2 + 1) >> 1;
la0 = x;
la1 = len1 - x;
lb0 = x;
lb1 = len2 - x;
tlen = la0 + lb0;
t = mp_malloc(sizeof(BNS) * tlen);
if (la1 + lb1 < lb0 + lb1 + 1)
ulen = lb0 + lb1 + 1;
else
ulen = la1 + lb1;
u = mp_malloc(sizeof(BNS) * ulen);
tlen = mp_add(t, op1, op1 + x, la0, la1);
ulen = mp_add(u, op2, op2 + x, lb0, lb1);
r = rop + x;
xlen = mp_mul(r, t, u, tlen, ulen);
tlen = la0 + lb0;
memset(t, '\0', sizeof(BNS) * tlen);
tlen = mp_mul(t, op1, op2, la0, lb0);
ulen = la1 + lb1;
memset(u, '\0', sizeof(BNS) * ulen);
ulen = mp_mul(u, op1 + x, op2 + x, la1, lb1);
xlen = mp_sub(r, r, t, xlen, tlen);
xlen = mp_sub(r, r, u, xlen, ulen);
r = rop + (x << 1);
xlen = len1 + len2;
xlen = mp_add(r, r, u, xlen, ulen);
xlen = mp_add(rop, rop, t, xlen, tlen);
mp_free(t);
mp_free(u);
if (xlen > 1 && rop[xlen - 1] == 0)
--xlen;
return (xlen);
}
long
mp_toom_mul(BNS *rop, BNS *op1, BNS *op2, BNI len1, BNI len2)
{
long size, xsize, i;
BNI value;
BNS carry;
BNI x;
BNI l1, l2;
BNI al, bl, cl, dl, el, Ul[3], Vl[3];
BNS *a, *b, *c, *d, *e, *U[3], *V[3];
x = (len1 + len2 + 4) / 6;
l1 = len1 - (x << 1);
l2 = len2 - (x << 1);
U[0] = mp_malloc(sizeof(BNS) * (x + 2));
V[0] = mp_malloc(sizeof(BNS) * (x + 2));
U[1] = mp_malloc(sizeof(BNS) * (x + 1));
V[1] = mp_malloc(sizeof(BNS) * (x + 1));
U[2] = mp_malloc(sizeof(BNS) * (x + 2));
V[2] = mp_malloc(sizeof(BNS) * (x + 2));
Ul[1] = mp_lshift(U[1], op1 + x, x, 1);
Ul[0] = mp_lshift(U[0], op1, x, 2);
Ul[0] = mp_add(U[0], U[0], U[1], Ul[0], Ul[1]);
Ul[0] = mp_add(U[0], U[0], op1 + x + x, Ul[0], l1);
Ul[2] = mp_lshift(U[2], op1 + x + x, l1, 2);
Ul[2] = mp_add(U[2], U[2], U[1], Ul[2], Ul[1]);
Ul[2] = mp_add(U[2], U[2], op1, Ul[2], x);
Ul[1] = mp_add(U[1], op1, op1 + x, x, x);
Ul[1] = mp_add(U[1], U[1], op1 + x + x, Ul[1], l1);
Vl[1] = mp_lshift(V[1], op2 + x, x, 1);
Vl[0] = mp_lshift(V[0], op2, x, 2);
Vl[0] = mp_add(V[0], V[0], V[1], Vl[0], Vl[1]);
Vl[0] = mp_add(V[0], V[0], op2 + x + x, Vl[0], l2);
Vl[2] = mp_lshift(V[2], op2 + x + x, l2, 2);
Vl[2] = mp_add(V[2], V[2], V[1], Vl[2], Vl[1]);
Vl[2] = mp_add(V[2], V[2], op2, Vl[2], x);
Vl[1] = mp_add(V[1], op2, op2 + x, x, x);
Vl[1] = mp_add(V[1], V[1], op2 + x + x, Vl[1], l2);
b = mp_calloc(1, sizeof(BNS) * (Ul[0] * Vl[0]));
bl = mp_mul(b, U[0], V[0], Ul[0], Vl[0]);
mp_free(U[0]);
mp_free(V[0]);
c = mp_calloc(1, sizeof(BNS) * (Ul[1] * Vl[1]));
cl = mp_mul(c, U[1], V[1], Ul[1], Vl[1]);
mp_free(U[1]);
mp_free(V[1]);
d = mp_calloc(1, sizeof(BNS) * (Ul[2] * Vl[2]));
dl = mp_mul(d, U[2], V[2], Ul[2], Vl[2]);
mp_free(U[2]);
mp_free(V[2]);
a = mp_calloc(1, sizeof(BNS) * (x + x));
al = mp_mul(a, op1, op2, x, x);
e = mp_calloc(1, sizeof(BNS) * (l1 + l2));
el = mp_mul(e, op1 + x + x, op2 + x + x, l1, l2);
size = mp_lshift(rop, a, al, 4);
bl = mp_sub(b, b, rop, bl, size);
bl = mp_sub(b, b, e, bl, el);
cl = mp_sub(c, c, a, cl, al);
cl = mp_sub(c, c, e, cl, el);
dl = mp_sub(d, d, a, dl, al);
size = mp_lshift(rop, e, el, 4);
dl = mp_sub(d, d, rop, dl, size);
size = mp_add(rop, b, d, bl, dl);
xsize = mp_lshift(rop + size, c, cl, 3);
size = mp_sub(rop, rop, rop + size, size, xsize);
xsize = mp_lshift(rop + size, c, cl, 1);
size = mp_sub(rop, rop + size, rop, xsize, size);
bl = mp_rshift(b, b, bl, 1);
bl = mp_sub(b, b, c, bl, cl);
cl = mp_rshift(c, rop, size, 1);
bl = mp_sub(b, b, c, bl, cl);
i = bl - 1;
value = b[i];
b[i] = value / 3;
for (--i; i >= 0; i--) {
carry = value % 3;
value = ((BNI)carry << BNSBITS) + b[i];
b[i] = (BNS)(value / 3);
}
dl = mp_rshift(d, d, dl, 1);
dl = mp_sub(d, d, b, dl, bl);
dl = mp_sub(d, d, rop, dl, size);
dl = mp_rshift(d, d, dl, 2);
memset(rop, '\0', sizeof(BNS) * (len1 + len2));
i = x * 4;
xsize = (len1 + len2) - i;
size = mp_add(rop + i, rop + i, e, xsize, el) + i;
i = x * 3;
xsize = size - i;
size = mp_add(rop + i, rop + i, d, xsize, dl) + i;
i = x * 2;
xsize = size - i;
size = mp_add(rop + i, rop + i, c, xsize, cl) + i;
i = x;
xsize = size - i;
size = mp_add(rop + i, rop + i, b, xsize, bl) + i;
size = mp_add(rop, rop, a, size, al);
mp_free(e);
mp_free(d);
mp_free(c);
mp_free(b);
mp_free(a);
if (size > 1 && rop[size - 1] == 0)
--size;
return (size);
}
long
mp_mul(BNS *rop, BNS *op1, BNS *op2, BNI len1, BNI len2)
{
if (len1 < len2)
MP_SWAP(op1, op2, len1, len2);
if (len1 < KARATSUBA || len2 < KARATSUBA)
return (mp_base_mul(rop, op1, op2, len1, len2));
else if (len1 < TOOM && len2 < TOOM && len2 > ((len1 + 1) >> 1))
return (mp_karatsuba_mul(rop, op1, op2, len1, len2));
else if (len1 >= TOOM && len2 >= TOOM && (len2 + 2) / 3 == (len1 + 2) / 3)
return (mp_toom_mul(rop, op1, op2, len1, len2));
else {
long xsize, psize, isize;
BNS *ptr;
isize = 0;
xsize = len1 + len2;
mp_mul(rop, op1, op2, len2, len2);
len1 -= len2;
op1 += len2;
if (len1 > len2)
ptr = mp_calloc(1, sizeof(BNS) * (len2 + len2));
else
ptr = mp_calloc(1, sizeof(BNS) * (len1 + len2));
while (len1 >= len2) {
isize += len2;
psize = mp_mul(ptr, op1, op2, len2, len2);
mp_add(rop + isize, rop + isize, ptr, xsize - isize, psize);
len1 -= len2;
op1 += len2;
memset(ptr, '\0', sizeof(BNS) * (MIN(len1, len2) + len2));
}
if (len1) {
isize += len2;
psize = mp_mul(ptr, op2, op1, len2, len1);
mp_add(rop + isize, rop + isize, ptr, xsize, psize);
}
if (rop[xsize - 1] == 0)
--xsize;
mp_free(ptr);
return (xsize);
}
}