#include "config.h"
#include "Compression.h"
#include "CheckedArithmetic.h"
#if USE(ZLIB) && !COMPILER(MSVC)
#include <string.h>
#include <zlib.h>
namespace WTF {
static void* zAlloc(void*, uint32_t count, uint32_t size)
{
CheckedSize allocSize = count;
allocSize *= size;
if (allocSize.hasOverflowed())
return Z_NULL;
void* result = 0;
if (tryFastMalloc(allocSize.unsafeGet()).getValue(result))
return result;
return Z_NULL;
}
static void zFree(void*, void* data)
{
fastFree(data);
}
std::unique_ptr<GenericCompressedData> GenericCompressedData::create(const uint8_t* data, size_t dataLength)
{
enum { MinimumSize = sizeof(GenericCompressedData) * 8 };
if (!data || dataLength < MinimumSize)
return nullptr;
z_stream stream;
memset(&stream, 0, sizeof(stream));
stream.zalloc = zAlloc;
stream.zfree = zFree;
stream.data_type = Z_BINARY;
stream.opaque = Z_NULL;
stream.avail_in = dataLength;
stream.next_in = const_cast<uint8_t*>(data);
size_t currentOffset = OBJECT_OFFSETOF(GenericCompressedData, m_data);
size_t currentCapacity = fastMallocGoodSize(MinimumSize);
Bytef* compressedData = static_cast<Bytef*>(fastMalloc(currentCapacity));
memset(compressedData, 0, sizeof(GenericCompressedData));
stream.next_out = compressedData + currentOffset;
stream.avail_out = currentCapacity - currentOffset;
deflateInit(&stream, Z_BEST_COMPRESSION);
while (true) {
int deflateResult = deflate(&stream, Z_FINISH);
if (deflateResult == Z_OK || !stream.avail_out) {
size_t newCapacity = 0;
currentCapacity -= stream.avail_out;
if (!stream.avail_in)
newCapacity = currentCapacity + 8;
else {
size_t compressedContent = stream.next_in - data;
double expectedSize = static_cast<double>(dataLength) * compressedContent / currentCapacity;
newCapacity = std::max(static_cast<size_t>(expectedSize + 8), currentCapacity + 8);
}
newCapacity = fastMallocGoodSize(newCapacity);
if (newCapacity >= dataLength)
goto fail;
compressedData = static_cast<Bytef*>(fastRealloc(compressedData, newCapacity));
currentOffset = currentCapacity - stream.avail_out;
stream.next_out = compressedData + currentOffset;
stream.avail_out = newCapacity - currentCapacity;
currentCapacity = newCapacity;
continue;
}
if (deflateResult == Z_STREAM_END) {
ASSERT(!stream.avail_in);
break;
}
ASSERT_NOT_REACHED();
fail:
deflateEnd(&stream);
fastFree(compressedData);
return nullptr;
}
deflateEnd(&stream);
static int64_t totalCompressed = 0;
static int64_t totalInput = 0;
totalCompressed += currentCapacity;
totalInput += dataLength;
return std::unique_ptr<GenericCompressedData>(new (compressedData) GenericCompressedData(dataLength, stream.total_out));
}
bool GenericCompressedData::decompress(uint8_t* destination, size_t bufferSize, size_t* decompressedByteCount)
{
if (decompressedByteCount)
*decompressedByteCount = 0;
z_stream stream;
memset(&stream, 0, sizeof(stream));
stream.zalloc = zAlloc;
stream.zfree = zFree;
stream.data_type = Z_BINARY;
stream.opaque = Z_NULL;
stream.next_out = destination;
stream.avail_out = bufferSize;
stream.next_in = m_data;
stream.avail_in = compressedSize();
if (inflateInit(&stream) != Z_OK) {
ASSERT_NOT_REACHED();
return false;
}
int inflateResult = inflate(&stream, Z_FINISH);
inflateEnd(&stream);
ASSERT(stream.total_out <= bufferSize);
if (decompressedByteCount)
*decompressedByteCount = stream.total_out;
if (inflateResult != Z_STREAM_END) {
ASSERT_NOT_REACHED();
return false;
}
return true;
}
}
#else
namespace WTF {
std::unique_ptr<GenericCompressedData> GenericCompressedData::create(const uint8_t*, size_t)
{
return nullptr;
}
bool GenericCompressedData::decompress(uint8_t*, size_t, size_t*)
{
return false;
}
}
#endif