mirror of
https://github.com/wichtounet/thor-os.git
synced 2025-09-22 11:07:51 -04:00
Guard the TCP connections datastructure
This commit is contained in:
parent
bfa88d4218
commit
29b5a777fb
@ -10,11 +10,14 @@
|
|||||||
#include <list.hpp>
|
#include <list.hpp>
|
||||||
|
|
||||||
#include "conc/condition_variable.hpp"
|
#include "conc/condition_variable.hpp"
|
||||||
|
#include "conc/rw_lock.hpp"
|
||||||
|
|
||||||
#include "net/tcp_layer.hpp"
|
#include "net/tcp_layer.hpp"
|
||||||
#include "net/dns_layer.hpp"
|
#include "net/dns_layer.hpp"
|
||||||
#include "net/checksum.hpp"
|
#include "net/checksum.hpp"
|
||||||
|
|
||||||
|
#include "tlib/errors.hpp"
|
||||||
|
|
||||||
#include "kernel_utils.hpp"
|
#include "kernel_utils.hpp"
|
||||||
#include "circular_buffer.hpp"
|
#include "circular_buffer.hpp"
|
||||||
#include "timer.hpp"
|
#include "timer.hpp"
|
||||||
@ -37,12 +40,12 @@ using flag_syn = std::bit_field<uint16_t, uint8_t, 1, 1>;
|
|||||||
using flag_fin = std::bit_field<uint16_t, uint8_t, 0, 1>;
|
using flag_fin = std::bit_field<uint16_t, uint8_t, 0, 1>;
|
||||||
|
|
||||||
struct tcp_connection {
|
struct tcp_connection {
|
||||||
size_t source_port;
|
size_t source_port; ///< The source port of the connection
|
||||||
size_t target_port;
|
size_t target_port; ///< The target port of the connection
|
||||||
|
|
||||||
std::atomic<bool> listening;
|
std::atomic<bool> listening; ///< Indicates if a kernel thread is listening on this connection
|
||||||
condition_variable queue;
|
condition_variable queue; ///< The listening queue
|
||||||
circular_buffer<network::ethernet::packet, 8> packets;
|
circular_buffer<network::ethernet::packet, 8> packets; ///< The packets for the listening queue
|
||||||
|
|
||||||
tcp_connection(size_t source_port, size_t target_port)
|
tcp_connection(size_t source_port, size_t target_port)
|
||||||
: source_port(source_port), target_port(target_port), listening(false) {
|
: source_port(source_port), target_port(target_port), listening(false) {
|
||||||
@ -50,24 +53,47 @@ struct tcp_connection {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The lock used to protect the list of connections
|
||||||
|
rw_lock connections_lock;
|
||||||
|
|
||||||
// Note: We need a list to not invalidate the values during insertions
|
// Note: We need a list to not invalidate the values during insertions
|
||||||
std::list<tcp_connection> connections;
|
std::list<tcp_connection> connections;
|
||||||
|
|
||||||
tcp_connection* get_connection(size_t source_port, size_t target_port) {
|
tcp_connection* get_connection(size_t source_port, size_t target_port) {
|
||||||
|
auto lock = connections_lock.reader_lock();
|
||||||
|
std::lock_guard<reader_rw_lock> l(lock);
|
||||||
|
|
||||||
|
for(auto& connection : connections){
|
||||||
|
if (connection.source_port == source_port && connection.target_port == target_port) {
|
||||||
|
return &connection;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp_connection& create_connection(size_t target, size_t source){
|
||||||
|
auto lock = connections_lock.writer_lock();
|
||||||
|
std::lock_guard<writer_rw_lock> l(lock);
|
||||||
|
|
||||||
|
return connections.emplace_back(target, source);
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_connection(tcp_connection& connection){
|
||||||
|
auto lock = connections_lock.writer_lock();
|
||||||
|
std::lock_guard<writer_rw_lock> l(lock);
|
||||||
|
|
||||||
auto end = connections.end();
|
auto end = connections.end();
|
||||||
auto it = connections.begin();
|
auto it = connections.begin();
|
||||||
|
|
||||||
while (it != end) {
|
while (it != end) {
|
||||||
auto& connection = *it;
|
if (&(*it) == &connection) {
|
||||||
|
connections.erase(it);
|
||||||
if (connection.source_port == source_port && connection.target_port == target_port) {
|
return;
|
||||||
return &connection;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
++it;
|
++it;
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void compute_checksum(network::ethernet::packet& packet) {
|
void compute_checksum(network::ethernet::packet& packet) {
|
||||||
@ -154,22 +180,20 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
|
|||||||
|
|
||||||
// Propagate to kernel connections
|
// Propagate to kernel connections
|
||||||
|
|
||||||
auto end = connections.end();
|
{
|
||||||
auto it = connections.begin();
|
auto lock = connections_lock.reader_lock();
|
||||||
|
std::lock_guard<reader_rw_lock> l(lock);
|
||||||
|
|
||||||
while (it != end) {
|
for (auto& connection : connections) {
|
||||||
auto& connection = *it;
|
if (connection.listening.load() && connection.source_port == source_port && connection.target_port == target_port) {
|
||||||
|
auto copy = packet;
|
||||||
|
copy.payload = new char[copy.payload_size];
|
||||||
|
std::copy_n(packet.payload, packet.payload_size, copy.payload);
|
||||||
|
|
||||||
if (connection.listening.load() && connection.source_port == source_port && connection.target_port == target_port) {
|
connection.packets.push(copy);
|
||||||
auto copy = packet;
|
connection.queue.notify_one();
|
||||||
copy.payload = new char[copy.payload_size];
|
}
|
||||||
std::copy_n(packet.payload, packet.payload_size, copy.payload);
|
|
||||||
|
|
||||||
connection.packets.push(copy);
|
|
||||||
connection.queue.notify_one();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
++it;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto seq_number = ack;
|
auto seq_number = ack;
|
||||||
@ -365,7 +389,7 @@ std::expected<void> network::tcp::connect(network::socket& sock, network::interf
|
|||||||
|
|
||||||
// Create the connection
|
// Create the connection
|
||||||
|
|
||||||
auto& connection = connections.emplace_back(target, source);
|
auto& connection = create_connection(target, source);
|
||||||
|
|
||||||
connection.listening = true;
|
connection.listening = true;
|
||||||
|
|
||||||
@ -450,7 +474,14 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
|
|||||||
(flag_ack(&flags)) = 1;
|
(flag_ack(&flags)) = 1;
|
||||||
tcp_header->flags = switch_endian_16(flags);
|
tcp_header->flags = switch_endian_16(flags);
|
||||||
|
|
||||||
auto& connection = connections.emplace_back(target, source);
|
auto connection_ptr = get_connection(target, source);
|
||||||
|
|
||||||
|
if (!connection_ptr) {
|
||||||
|
logging::logf(logging::log_level::ERROR, "tcp: Unable to find connection!\n");
|
||||||
|
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_CONNECTION);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& connection = *connection_ptr;
|
||||||
|
|
||||||
connection.listening = true;
|
connection.listening = true;
|
||||||
|
|
||||||
@ -523,17 +554,7 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
|
|||||||
tcp::finalize_packet(interface, *packet);
|
tcp::finalize_packet(interface, *packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto end = connections.end();
|
remove_connection(connection);
|
||||||
auto it = connections.begin();
|
|
||||||
|
|
||||||
while (it != end) {
|
|
||||||
if (&(*it) == &connection) {
|
|
||||||
connections.erase(it);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,7 @@ constexpr const size_t ERROR_SOCKET_TIMEOUT = 28;
|
|||||||
constexpr const size_t ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR = 29;
|
constexpr const size_t ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR = 29;
|
||||||
constexpr const size_t ERROR_SOCKET_INVALID_TYPE_PROTOCOL = 30;
|
constexpr const size_t ERROR_SOCKET_INVALID_TYPE_PROTOCOL = 30;
|
||||||
constexpr const size_t ERROR_SOCKET_NOT_CONNECTED = 31;
|
constexpr const size_t ERROR_SOCKET_NOT_CONNECTED = 31;
|
||||||
|
constexpr const size_t ERROR_SOCKET_INVALID_CONNECTION = 32;
|
||||||
|
|
||||||
inline const char* error_message(size_t error){
|
inline const char* error_message(size_t error){
|
||||||
switch(error){
|
switch(error){
|
||||||
@ -108,6 +109,8 @@ inline const char* error_message(size_t error){
|
|||||||
return "The socket protocol is not vaild with this type";
|
return "The socket protocol is not vaild with this type";
|
||||||
case ERROR_SOCKET_NOT_CONNECTED:
|
case ERROR_SOCKET_NOT_CONNECTED:
|
||||||
return "The socket is not connected";
|
return "The socket is not connected";
|
||||||
|
case ERROR_SOCKET_INVALID_CONNECTION:
|
||||||
|
return "Issue with the internal connection";
|
||||||
default:
|
default:
|
||||||
return "Unknonwn error";
|
return "Unknonwn error";
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user