#include "config.h"
#include "CurlRequest.h"
#if USE(CURL)
#include "CurlRequestClient.h"
#include "CurlRequestScheduler.h"
#include "MIMETypeRegistry.h"
#include "ResourceError.h"
#include "SharedBuffer.h"
#include <wtf/Language.h>
#include <wtf/MainThread.h>
namespace WebCore {
static void runOnMainThread(Function<void()>&& task);
CurlRequest::CurlRequest(const ResourceRequest&request, CurlRequestClient* client, bool shouldSuspend, bool enableMultipart)
: m_request(request.isolatedCopy())
, m_client(client)
, m_shouldSuspend(shouldSuspend)
, m_enableMultipart(enableMultipart)
, m_formDataStream(m_request.httpBody())
{
ASSERT(isMainThread());
}
void CurlRequest::invalidateClient()
{
ASSERT(isMainThread());
m_client = nullptr;
}
void CurlRequest::setUserPass(const String& user, const String& password)
{
ASSERT(isMainThread());
m_user = user.isolatedCopy();
m_password = password.isolatedCopy();
}
void CurlRequest::start(bool isSyncRequest)
{
ASSERT(isMainThread());
m_isSyncRequest = isSyncRequest;
auto url = m_request.url().isolatedCopy();
if (!m_isSyncRequest) {
if (url.isLocalFile())
invokeDidReceiveResponseForFile(url);
else
startWithJobManager();
} else {
retain();
if (url.isLocalFile())
invokeDidReceiveResponseForFile(url);
setupTransfer();
CURLcode resultCode = m_curlHandle->perform();
didCompleteTransfer(resultCode);
release();
}
}
void CurlRequest::startWithJobManager()
{
ASSERT(isMainThread());
CurlContext::singleton().scheduler().add(this);
}
void CurlRequest::cancel()
{
ASSERT(isMainThread());
if (isCompletedOrCancelled())
return;
m_cancelled = true;
if (!m_isSyncRequest) {
auto& scheduler = CurlContext::singleton().scheduler();
if (needToInvokeDidCancelTransfer()) {
runOnWorkerThreadIfRequired([this, protectedThis = makeRef(*this)]() {
didCancelTransfer();
});
} else
scheduler.cancel(this);
} else {
if (needToInvokeDidCancelTransfer())
didCancelTransfer();
}
invalidateClient();
}
void CurlRequest::suspend()
{
ASSERT(isMainThread());
setRequestPaused(true);
}
void CurlRequest::resume()
{
ASSERT(isMainThread());
setRequestPaused(false);
}
void CurlRequest::callClient(Function<void(CurlRequest&, CurlRequestClient&)>&& task)
{
runOnMainThread([this, protectedThis = makeRef(*this), task = WTFMove(task)]() mutable {
if (m_client)
task(*this, makeRef(*m_client));
});
}
static void runOnMainThread(Function<void()>&& task)
{
if (isMainThread())
task();
else
callOnMainThread(WTFMove(task));
}
void CurlRequest::runOnWorkerThreadIfRequired(Function<void()>&& task)
{
if (isMainThread() && !m_isSyncRequest)
CurlContext::singleton().scheduler().callOnWorkerThread(WTFMove(task));
else
task();
}
CURL* CurlRequest::setupTransfer()
{
auto& sslHandle = CurlContext::singleton().sslHandle();
auto httpHeaderFields = m_request.httpHeaderFields();
appendAcceptLanguageHeader(httpHeaderFields);
m_curlHandle = std::make_unique<CurlHandle>();
m_curlHandle->setUrl(m_request.url());
m_curlHandle->appendRequestHeaders(httpHeaderFields);
const auto& method = m_request.httpMethod();
if (method == "GET")
m_curlHandle->enableHttpGetRequest();
else if (method == "POST")
setupPOST(m_request);
else if (method == "PUT")
setupPUT(m_request);
else if (method == "HEAD")
m_curlHandle->enableHttpHeadRequest();
else {
m_curlHandle->setHttpCustomRequest(method);
setupPUT(m_request);
}
if (!m_user.isEmpty() || !m_password.isEmpty()) {
m_curlHandle->enableHttpAuthentication(CURLAUTH_ANY);
m_curlHandle->setHttpAuthUserPass(m_user, m_password);
}
m_curlHandle->setHeaderCallbackFunction(didReceiveHeaderCallback, this);
m_curlHandle->setWriteCallbackFunction(didReceiveDataCallback, this);
m_curlHandle->setTimeout(Seconds(m_request.timeoutInterval()));
if (m_shouldSuspend)
setRequestPaused(true);
return m_curlHandle->handle();
}
size_t CurlRequest::willSendData(char* buffer, size_t blockSize, size_t numberOfBlocks)
{
if (isCompletedOrCancelled())
return CURL_READFUNC_ABORT;
if (!blockSize || !numberOfBlocks)
return CURL_READFUNC_ABORT;
if (blockSize > (std::numeric_limits<size_t>::max() / numberOfBlocks))
return CURL_READFUNC_ABORT;
size_t bufferSize = blockSize * numberOfBlocks;
auto sendBytes = m_formDataStream.read(buffer, bufferSize);
if (!sendBytes) {
return CURL_READFUNC_ABORT;
}
callClient([totalReadSize = m_formDataStream.totalReadSize(), totalSize = m_formDataStream.totalSize()](CurlRequest& request, CurlRequestClient& client) {
client.curlDidSendData(request, totalReadSize, totalSize);
});
return *sendBytes;
}
size_t CurlRequest::didReceiveHeader(String&& header)
{
static const auto emptyLineCRLF = "\r\n";
static const auto emptyLineLF = "\n";
if (isCompletedOrCancelled())
return 0;
if (m_didReceiveResponse) {
m_didReceiveResponse = false;
m_response = CurlResponse { };
m_multipartHandle = nullptr;
}
auto receiveBytes = static_cast<size_t>(header.length());
if ((header != emptyLineCRLF) && (header != emptyLineLF)) {
m_response.headers.append(WTFMove(header));
return receiveBytes;
}
long statusCode = 0;
if (auto code = m_curlHandle->getResponseCode())
statusCode = *code;
long httpConnectCode = 0;
if (auto code = m_curlHandle->getHttpConnectCode())
httpConnectCode = *code;
m_didReceiveResponse = true;
m_response.url = m_request.url();
m_response.statusCode = statusCode;
m_response.httpConnectCode = httpConnectCode;
if (auto length = m_curlHandle->getContentLength())
m_response.expectedContentLength = *length;
if (auto proxyUrl = m_curlHandle->getProxyUrl())
m_response.proxyUrl = URL(URL(), *proxyUrl);
if (auto auth = m_curlHandle->getHttpAuthAvail())
m_response.availableHttpAuth = *auth;
if (auto auth = m_curlHandle->getProxyAuthAvail())
m_response.availableProxyAuth = *auth;
if (auto version = m_curlHandle->getHttpVersion())
m_response.httpVersion = *version;
if (auto metrics = m_curlHandle->getNetworkLoadMetrics())
m_networkLoadMetrics = *metrics;
if (m_response.availableProxyAuth)
CurlContext::singleton().setProxyAuthMethod(m_response.availableProxyAuth);
if (auto info = m_curlHandle->certificateInfo())
m_certificateInfo = *info;
if (m_enableMultipart)
m_multipartHandle = CurlMultipartHandle::createIfNeeded(*this, m_response);
return receiveBytes;
}
size_t CurlRequest::didReceiveData(Ref<SharedBuffer>&& buffer)
{
if (isCompletedOrCancelled())
return 0;
if (needToInvokeDidReceiveResponse()) {
if (!m_isSyncRequest) {
setCallbackPaused(true);
invokeDidReceiveResponse(m_response, Action::ReceiveData);
updateHandlePauseState(true);
return CURL_WRITEFUNC_PAUSE;
}
invokeDidReceiveResponse(m_response, Action::None);
}
auto receiveBytes = buffer->size();
writeDataToDownloadFileIfEnabled(buffer);
if (receiveBytes) {
if (m_multipartHandle)
m_multipartHandle->didReceiveData(buffer);
else {
callClient([buffer = WTFMove(buffer)](CurlRequest& request, CurlRequestClient& client) mutable {
client.curlDidReceiveBuffer(request, WTFMove(buffer));
});
}
}
return receiveBytes;
}
void CurlRequest::didReceiveHeaderFromMultipart(const Vector<String>& headers)
{
if (isCompletedOrCancelled())
return;
CurlResponse response = m_response.isolatedCopy();
response.expectedContentLength = 0;
response.headers.clear();
for (auto header : headers)
response.headers.append(header);
invokeDidReceiveResponse(response, Action::None);
}
void CurlRequest::didReceiveDataFromMultipart(Ref<SharedBuffer>&& buffer)
{
if (isCompletedOrCancelled())
return;
auto receiveBytes = buffer->size();
if (receiveBytes) {
callClient([buffer = WTFMove(buffer)](CurlRequest& request, CurlRequestClient& client) mutable {
client.curlDidReceiveBuffer(request, WTFMove(buffer));
});
}
}
void CurlRequest::didCompleteTransfer(CURLcode result)
{
if (m_cancelled) {
m_curlHandle = nullptr;
return;
}
if (needToInvokeDidReceiveResponse()) {
m_finishedResultCode = result;
invokeDidReceiveResponse(m_response, Action::FinishTransfer);
return;
}
if (result == CURLE_OK) {
if (m_multipartHandle)
m_multipartHandle->didComplete();
if (auto metrics = m_curlHandle->getNetworkLoadMetrics())
m_networkLoadMetrics = *metrics;
finalizeTransfer();
callClient([](CurlRequest& request, CurlRequestClient& client) {
client.curlDidComplete(request);
});
} else {
auto type = (result == CURLE_OPERATION_TIMEDOUT && m_request.timeoutInterval() > 0.0) ? ResourceError::Type::Timeout : ResourceError::Type::General;
auto resourceError = ResourceError::httpError(result, m_request.url(), type);
if (auto sslErrors = m_curlHandle->sslErrors())
resourceError.setSslErrors(sslErrors);
if (auto info = m_curlHandle->certificateInfo())
m_certificateInfo = *info;
finalizeTransfer();
callClient([error = resourceError.isolatedCopy()](CurlRequest& request, CurlRequestClient& client) {
client.curlDidFailWithError(request, error);
});
}
}
void CurlRequest::didCancelTransfer()
{
finalizeTransfer();
cleanupDownloadFile();
}
void CurlRequest::finalizeTransfer()
{
closeDownloadFile();
m_formDataStream.clean();
m_multipartHandle = nullptr;
m_curlHandle = nullptr;
}
void CurlRequest::appendAcceptLanguageHeader(HTTPHeaderMap& header)
{
for (const auto& language : userPreferredLanguages())
header.add(HTTPHeaderName::AcceptLanguage, language);
}
void CurlRequest::setupPUT(ResourceRequest& request)
{
m_curlHandle->enableHttpPutRequest();
m_curlHandle->removeRequestHeader("Expect");
auto elementSize = m_formDataStream.elementSize();
if (!elementSize)
return;
setupSendData(true);
}
void CurlRequest::setupPOST(ResourceRequest& request)
{
m_curlHandle->enableHttpPostRequest();
auto elementSize = m_formDataStream.elementSize();
if (!elementSize)
return;
if (elementSize == 1) {
const auto* postData = m_formDataStream.getPostData();
if (postData && postData->size())
m_curlHandle->setPostFields(postData->data(), postData->size());
} else
setupSendData(false);
}
void CurlRequest::setupSendData(bool forPutMethod)
{
if (m_formDataStream.shouldUseChunkTransfer())
m_curlHandle->appendRequestHeader("Transfer-Encoding: chunked");
else {
if (forPutMethod)
m_curlHandle->setInFileSizeLarge(static_cast<curl_off_t>(m_formDataStream.totalSize()));
else
m_curlHandle->setPostFieldLarge(static_cast<curl_off_t>(m_formDataStream.totalSize()));
}
m_curlHandle->setReadCallbackFunction(willSendDataCallback, this);
}
void CurlRequest::invokeDidReceiveResponseForFile(URL& url)
{
ASSERT(isMainThread());
ASSERT(url.isLocalFile());
m_response.url = url;
m_response.statusCode = 200;
m_response.headers.append(String("Content-Type: " + MIMETypeRegistry::getMIMETypeForPath(m_response.url.path())));
if (!m_isSyncRequest) {
runOnWorkerThreadIfRequired([this, protectedThis = makeRef(*this)]() {
invokeDidReceiveResponse(m_response, Action::StartTransfer);
});
} else {
invokeDidReceiveResponse(m_response, Action::None);
}
}
void CurlRequest::invokeDidReceiveResponse(const CurlResponse& response, Action behaviorAfterInvoke)
{
ASSERT(!m_didNotifyResponse || m_multipartHandle);
m_didNotifyResponse = true;
m_actionAfterInvoke = behaviorAfterInvoke;
callClient([response = response.isolatedCopy()](CurlRequest& request, CurlRequestClient& client) {
client.curlDidReceiveResponse(request, response);
});
}
void CurlRequest::completeDidReceiveResponse()
{
ASSERT(isMainThread());
ASSERT(m_didNotifyResponse);
ASSERT(!m_didReturnFromNotify || m_multipartHandle);
if (isCancelled())
return;
if (m_actionAfterInvoke != Action::StartTransfer && isCompleted())
return;
m_didReturnFromNotify = true;
if (m_actionAfterInvoke == Action::ReceiveData) {
setCallbackPaused(false);
} else if (m_actionAfterInvoke == Action::StartTransfer) {
startWithJobManager();
} else if (m_actionAfterInvoke == Action::FinishTransfer) {
if (!m_isSyncRequest) {
runOnWorkerThreadIfRequired([this, protectedThis = makeRef(*this), finishedResultCode = m_finishedResultCode]() {
didCompleteTransfer(finishedResultCode);
});
} else
didCompleteTransfer(m_finishedResultCode);
}
}
void CurlRequest::setRequestPaused(bool paused)
{
{
LockHolder lock(m_pauseStateMutex);
auto savedState = shouldBePaused();
m_shouldSuspend = m_isPausedOfRequest = paused;
if (shouldBePaused() == savedState)
return;
}
pausedStatusChanged();
}
void CurlRequest::setCallbackPaused(bool paused)
{
{
LockHolder lock(m_pauseStateMutex);
auto savedState = shouldBePaused();
m_isPausedOfCallback = paused;
if (shouldBePaused() == savedState || paused)
return;
}
pausedStatusChanged();
}
void CurlRequest::invokeCancel()
{
runOnMainThread([this, protectedThis = makeRef(*this)]() {
cancel();
});
}
void CurlRequest::pausedStatusChanged()
{
if (isCompletedOrCancelled())
return;
runOnWorkerThreadIfRequired([this, protectedThis = makeRef(*this)]() {
if (isCompletedOrCancelled())
return;
bool needCancel { false };
{
LockHolder lock(m_pauseStateMutex);
bool paused = shouldBePaused();
if (isHandlePaused() == paused)
return;
auto error = m_curlHandle->pause(paused ? CURLPAUSE_ALL : CURLPAUSE_CONT);
if (error == CURLE_OK)
updateHandlePauseState(paused);
needCancel = (error != CURLE_OK && !paused);
}
if (needCancel)
invokeCancel();
});
}
void CurlRequest::updateHandlePauseState(bool paused)
{
ASSERT(!isMainThread() || m_isSyncRequest);
m_isHandlePaused = paused;
}
bool CurlRequest::isHandlePaused() const
{
ASSERT(!isMainThread() || m_isSyncRequest);
return m_isHandlePaused;
}
void CurlRequest::enableDownloadToFile()
{
LockHolder locker(m_downloadMutex);
m_isEnabledDownloadToFile = true;
}
const String& CurlRequest::getDownloadedFilePath()
{
LockHolder locker(m_downloadMutex);
return m_downloadFilePath;
}
void CurlRequest::writeDataToDownloadFileIfEnabled(const SharedBuffer& buffer)
{
{
LockHolder locker(m_downloadMutex);
if (!m_isEnabledDownloadToFile)
return;
if (m_downloadFilePath.isEmpty())
m_downloadFilePath = FileSystem::openTemporaryFile("download", m_downloadFileHandle);
}
if (m_downloadFileHandle != FileSystem::invalidPlatformFileHandle)
FileSystem::writeToFile(m_downloadFileHandle, buffer.data(), buffer.size());
}
void CurlRequest::closeDownloadFile()
{
LockHolder locker(m_downloadMutex);
if (m_downloadFileHandle == FileSystem::invalidPlatformFileHandle)
return;
FileSystem::closeFile(m_downloadFileHandle);
m_downloadFileHandle = FileSystem::invalidPlatformFileHandle;
}
void CurlRequest::cleanupDownloadFile()
{
LockHolder locker(m_downloadMutex);
if (!m_downloadFilePath.isEmpty()) {
FileSystem::deleteFile(m_downloadFilePath);
m_downloadFilePath = String();
}
}
size_t CurlRequest::willSendDataCallback(char* ptr, size_t blockSize, size_t numberOfBlocks, void* userData)
{
return static_cast<CurlRequest*>(userData)->willSendData(ptr, blockSize, numberOfBlocks);
}
size_t CurlRequest::didReceiveHeaderCallback(char* ptr, size_t blockSize, size_t numberOfBlocks, void* userData)
{
return static_cast<CurlRequest*>(userData)->didReceiveHeader(String(ptr, blockSize * numberOfBlocks));
}
size_t CurlRequest::didReceiveDataCallback(char* ptr, size_t blockSize, size_t numberOfBlocks, void* userData)
{
return static_cast<CurlRequest*>(userData)->didReceiveData(SharedBuffer::create(ptr, blockSize * numberOfBlocks));
}
}
#endif