280 lines
9.8 KiB
C++
280 lines
9.8 KiB
C++
/*
|
|
* Copyright (C) 2019 Igalia, S.L.
|
|
*
|
|
* This library is free software; you can redistribute it and/or
|
|
* modify it under the terms of the GNU Library General Public
|
|
* License as published by the Free Software Foundation; either
|
|
* version 2 of the License, or (at your option) any later version.
|
|
*
|
|
* This library is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
* Library General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Library General Public License
|
|
* along with this library; see the file COPYING.LIB. If not, write to
|
|
* the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
|
|
* Boston, MA 02110-1301, USA.
|
|
*/
|
|
|
|
#include "config.h"
|
|
#include "SocketConnection.h"
|
|
|
|
#include <cstring>
|
|
#include <gio/gio.h>
|
|
#include <wtf/ByteOrder.h>
|
|
#include <wtf/CheckedArithmetic.h>
|
|
#include <wtf/FastMalloc.h>
|
|
#include <wtf/RunLoop.h>
|
|
|
|
namespace WTF {
|
|
|
|
static const unsigned defaultBufferSize = 4096;
|
|
|
|
SocketConnection::SocketConnection(GRefPtr<GSocketConnection>&& connection, const MessageHandlers& messageHandlers, gpointer userData)
|
|
: m_connection(WTFMove(connection))
|
|
, m_messageHandlers(messageHandlers)
|
|
, m_userData(userData)
|
|
{
|
|
relaxAdoptionRequirement();
|
|
|
|
m_readBuffer.reserveInitialCapacity(defaultBufferSize);
|
|
m_writeBuffer.reserveInitialCapacity(defaultBufferSize);
|
|
|
|
auto* socket = g_socket_connection_get_socket(m_connection.get());
|
|
g_socket_set_blocking(socket, FALSE);
|
|
m_readMonitor.start(socket, G_IO_IN, RunLoop::current(), [this, protectedThis = makeRef(*this)](GIOCondition condition) -> gboolean {
|
|
if (isClosed())
|
|
return G_SOURCE_REMOVE;
|
|
|
|
if (condition & G_IO_HUP || condition & G_IO_ERR || condition & G_IO_NVAL) {
|
|
didClose();
|
|
return G_SOURCE_REMOVE;
|
|
}
|
|
|
|
ASSERT(condition & G_IO_IN);
|
|
return read();
|
|
});
|
|
}
|
|
|
|
SocketConnection::~SocketConnection()
|
|
{
|
|
}
|
|
|
|
bool SocketConnection::read()
|
|
{
|
|
while (true) {
|
|
size_t previousBufferSize = m_readBuffer.size();
|
|
if (m_readBuffer.capacity() - previousBufferSize <= 0)
|
|
m_readBuffer.reserveCapacity(m_readBuffer.capacity() + defaultBufferSize);
|
|
m_readBuffer.grow(m_readBuffer.capacity());
|
|
GUniqueOutPtr<GError> error;
|
|
auto bytesRead = g_socket_receive(g_socket_connection_get_socket(m_connection.get()), m_readBuffer.data() + previousBufferSize, m_readBuffer.size() - previousBufferSize, nullptr, &error.outPtr());
|
|
if (bytesRead == -1) {
|
|
if (g_error_matches(error.get(), G_IO_ERROR, G_IO_ERROR_WOULD_BLOCK)) {
|
|
m_readBuffer.shrink(previousBufferSize);
|
|
break;
|
|
}
|
|
|
|
g_warning("Error reading from socket connection: %s\n", error->message);
|
|
didClose();
|
|
return G_SOURCE_REMOVE;
|
|
}
|
|
|
|
if (!bytesRead) {
|
|
didClose();
|
|
return G_SOURCE_REMOVE;
|
|
}
|
|
|
|
m_readBuffer.shrink(previousBufferSize + bytesRead);
|
|
|
|
while (readMessage()) { }
|
|
if (isClosed())
|
|
return G_SOURCE_REMOVE;
|
|
}
|
|
return G_SOURCE_CONTINUE;
|
|
}
|
|
|
|
enum {
|
|
ByteOrderLittleEndian = 1 << 0
|
|
};
|
|
typedef uint8_t MessageFlags;
|
|
|
|
static inline bool messageIsByteSwapped(MessageFlags flags)
|
|
{
|
|
#if G_BYTE_ORDER == G_LITTLE_ENDIAN
|
|
return !(flags & ByteOrderLittleEndian);
|
|
#else
|
|
return (flags & ByteOrderLittleEndian);
|
|
#endif
|
|
}
|
|
|
|
bool SocketConnection::readMessage()
|
|
{
|
|
if (m_readBuffer.size() < sizeof(uint32_t))
|
|
return false;
|
|
|
|
auto* messageData = m_readBuffer.data();
|
|
uint32_t bodySizeHeader;
|
|
memcpy(&bodySizeHeader, messageData, sizeof(uint32_t));
|
|
messageData += sizeof(uint32_t);
|
|
bodySizeHeader = ntohl(bodySizeHeader);
|
|
Checked<size_t> bodySize = bodySizeHeader;
|
|
MessageFlags flags;
|
|
memcpy(&flags, messageData, sizeof(MessageFlags));
|
|
messageData += sizeof(MessageFlags);
|
|
auto messageSize = sizeof(uint32_t) + sizeof(MessageFlags) + bodySize;
|
|
if (m_readBuffer.size() < messageSize)
|
|
return false;
|
|
|
|
Checked<size_t> messageNameLength = strlen(messageData);
|
|
messageNameLength++;
|
|
if (m_readBuffer.size() < messageNameLength) {
|
|
ASSERT_NOT_REACHED();
|
|
return false;
|
|
}
|
|
|
|
const auto it = m_messageHandlers.find(messageData);
|
|
if (it != m_messageHandlers.end()) {
|
|
messageData += messageNameLength.value();
|
|
GRefPtr<GVariant> parameters;
|
|
if (!it->value.first.isNull()) {
|
|
GUniquePtr<GVariantType> variantType(g_variant_type_new(it->value.first.data()));
|
|
size_t parametersSize = bodySize.value() - messageNameLength.value();
|
|
// g_variant_new_from_data() requires the memory to be properly aligned for the type being loaded,
|
|
// but it's not possible to know the alignment because g_variant_type_info_query() is not public API.
|
|
// Since GLib 2.60 g_variant_new_from_data() already checks the alignment and reallocates the buffer
|
|
// in aligned memory only if needed. For older versions we can simply ensure the memory is 8 aligned.
|
|
#if GLIB_CHECK_VERSION(2, 60, 0)
|
|
parameters = g_variant_new_from_data(variantType.get(), messageData, parametersSize, FALSE, nullptr, nullptr);
|
|
#else
|
|
auto* alignedMemory = fastAlignedMalloc(8, parametersSize);
|
|
memcpy(alignedMemory, messageData, parametersSize);
|
|
GRefPtr<GBytes> bytes = g_bytes_new_with_free_func(alignedMemory, parametersSize, [](gpointer data) {
|
|
fastAlignedFree(data);
|
|
}, alignedMemory);
|
|
parameters = g_variant_new_from_bytes(variantType.get(), bytes.get(), FALSE);
|
|
#endif
|
|
if (messageIsByteSwapped(flags))
|
|
parameters = adoptGRef(g_variant_byteswap(parameters.get()));
|
|
}
|
|
it->value.second(*this, parameters.get(), m_userData);
|
|
if (isClosed())
|
|
return false;
|
|
}
|
|
|
|
if (m_readBuffer.size() > messageSize) {
|
|
std::memmove(m_readBuffer.data(), m_readBuffer.data() + messageSize.value(), m_readBuffer.size() - messageSize.value());
|
|
m_readBuffer.shrink(m_readBuffer.size() - messageSize.value());
|
|
} else
|
|
m_readBuffer.shrink(0);
|
|
|
|
if (m_readBuffer.size() < defaultBufferSize)
|
|
m_readBuffer.shrinkCapacity(defaultBufferSize);
|
|
|
|
return true;
|
|
}
|
|
|
|
void SocketConnection::sendMessage(const char* messageName, GVariant* parameters)
|
|
{
|
|
GRefPtr<GVariant> adoptedParameters = parameters;
|
|
size_t parametersSize = parameters ? g_variant_get_size(parameters) : 0;
|
|
CheckedSize messageNameLength = strlen(messageName);
|
|
messageNameLength++;
|
|
if (UNLIKELY(messageNameLength.hasOverflowed())) {
|
|
g_warning("Trying to send message with invalid too long name");
|
|
return;
|
|
}
|
|
CheckedUint32 bodySize = messageNameLength + parametersSize;
|
|
if (UNLIKELY(bodySize.hasOverflowed())) {
|
|
g_warning("Trying to send message '%s' with invalid too long body", messageName);
|
|
return;
|
|
}
|
|
size_t previousBufferSize = m_writeBuffer.size();
|
|
m_writeBuffer.grow(previousBufferSize + sizeof(uint32_t) + sizeof(MessageFlags) + bodySize.value());
|
|
|
|
auto* messageData = m_writeBuffer.data() + previousBufferSize;
|
|
uint32_t bodySizeHeader = htonl(bodySize.value());
|
|
memcpy(messageData, &bodySizeHeader, sizeof(uint32_t));
|
|
messageData += sizeof(uint32_t);
|
|
MessageFlags flags = 0;
|
|
#if G_BYTE_ORDER == G_LITTLE_ENDIAN
|
|
flags |= ByteOrderLittleEndian;
|
|
#endif
|
|
memcpy(messageData, &flags, sizeof(MessageFlags));
|
|
messageData += sizeof(MessageFlags);
|
|
memcpy(messageData, messageName, messageNameLength);
|
|
messageData += messageNameLength.value();
|
|
if (parameters)
|
|
memcpy(messageData, g_variant_get_data(parameters), parametersSize);
|
|
|
|
write();
|
|
}
|
|
|
|
void SocketConnection::write()
|
|
{
|
|
if (isClosed())
|
|
return;
|
|
|
|
GUniqueOutPtr<GError> error;
|
|
auto bytesWritten = g_socket_send(g_socket_connection_get_socket(m_connection.get()), m_writeBuffer.data(), m_writeBuffer.size(), nullptr, &error.outPtr());
|
|
if (bytesWritten == -1) {
|
|
if (g_error_matches(error.get(), G_IO_ERROR, G_IO_ERROR_WOULD_BLOCK)) {
|
|
waitForSocketWritability();
|
|
return;
|
|
}
|
|
|
|
g_warning("Error sending message on socket connection: %s\n", error->message);
|
|
didClose();
|
|
return;
|
|
}
|
|
|
|
if (m_writeBuffer.size() > static_cast<size_t>(bytesWritten)) {
|
|
std::memmove(m_writeBuffer.data(), m_writeBuffer.data() + bytesWritten, m_writeBuffer.size() - bytesWritten);
|
|
m_writeBuffer.shrink(m_writeBuffer.size() - bytesWritten);
|
|
} else
|
|
m_writeBuffer.shrink(0);
|
|
|
|
if (m_writeBuffer.size() < defaultBufferSize)
|
|
m_writeBuffer.shrinkCapacity(defaultBufferSize);
|
|
|
|
if (!m_writeBuffer.isEmpty())
|
|
waitForSocketWritability();
|
|
}
|
|
|
|
void SocketConnection::waitForSocketWritability()
|
|
{
|
|
if (m_writeMonitor.isActive())
|
|
return;
|
|
|
|
m_writeMonitor.start(g_socket_connection_get_socket(m_connection.get()), G_IO_OUT, RunLoop::current(), [this, protectedThis = makeRef(*this)] (GIOCondition condition) -> gboolean {
|
|
if (condition & G_IO_OUT) {
|
|
// We can't stop the monitor from this lambda, because stop destroys the lambda.
|
|
RunLoop::current().dispatch([this, protectedThis] {
|
|
m_writeMonitor.stop();
|
|
write();
|
|
});
|
|
}
|
|
return G_SOURCE_REMOVE;
|
|
});
|
|
}
|
|
|
|
void SocketConnection::close()
|
|
{
|
|
m_readMonitor.stop();
|
|
m_writeMonitor.stop();
|
|
m_connection = nullptr;
|
|
}
|
|
|
|
void SocketConnection::didClose()
|
|
{
|
|
if (isClosed())
|
|
return;
|
|
|
|
close();
|
|
ASSERT(m_messageHandlers.contains("DidClose"));
|
|
m_messageHandlers.get("DidClose").second(*this, nullptr, m_userData);
|
|
}
|
|
|
|
} // namespace WTF
|