From 29b5a777fb4ef3b9fab2fef4e993bddff59b8e62 Mon Sep 17 00:00:00 2001 From: Baptiste Wicht Date: Sat, 17 Sep 2016 16:21:20 +0200 Subject: [PATCH] Guard the TCP connections datastructure --- kernel/src/net/tcp_layer.cpp | 95 ++++++++++++++++++++++-------------- tlib/include/tlib/errors.hpp | 3 ++ 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index 99a6af2a..59312efd 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -10,11 +10,14 @@ #include #include "conc/condition_variable.hpp" +#include "conc/rw_lock.hpp" #include "net/tcp_layer.hpp" #include "net/dns_layer.hpp" #include "net/checksum.hpp" +#include "tlib/errors.hpp" + #include "kernel_utils.hpp" #include "circular_buffer.hpp" #include "timer.hpp" @@ -37,12 +40,12 @@ using flag_syn = std::bit_field; using flag_fin = std::bit_field; struct tcp_connection { - size_t source_port; - size_t target_port; + size_t source_port; ///< The source port of the connection + size_t target_port; ///< The target port of the connection - std::atomic listening; - condition_variable queue; - circular_buffer packets; + std::atomic listening; ///< Indicates if a kernel thread is listening on this connection + condition_variable queue; ///< The listening queue + circular_buffer packets; ///< The packets for the listening queue tcp_connection(size_t source_port, size_t target_port) : source_port(source_port), target_port(target_port), listening(false) { @@ -50,24 +53,47 @@ struct tcp_connection { } }; +// The lock used to protect the list of connections +rw_lock connections_lock; + // Note: We need a list to not invalidate the values during insertions std::list connections; tcp_connection* get_connection(size_t source_port, size_t target_port) { + auto lock = connections_lock.reader_lock(); + std::lock_guard l(lock); + + for(auto& connection : connections){ + if (connection.source_port == source_port && connection.target_port == target_port) { + return &connection; + } + } + + return nullptr; +} + +tcp_connection& create_connection(size_t target, size_t source){ + auto lock = connections_lock.writer_lock(); + std::lock_guard l(lock); + + return connections.emplace_back(target, source); +} + +void remove_connection(tcp_connection& connection){ + auto lock = connections_lock.writer_lock(); + std::lock_guard l(lock); + auto end = connections.end(); auto it = connections.begin(); while (it != end) { - auto& connection = *it; - - if (connection.source_port == source_port && connection.target_port == target_port) { - return &connection; + if (&(*it) == &connection) { + connections.erase(it); + return; } ++it; } - - return nullptr; } void compute_checksum(network::ethernet::packet& packet) { @@ -154,22 +180,20 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth // Propagate to kernel connections - auto end = connections.end(); - auto it = connections.begin(); + { + auto lock = connections_lock.reader_lock(); + std::lock_guard l(lock); - while (it != end) { - auto& connection = *it; + for (auto& connection : connections) { + if (connection.listening.load() && connection.source_port == source_port && connection.target_port == target_port) { + auto copy = packet; + copy.payload = new char[copy.payload_size]; + std::copy_n(packet.payload, packet.payload_size, copy.payload); - if (connection.listening.load() && connection.source_port == source_port && connection.target_port == target_port) { - auto copy = packet; - copy.payload = new char[copy.payload_size]; - std::copy_n(packet.payload, packet.payload_size, copy.payload); - - connection.packets.push(copy); - connection.queue.notify_one(); + connection.packets.push(copy); + connection.queue.notify_one(); + } } - - ++it; } auto seq_number = ack; @@ -365,7 +389,7 @@ std::expected network::tcp::connect(network::socket& sock, network::interf // Create the connection - auto& connection = connections.emplace_back(target, source); + auto& connection = create_connection(target, source); connection.listening = true; @@ -450,7 +474,14 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int (flag_ack(&flags)) = 1; tcp_header->flags = switch_endian_16(flags); - auto& connection = connections.emplace_back(target, source); + auto connection_ptr = get_connection(target, source); + + if (!connection_ptr) { + logging::logf(logging::log_level::ERROR, "tcp: Unable to find connection!\n"); + return std::make_unexpected(std::ERROR_SOCKET_INVALID_CONNECTION); + } + + auto& connection = *connection_ptr; connection.listening = true; @@ -523,17 +554,7 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int tcp::finalize_packet(interface, *packet); } - auto end = connections.end(); - auto it = connections.begin(); - - while (it != end) { - if (&(*it) == &connection) { - connections.erase(it); - break; - } - - ++it; - } + remove_connection(connection); return {}; } diff --git a/tlib/include/tlib/errors.hpp b/tlib/include/tlib/errors.hpp index 4f3d0e82..91c74f12 100644 --- a/tlib/include/tlib/errors.hpp +++ b/tlib/include/tlib/errors.hpp @@ -43,6 +43,7 @@ constexpr const size_t ERROR_SOCKET_TIMEOUT = 28; constexpr const size_t ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR = 29; constexpr const size_t ERROR_SOCKET_INVALID_TYPE_PROTOCOL = 30; constexpr const size_t ERROR_SOCKET_NOT_CONNECTED = 31; +constexpr const size_t ERROR_SOCKET_INVALID_CONNECTION = 32; inline const char* error_message(size_t error){ switch(error){ @@ -108,6 +109,8 @@ inline const char* error_message(size_t error){ return "The socket protocol is not vaild with this type"; case ERROR_SOCKET_NOT_CONNECTED: return "The socket is not connected"; + case ERROR_SOCKET_INVALID_CONNECTION: + return "Issue with the internal connection"; default: return "Unknonwn error"; }