SocketConnection.cpp   [plain text]


/*
 * Copyright (C) 2019 Igalia, S.L.
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Library General Public
 *  License as published by the Free Software Foundation; either
 *  version 2 of the License, or (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Library General Public License for more details.
 *
 *  You should have received a copy of the GNU Library General Public License
 *  along with this library; see the file COPYING.LIB.  If not, write to
 *  the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 *  Boston, MA 02110-1301, USA.
 */

#include "config.h"
#include "SocketConnection.h"

#include <cstring>
#include <gio/gio.h>
#include <wtf/ByteOrder.h>
#include <wtf/CheckedArithmetic.h>
#include <wtf/FastMalloc.h>
#include <wtf/RunLoop.h>

namespace WTF {

static const unsigned defaultBufferSize = 4096;

SocketConnection::SocketConnection(GRefPtr<GSocketConnection>&& connection, const MessageHandlers& messageHandlers, gpointer userData)
    : m_connection(WTFMove(connection))
    , m_messageHandlers(messageHandlers)
    , m_userData(userData)
{
    relaxAdoptionRequirement();

    m_readBuffer.reserveInitialCapacity(defaultBufferSize);
    m_writeBuffer.reserveInitialCapacity(defaultBufferSize);

    auto* socket = g_socket_connection_get_socket(m_connection.get());
    g_socket_set_blocking(socket, FALSE);
    m_readMonitor.start(socket, G_IO_IN, RunLoop::current(), [this, protectedThis = makeRef(*this)](GIOCondition condition) -> gboolean {
        if (isClosed())
            return G_SOURCE_REMOVE;

        if (condition & G_IO_HUP || condition & G_IO_ERR || condition & G_IO_NVAL) {
            didClose();
            return G_SOURCE_REMOVE;
        }

        ASSERT(condition & G_IO_IN);
        return read();
    });
}

SocketConnection::~SocketConnection()
{
}

bool SocketConnection::read()
{
    while (true) {
        size_t previousBufferSize = m_readBuffer.size();
        if (m_readBuffer.capacity() - previousBufferSize <= 0)
            m_readBuffer.reserveCapacity(m_readBuffer.capacity() + defaultBufferSize);
        m_readBuffer.grow(m_readBuffer.capacity());
        GUniqueOutPtr<GError> error;
        auto bytesRead = g_socket_receive(g_socket_connection_get_socket(m_connection.get()), m_readBuffer.data() + previousBufferSize, m_readBuffer.size() - previousBufferSize, nullptr, &error.outPtr());
        if (bytesRead == -1) {
            if (g_error_matches(error.get(), G_IO_ERROR, G_IO_ERROR_WOULD_BLOCK)) {
                m_readBuffer.shrink(previousBufferSize);
                break;
            }

            g_warning("Error reading from socket connection: %s\n", error->message);
            didClose();
            return G_SOURCE_REMOVE;
        }

        if (!bytesRead) {
            didClose();
            return G_SOURCE_REMOVE;
        }

        m_readBuffer.shrink(previousBufferSize + bytesRead);

        while (readMessage()) { }
        if (isClosed())
            return G_SOURCE_REMOVE;
    }
    return G_SOURCE_CONTINUE;
}

enum {
    ByteOrderLittleEndian = 1 << 0
};
typedef uint8_t MessageFlags;

static inline bool messageIsByteSwapped(MessageFlags flags)
{
#if G_BYTE_ORDER == G_LITTLE_ENDIAN
    return !(flags & ByteOrderLittleEndian);
#else
    return (flags & ByteOrderLittleEndian);
#endif
}

bool SocketConnection::readMessage()
{
    if (m_readBuffer.size() < sizeof(uint32_t))
        return false;

    auto* messageData = m_readBuffer.data();
    uint32_t bodySizeHeader;
    memcpy(&bodySizeHeader, messageData, sizeof(uint32_t));
    messageData += sizeof(uint32_t);
    bodySizeHeader = ntohl(bodySizeHeader);
    Checked<size_t> bodySize = bodySizeHeader;
    MessageFlags flags;
    memcpy(&flags, messageData, sizeof(MessageFlags));
    messageData += sizeof(MessageFlags);
    auto messageSize = sizeof(uint32_t) + sizeof(MessageFlags) + bodySize;
    if (m_readBuffer.size() < messageSize.unsafeGet())
        return false;

    Checked<size_t> messageNameLength = strlen(messageData);
    messageNameLength++;
    if (m_readBuffer.size() < messageNameLength.unsafeGet()) {
        ASSERT_NOT_REACHED();
        return false;
    }

    const auto it = m_messageHandlers.find(messageData);
    if (it != m_messageHandlers.end()) {
        messageData += messageNameLength.unsafeGet();
        GRefPtr<GVariant> parameters;
        if (!it->value.first.isNull()) {
            GUniquePtr<GVariantType> variantType(g_variant_type_new(it->value.first.data()));
            size_t parametersSize = bodySize.unsafeGet() - messageNameLength.unsafeGet();
            // g_variant_new_from_data() requires the memory to be properly aligned for the type being loaded,
            // but it's not possible to know the alignment because g_variant_type_info_query() is not public API.
            // Since GLib 2.60 g_variant_new_from_data() already checks the alignment and reallocates the buffer
            // in aligned memory only if needed. For older versions we can simply ensure the memory is 8 aligned.
#if GLIB_CHECK_VERSION(2, 60, 0)
            parameters = g_variant_new_from_data(variantType.get(), messageData, parametersSize, FALSE, nullptr, nullptr);
#else
            auto* alignedMemory = fastAlignedMalloc(8, parametersSize);
            memcpy(alignedMemory, messageData, parametersSize);
            GRefPtr<GBytes> bytes = g_bytes_new_with_free_func(alignedMemory, parametersSize, [](gpointer data) {
                fastAlignedFree(data);
            }, alignedMemory);
            parameters = g_variant_new_from_bytes(variantType.get(), bytes.get(), FALSE);
#endif
            if (messageIsByteSwapped(flags))
                parameters = adoptGRef(g_variant_byteswap(parameters.get()));
        }
        it->value.second(*this, parameters.get(), m_userData);
        if (isClosed())
            return false;
    }

    if (m_readBuffer.size() > messageSize.unsafeGet()) {
        std::memmove(m_readBuffer.data(), m_readBuffer.data() + messageSize.unsafeGet(), m_readBuffer.size() - messageSize.unsafeGet());
        m_readBuffer.shrink(m_readBuffer.size() - messageSize.unsafeGet());
    } else
        m_readBuffer.shrink(0);

    if (m_readBuffer.size() < defaultBufferSize)
        m_readBuffer.shrinkCapacity(defaultBufferSize);

    return true;
}

void SocketConnection::sendMessage(const char* messageName, GVariant* parameters)
{
    GRefPtr<GVariant> adoptedParameters = parameters;
    size_t parametersSize = parameters ? g_variant_get_size(parameters) : 0;
    CheckedSize messageNameLength = strlen(messageName);
    messageNameLength++;
    if (UNLIKELY(messageNameLength.hasOverflowed())) {
        g_warning("Trying to send message with invalid too long name");
        return;
    }
    Checked<uint32_t, RecordOverflow> bodySize = messageNameLength + parametersSize;
    if (UNLIKELY(bodySize.hasOverflowed())) {
        g_warning("Trying to send message '%s' with invalid too long body", messageName);
        return;
    }
    size_t previousBufferSize = m_writeBuffer.size();
    m_writeBuffer.grow(previousBufferSize + sizeof(uint32_t) + sizeof(MessageFlags) + bodySize.unsafeGet());

    auto* messageData = m_writeBuffer.data() + previousBufferSize;
    uint32_t bodySizeHeader = htonl(bodySize.unsafeGet());
    memcpy(messageData, &bodySizeHeader, sizeof(uint32_t));
    messageData += sizeof(uint32_t);
    MessageFlags flags = 0;
#if G_BYTE_ORDER == G_LITTLE_ENDIAN
    flags |= ByteOrderLittleEndian;
#endif
    memcpy(messageData, &flags, sizeof(MessageFlags));
    messageData += sizeof(MessageFlags);
    memcpy(messageData, messageName, messageNameLength.unsafeGet());
    messageData += messageNameLength.unsafeGet();
    if (parameters)
        memcpy(messageData, g_variant_get_data(parameters), parametersSize);

    write();
}

void SocketConnection::write()
{
    if (isClosed())
        return;

    GUniqueOutPtr<GError> error;
    auto bytesWritten = g_socket_send(g_socket_connection_get_socket(m_connection.get()), m_writeBuffer.data(), m_writeBuffer.size(), nullptr, &error.outPtr());
    if (bytesWritten == -1) {
        if (g_error_matches(error.get(), G_IO_ERROR, G_IO_ERROR_WOULD_BLOCK)) {
            waitForSocketWritability();
            return;
        }

        g_warning("Error sending message on socket connection: %s\n", error->message);
        didClose();
        return;
    }

    if (m_writeBuffer.size() > static_cast<size_t>(bytesWritten)) {
        std::memmove(m_writeBuffer.data(), m_writeBuffer.data() + bytesWritten, m_writeBuffer.size() - bytesWritten);
        m_writeBuffer.shrink(m_writeBuffer.size() - bytesWritten);
    } else
        m_writeBuffer.shrink(0);

    if (m_writeBuffer.size() < defaultBufferSize)
        m_writeBuffer.shrinkCapacity(defaultBufferSize);

    if (!m_writeBuffer.isEmpty())
        waitForSocketWritability();
}

void SocketConnection::waitForSocketWritability()
{
    if (m_writeMonitor.isActive())
        return;

    m_writeMonitor.start(g_socket_connection_get_socket(m_connection.get()), G_IO_OUT, RunLoop::current(), [this, protectedThis = makeRef(*this)] (GIOCondition condition) -> gboolean {
        if (condition & G_IO_OUT) {
            // We can't stop the monitor from this lambda, because stop destroys the lambda.
            RunLoop::current().dispatch([this, protectedThis] {
                m_writeMonitor.stop();
                write();
            });
        }
        return G_SOURCE_REMOVE;
    });
}

void SocketConnection::close()
{
    m_readMonitor.stop();
    m_writeMonitor.stop();
    m_connection = nullptr;
}

void SocketConnection::didClose()
{
    if (isClosed())
        return;

    close();
    ASSERT(m_messageHandlers.contains("DidClose"));
    m_messageHandlers.get("DidClose").second(*this, nullptr, m_userData);
}

} // namespace WTF