CurlDownload.cpp   [plain text]


/*
 * Copyright (C) 2013 Apple Inc.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 */

#include "config.h"

#if USE(CURL)

#include "CurlDownload.h"

#include "HTTPHeaderNames.h"
#include "HTTPParsers.h"
#include "ResourceHandleManager.h"
#include "ResourceRequest.h"
#include <wtf/MainThread.h>
#include <wtf/text/CString.h>

using namespace WebCore;

namespace WebCore {

// CurlDownloadManager -------------------------------------------------------------------

CurlDownloadManager::CurlDownloadManager()
: m_threadId(0)
, m_curlMultiHandle(0)
, m_runThread(false)
{
    curl_global_init(CURL_GLOBAL_ALL);
    m_curlMultiHandle = curl_multi_init();
}

CurlDownloadManager::~CurlDownloadManager()
{
    stopThread();
    curl_multi_cleanup(m_curlMultiHandle);
    curl_global_cleanup();
}

bool CurlDownloadManager::add(CURL* curlHandle)
{
    {
        LockHolder locker(m_mutex);
        m_pendingHandleList.append(curlHandle);
    }

    startThreadIfNeeded();

    return true;
}

bool CurlDownloadManager::remove(CURL* curlHandle)
{
    LockHolder locker(m_mutex);

    m_removedHandleList.append(curlHandle);

    return true;
}

int CurlDownloadManager::getActiveDownloadCount() const
{
    LockHolder locker(m_mutex);
    return m_activeHandleList.size();
}

int CurlDownloadManager::getPendingDownloadCount() const
{
    LockHolder locker(m_mutex);
    return m_pendingHandleList.size();
}

void CurlDownloadManager::startThreadIfNeeded()
{
    if (!runThread()) {
        if (m_threadId)
            waitForThreadCompletion(m_threadId);
        setRunThread(true);
        m_threadId = createThread(downloadThread, this, "downloadThread");
    }
}

void CurlDownloadManager::stopThread()
{
    setRunThread(false);

    if (m_threadId) {
        waitForThreadCompletion(m_threadId);
        m_threadId = 0;
    }
}

void CurlDownloadManager::stopThreadIfIdle()
{
    if (!getActiveDownloadCount() && !getPendingDownloadCount())
        setRunThread(false);
}

void CurlDownloadManager::updateHandleList()
{
    LockHolder locker(m_mutex);

    // Remove curl easy handles from multi list 
    int size = m_removedHandleList.size();
    for (int i = 0; i < size; i++) {
        removeFromCurl(m_removedHandleList[0]);
        m_removedHandleList.remove(0);
    }

    // Add pending curl easy handles to multi list 
    size = m_pendingHandleList.size();
    for (int i = 0; i < size; i++) {
        addToCurl(m_pendingHandleList[0]);
        m_pendingHandleList.remove(0);
    }
}

bool CurlDownloadManager::addToCurl(CURL* curlHandle)
{
    CURLMcode retval = curl_multi_add_handle(m_curlMultiHandle, curlHandle);
    if (retval == CURLM_OK) {
        m_activeHandleList.append(curlHandle);
        return true;
    }
    return false;
}

bool CurlDownloadManager::removeFromCurl(CURL* curlHandle)
{
    int handlePos = m_activeHandleList.find(curlHandle);

    if (handlePos < 0)
        return true;
    
    CURLMcode retval = curl_multi_remove_handle(m_curlMultiHandle, curlHandle);
    if (retval == CURLM_OK) {
        m_activeHandleList.remove(handlePos);
        curl_easy_cleanup(curlHandle);
        return true;
    }
    return false;
}

void CurlDownloadManager::downloadThread(void* data)
{
    CurlDownloadManager* downloadManager = reinterpret_cast<CurlDownloadManager*>(data);

    while (downloadManager->runThread()) {

        downloadManager->updateHandleList();

        // Retry 'select' if it was interrupted by a process signal.
        int rc = 0;
        do {
            fd_set fdread;
            fd_set fdwrite;
            fd_set fdexcep;

            int maxfd = 0;

            const int selectTimeoutMS = 5;

            struct timeval timeout;
            timeout.tv_sec = 0;
            timeout.tv_usec = selectTimeoutMS * 1000; // select waits microseconds

            FD_ZERO(&fdread);
            FD_ZERO(&fdwrite);
            FD_ZERO(&fdexcep);
            curl_multi_fdset(downloadManager->getMultiHandle(), &fdread, &fdwrite, &fdexcep, &maxfd);
            // When the 3 file descriptors are empty, winsock will return -1
            // and bail out, stopping the file download. So make sure we
            // have valid file descriptors before calling select.
            if (maxfd >= 0)
                rc = ::select(maxfd + 1, &fdread, &fdwrite, &fdexcep, &timeout);
        } while (rc == -1 && errno == EINTR);

        int activeDownloadCount = 0;
        while (curl_multi_perform(downloadManager->getMultiHandle(), &activeDownloadCount) == CURLM_CALL_MULTI_PERFORM) { }

        int messagesInQueue = 0;
        CURLMsg* msg = curl_multi_info_read(downloadManager->getMultiHandle(), &messagesInQueue);

        if (!msg)
            continue;

        CurlDownload* download = 0;
        CURLcode err = curl_easy_getinfo(msg->easy_handle, CURLINFO_PRIVATE, &download);

        if (msg->msg == CURLMSG_DONE) {
            if (download) {
                if (msg->data.result == CURLE_OK) {
                    callOnMainThread([download] {
                        download->didFinish();
                        download->deref(); // This matches the ref() in CurlDownload::start().
                    });
                } else {
                    callOnMainThread([download] {
                        download->didFail();
                        download->deref(); // This matches the ref() in CurlDownload::start().
                    });
                }
            }
            downloadManager->removeFromCurl(msg->easy_handle);
        }

        downloadManager->stopThreadIfIdle();
    }
}

// CurlDownload --------------------------------------------------------------------------

CurlDownloadManager CurlDownload::m_downloadManager;

CurlDownload::CurlDownload()
    : m_curlHandle(nullptr)
    , m_customHeaders(nullptr)
    , m_url(nullptr)
    , m_tempHandle(invalidPlatformFileHandle)
    , m_deletesFileUponFailure(false)
    , m_listener(nullptr)
{
}

CurlDownload::~CurlDownload()
{
    {
        LockHolder locker(m_mutex);

        if (m_url)
            fastFree(m_url);

        if (m_customHeaders)
            curl_slist_free_all(m_customHeaders);
    }

    closeFile();
    moveFileToDestination();
}

void CurlDownload::init(CurlDownloadListener* listener, const URL& url)
{
    if (!listener)
        return;

    LockHolder locker(m_mutex);

    m_curlHandle = curl_easy_init();

    String urlStr = url.string();
    m_url = fastStrDup(urlStr.latin1().data());

    curl_easy_setopt(m_curlHandle, CURLOPT_URL, m_url);
    curl_easy_setopt(m_curlHandle, CURLOPT_PRIVATE, this);
    curl_easy_setopt(m_curlHandle, CURLOPT_WRITEFUNCTION, writeCallback);
    curl_easy_setopt(m_curlHandle, CURLOPT_WRITEDATA, this);
    curl_easy_setopt(m_curlHandle, CURLOPT_HEADERFUNCTION, headerCallback);
    curl_easy_setopt(m_curlHandle, CURLOPT_WRITEHEADER, this);
    curl_easy_setopt(m_curlHandle, CURLOPT_FOLLOWLOCATION, 1);
    curl_easy_setopt(m_curlHandle, CURLOPT_MAXREDIRS, 10);
    curl_easy_setopt(m_curlHandle, CURLOPT_HTTPAUTH, CURLAUTH_ANY);

    const char* certPath = getenv("CURL_CA_BUNDLE_PATH");
    if (certPath)
        curl_easy_setopt(m_curlHandle, CURLOPT_CAINFO, certPath);

    CURLSH* curlsh = ResourceHandleManager::sharedInstance()->getCurlShareHandle();
    if (curlsh)
        curl_easy_setopt(m_curlHandle, CURLOPT_SHARE, curlsh);

    m_listener = listener;
}

void CurlDownload::init(CurlDownloadListener* listener, ResourceHandle*, const ResourceRequest& request, const ResourceResponse&)
{
    if (!listener)
        return;

    URL url(ParsedURLString, request.url());

    init(listener, url);

    addHeaders(request);
}

bool CurlDownload::start()
{
    ref(); // CurlDownloadManager::downloadThread will call deref when the download has finished.
    return m_downloadManager.add(m_curlHandle);
}

bool CurlDownload::cancel()
{
    return m_downloadManager.remove(m_curlHandle);
}

String CurlDownload::getTempPath() const
{
    LockHolder locker(m_mutex);
    return m_tempPath;
}

String CurlDownload::getUrl() const
{
    LockHolder locker(m_mutex);
    return String(m_url);
}

ResourceResponse CurlDownload::getResponse() const
{
    LockHolder locker(m_mutex);
    return m_response;
}

void CurlDownload::closeFile()
{
    LockHolder locker(m_mutex);

    if (m_tempHandle != invalidPlatformFileHandle) {
        WebCore::closeFile(m_tempHandle);
        m_tempHandle = invalidPlatformFileHandle;
    }
}

void CurlDownload::moveFileToDestination()
{
    LockHolder locker(m_mutex);

    if (m_destination.isEmpty())
        return;

    ::MoveFileEx(m_tempPath.charactersWithNullTermination().data(), m_destination.charactersWithNullTermination().data(), MOVEFILE_COPY_ALLOWED | MOVEFILE_REPLACE_EXISTING);
}

void CurlDownload::writeDataToFile(const char* data, int size)
{
    if (m_tempPath.isEmpty())
        m_tempPath = openTemporaryFile("download", m_tempHandle);

    if (m_tempHandle != invalidPlatformFileHandle)
        writeToFile(m_tempHandle, data, size);
}

void CurlDownload::addHeaders(const ResourceRequest& request)
{
    LockHolder locker(m_mutex);

    if (request.httpHeaderFields().size() > 0) {
        struct curl_slist* headers = 0;

        HTTPHeaderMap customHeaders = request.httpHeaderFields();
        HTTPHeaderMap::const_iterator end = customHeaders.end();
        for (HTTPHeaderMap::const_iterator it = customHeaders.begin(); it != end; ++it) {
            const String& value = it->value;
            String headerString(it->key);
            if (value.isEmpty())
                // Insert the ; to tell curl that this header has an empty value.
                headerString.append(";");
            else {
                headerString.append(": ");
                headerString.append(value);
            }
            CString headerLatin1 = headerString.latin1();
            headers = curl_slist_append(headers, headerLatin1.data());
        }

        if (headers) {
            curl_easy_setopt(m_curlHandle, CURLOPT_HTTPHEADER, headers);
            m_customHeaders = headers;
        }
    }
}

void CurlDownload::didReceiveHeader(const String& header)
{
    LockHolder locker(m_mutex);

    if (header == "\r\n" || header == "\n") {

        long httpCode = 0;
        CURLcode err = curl_easy_getinfo(m_curlHandle, CURLINFO_RESPONSE_CODE, &httpCode);

        if (httpCode >= 200 && httpCode < 300) {
            URL url = getCurlEffectiveURL(m_curlHandle);
            callOnMainThread([this, url = url.isolatedCopy(), protectedThis = makeRef(*this)] {
                m_response.setURL(url);
                m_response.setMimeType(extractMIMETypeFromMediaType(m_response.httpHeaderField(HTTPHeaderName::ContentType)));
                m_response.setTextEncodingName(extractCharsetFromMediaType(m_response.httpHeaderField(HTTPHeaderName::ContentType)));

                didReceiveResponse();
            });
        }
    } else {
        callOnMainThread([this, header = header.isolatedCopy(), protectedThis = makeRef(*this)] {
            int splitPos = header.find(":");
            if (splitPos != -1)
                m_response.setHTTPHeaderField(header.left(splitPos), header.substring(splitPos + 1).stripWhiteSpace());
        });
    }
}

void CurlDownload::didReceiveData(void* data, int size)
{
    LockHolder locker(m_mutex);

    RefPtr<CurlDownload> protectedThis(this);

    callOnMainThread([this, size, protectedThis] {
        didReceiveDataOfLength(size);
    });

    writeDataToFile(static_cast<const char*>(data), size);
}

void CurlDownload::didReceiveResponse()
{
    if (m_listener)
        m_listener->didReceiveResponse();
}

void CurlDownload::didReceiveDataOfLength(int size)
{
    if (m_listener)
        m_listener->didReceiveDataOfLength(size);
}

void CurlDownload::didFinish()
{
    closeFile();
    moveFileToDestination();

    if (m_listener)
        m_listener->didFinish();
}

void CurlDownload::didFail()
{
    closeFile();

    LockHolder locker(m_mutex);

    if (m_deletesFileUponFailure)
        deleteFile(m_tempPath);

    if (m_listener)
        m_listener->didFail();
}

size_t CurlDownload::writeCallback(void* ptr, size_t size, size_t nmemb, void* data)
{
    size_t totalSize = size * nmemb;
    CurlDownload* download = reinterpret_cast<CurlDownload*>(data);

    if (download)
        download->didReceiveData(ptr, totalSize);

    return totalSize;
}

size_t CurlDownload::headerCallback(char* ptr, size_t size, size_t nmemb, void* data)
{
    size_t totalSize = size * nmemb;
    CurlDownload* download = reinterpret_cast<CurlDownload*>(data);

    String header(static_cast<const char*>(ptr), totalSize);

    if (download)
        download->didReceiveHeader(header);

    return totalSize;
}

void CurlDownload::downloadFinishedCallback(CurlDownload* download)
{
    if (download)
        download->didFinish();
}

void CurlDownload::downloadFailedCallback(CurlDownload* download)
{
    if (download)
        download->didFail();
}

void CurlDownload::receivedDataCallback(CurlDownload* download, int size)
{
    if (download)
        download->didReceiveDataOfLength(size);
}

void CurlDownload::receivedResponseCallback(CurlDownload* download)
{
    if (download)
        download->didReceiveResponse();
}

}

#endif