diff --git a/kernel/include/net/network.hpp b/kernel/include/net/network.hpp index ac9e127a..0f796e52 100644 --- a/kernel/include/net/network.hpp +++ b/kernel/include/net/network.hpp @@ -67,6 +67,8 @@ size_t number_of_interfaces(); interface_descriptor& interface(size_t index); +interface_descriptor& select_interface(network::ip::address address); + /*! * \brief Open a new socket * \param domain The socket domain diff --git a/kernel/include/net/socket.hpp b/kernel/include/net/socket.hpp index 413e7a2f..13ca1e0e 100644 --- a/kernel/include/net/socket.hpp +++ b/kernel/include/net/socket.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "tlib/net_constants.hpp" @@ -24,19 +25,16 @@ namespace network { struct socket { - size_t id; + size_t id; ///< The socket file descriptor socket_domain domain; socket_type type; socket_protocol protocol; size_t next_fd; bool listen; - bool connected; - uint32_t local_port; - uint32_t server_port; - ip::address server_address; - uint32_t ack_number; - uint32_t seq_number; + void* data = nullptr; + + uint32_t local_port; //TODO This should not be here since it belongs to UDP std::vector packets; @@ -90,6 +88,18 @@ struct socket { return packet.fd == fd; }), packets.end()); } + + template + T& get_data(){ + thor_assert(data); + return *reinterpret_cast(data); + } + + template + std::add_const_t& get_data() const { + thor_assert(data); + return *reinterpret_cast*>(data); + } }; } // end of network namespace diff --git a/kernel/include/net/tcp_layer.hpp b/kernel/include/net/tcp_layer.hpp index a031078d..2ccb7b9b 100644 --- a/kernel/include/net/tcp_layer.hpp +++ b/kernel/include/net/tcp_layer.hpp @@ -22,13 +22,13 @@ namespace tcp { void decode(network::interface_descriptor& interface, network::ethernet::packet& packet); std::expected prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size); -std::expected prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size); +std::expected prepare_packet(char* buffer, network::socket& socket, size_t payload_size); void finalize_packet(network::interface_descriptor& interface, network::ethernet::packet& p); void finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p); -std::expected connect(network::socket& socket, network::interface_descriptor& interface); -std::expected disconnect(network::socket& socket, network::interface_descriptor& interface); +std::expected connect(network::socket& socket, network::interface_descriptor& interface, size_t local_port, size_t server_port, network::ip::address server); +std::expected disconnect(network::socket& socket); } // end of tcp namespace diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index f3208e34..384daeac 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -74,26 +74,6 @@ void tx_thread(void* data){ } } -network::interface_descriptor& select_interface(network::ip::address address){ - if(address == network::ip::make_address(127, 0, 0, 1)){ - for(auto& interface : interfaces){ - if(interface.enabled && interface.is_loopback()){ - return interface; - } - } - } - - // Otherwise return the first enabled interface - - for(auto& interface : interfaces){ - if(interface.enabled){ - return interface; - } - } - - thor_unreachable("network: Should never happen"); -} - void sysfs_publish(const network::interface_descriptor& interface){ auto p = path("/net") / interface.name; @@ -232,6 +212,27 @@ network::interface_descriptor& network::interface(size_t index){ return interfaces[index]; } +network::interface_descriptor& network::select_interface(network::ip::address address){ + if(address == network::ip::make_address(127, 0, 0, 1)){ + for(auto& interface : interfaces){ + if(interface.enabled && interface.is_loopback()){ + return interface; + } + } + } + + // Otherwise return the first enabled interface + + for(auto& interface : interfaces){ + if(interface.enabled){ + return interface; + } + } + + thor_unreachable("network: Should never happen"); +} + + std::expected network::open(network::socket_domain domain, network::socket_type type, network::socket_protocol protocol){ // Make sure the socket domain is valid if(domain != socket_domain::AF_INET){ @@ -258,15 +259,7 @@ std::expected network::open(network::socket_domain domain, return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); } - auto socket_fd = scheduler::register_new_socket(domain, type, protocol); - - // Initialize TCP connection values - auto& socket = scheduler::get_socket(socket_fd); - socket.connected = false; - socket.local_port = 0; - socket.server_port = 0; - - return socket_fd; + return scheduler::register_new_socket(domain, type, protocol); } void network::close(size_t fd){ @@ -286,11 +279,6 @@ std::tuple network::prepare_packet(socket_fd_t socket_fd, void* auto& socket = scheduler::get_socket(socket_fd); - // Make sure stream sockets are connected - if(socket.type == socket_type::STREAM && !socket.connected){ - return {-std::ERROR_SOCKET_NOT_CONNECTED, 0}; - } - auto return_from_packet = [&socket](std::expected& packet) -> std::tuple { if (packet) { auto fd = socket.register_packet(*packet); @@ -320,8 +308,7 @@ std::tuple network::prepare_packet(socket_fd_t socket_fd, void* case network::socket_protocol::TCP: { auto descriptor = static_cast(desc); - auto& interface = select_interface(socket.server_address); - auto packet = network::tcp::prepare_packet(buffer, interface, socket, descriptor->payload_size); + auto packet = network::tcp::prepare_packet(buffer, socket, descriptor->payload_size); return return_from_packet(packet); } @@ -355,11 +342,6 @@ std::expected network::finalize_packet(socket_fd_t socket_fd, size_t packe return std::make_unexpected(std::ERROR_SOCKET_INVALID_PACKET_FD); } - // Make sure stream sockets are connected - if(socket.type == socket_type::STREAM && !socket.connected){ - return std::make_unexpected(-std::ERROR_SOCKET_NOT_CONNECTED); - } - auto& packet = socket.get_packet(packet_fd); auto& interface = network::interface(packet.interface); @@ -427,25 +409,21 @@ std::expected network::connect(socket_fd_t socket_fd, network::ip::addre return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); } - socket.local_port = local_port++; - socket.server_port = port; - socket.server_address = server; + auto selected_port = local_port++; logging::logf(logging::log_level::TRACE, "network: %u stream socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port); if(socket.protocol == socket_protocol::TCP){ - auto connection = network::tcp::connect(socket, select_interface(server)); + auto connection = network::tcp::connect(socket, select_interface(server), selected_port, port, server); - if(connection){ - socket.connected = true; - } else { + if(!connection){ return std::make_unexpected(connection.error()); } } else { return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); } - return std::make_expected(socket.local_port); + return std::make_expected(selected_port); } std::expected network::disconnect(socket_fd_t socket_fd){ @@ -459,18 +437,12 @@ std::expected network::disconnect(socket_fd_t socket_fd){ return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); } - if(!socket.connected){ - return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); - } - logging::logf(logging::log_level::TRACE, "network: %u disconnect from stream socket %u\n", scheduler::get_pid(), socket_fd); if(socket.protocol == socket_protocol::TCP){ - auto disconnection = network::tcp::disconnect(socket, select_interface(socket.server_address)); + auto disconnection = network::tcp::disconnect(socket); - if(disconnection){ - socket.connected = false; - } else { + if(!disconnection){ return std::make_unexpected(disconnection.error()); } } else { @@ -567,30 +539,10 @@ void network::propagate_packet(const ethernet::packet& packet, socket_protocol p propagate = true; } } - } else if(socket.type == socket_type::STREAM){ - if(socket.protocol == protocol && socket.connected){ - auto local_port = socket.local_port; - auto server_port = socket.server_port; - - auto tcp_index = packet.tag(2); - auto* tcp_header = reinterpret_cast(packet.payload + tcp_index); - auto source_port = switch_endian_16(tcp_header->source_port); - auto target_port = switch_endian_16(tcp_header->target_port); - auto flags = switch_endian_16(tcp_header->flags); - - using flag_psh = std::bit_field; - - // Don't propagate ack - if (*flag_psh(&flags)) { - logging::logf(logging::log_level::TRACE, "network: propagate on socket %u\n", socket.id); - - if(local_port == target_port && server_port == source_port){ - propagate = true; - } - } - } } + // Note: Stream sockets are responsible for propagation + if (propagate) { auto copy = packet; copy.payload = new char[copy.payload_size]; diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index c631ff46..efb26a04 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -40,17 +40,20 @@ using flag_syn = std::bit_field; using flag_fin = std::bit_field; struct tcp_connection { - size_t source_port; ///< The source port of the connection - size_t target_port; ///< The target port of the connection + size_t local_port; ///< The local source port + size_t server_port; ///< The server port + network::ip::address server_address; ///< The server address std::atomic listening; ///< Indicates if a kernel thread is listening on this connection condition_variable queue; ///< The listening queue circular_buffer packets; ///< The packets for the listening queue - tcp_connection(size_t source_port, size_t target_port) - : source_port(source_port), target_port(target_port), listening(false) { - //Nothing else to init - } + bool connected = false; + + uint32_t ack_number = 0; ///< The next ack number + uint32_t seq_number = 0; ///< The next sequence number + + network::socket* socket = nullptr; }; // The lock used to protect the list of connections @@ -59,12 +62,12 @@ rw_lock connections_lock; // Note: We need a list to not invalidate the values during insertions std::list connections; -tcp_connection* get_connection(size_t source_port, size_t target_port) { +tcp_connection* get_connection_for_packet(size_t source_port, size_t target_port) { auto lock = connections_lock.reader_lock(); std::lock_guard l(lock); for(auto& connection : connections){ - if (connection.source_port == source_port && connection.target_port == target_port) { + if (connection.server_port == source_port && connection.local_port == target_port) { return &connection; } } @@ -72,11 +75,15 @@ tcp_connection* get_connection(size_t source_port, size_t target_port) { return nullptr; } -tcp_connection& create_connection(size_t target, size_t source){ +tcp_connection& create_connection(){ auto lock = connections_lock.writer_lock(); std::lock_guard l(lock); - return connections.emplace_back(target, source); + auto& connection = connections.emplace_back(); + + connection.listening = false; + + return connection; } void remove_connection(tcp_connection& connection){ @@ -178,35 +185,52 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth auto flags = switch_endian_16(tcp_header->flags); - // Propagate to kernel connections + auto next_seq = ack; + auto next_ack = seq + tcp_payload_len(packet);; - { - auto lock = connections_lock.reader_lock(); - std::lock_guard l(lock); + auto connection_ptr = get_connection_for_packet(source_port, target_port); - for (auto& connection : connections) { - if (connection.listening.load() && connection.source_port == source_port && connection.target_port == target_port) { + if(connection_ptr){ + auto& connection = *connection_ptr; + + // Update the connection status + + connection.seq_number = next_seq; + connection.ack_number = next_ack; + + // Propagate to kernel connections + + if (connection.listening.load()) { + auto copy = packet; + copy.payload = new char[copy.payload_size]; + std::copy_n(packet.payload, packet.payload_size, copy.payload); + + connection.packets.push(copy); + connection.queue.notify_one(); + } + + // Propagate to the kernel socket + + if (*flag_psh(&flags) && connection.socket) { + auto& socket = *connection.socket; + + packet.index += sizeof(header); + + if (socket.listen) { auto copy = packet; copy.payload = new char[copy.payload_size]; std::copy_n(packet.payload, packet.payload_size, copy.payload); - connection.packets.push(copy); - connection.queue.notify_one(); + socket.listen_packets.push(copy); + socket.listen_queue.notify_one(); } } + } else { + logging::logf(logging::log_level::DEBUG, "tcp: Received packet for which there are no connection\n"); } - auto seq_number = ack; - auto ack_number = seq + tcp_payload_len(packet); + // Acknowledge if necessary - //TODO socket.seq_number = ack; - //TODO socket.ack_number = seq + tcp_payload_len(packet); - - packet.index += sizeof(header); - - network::propagate_packet(packet, network::socket_protocol::TCP); - - // A push needs to be acknowledged if (*flag_psh(&flags)) { auto p = tcp::prepare_packet(interface, switch_endian_32(ip_header->source_ip), target_port, source_port, 0); @@ -217,12 +241,12 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth auto* ack_tcp_header = reinterpret_cast(p->payload + p->tag(2)); - ack_tcp_header->sequence_number = switch_endian_32(seq_number); - ack_tcp_header->ack_number = switch_endian_32(ack_number); + ack_tcp_header->sequence_number = switch_endian_32(next_seq); + ack_tcp_header->ack_number = switch_endian_32(next_ack); - auto flags = get_default_flags(); - (flag_ack(&flags)) = 1; - ack_tcp_header->flags = switch_endian_16(flags); + auto ack_flags = get_default_flags(); + (flag_ack(&ack_flags)) = 1; + ack_tcp_header->flags = switch_endian_16(ack_flags); tcp::finalize_packet(interface, *p); } @@ -239,15 +263,23 @@ std::expected network::tcp::prepare_packet(network::i return packet; } -std::expected network::tcp::prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size) { - auto target_ip = socket.server_address; +std::expected network::tcp::prepare_packet(char* buffer, network::socket& socket, size_t payload_size) { + auto& connection = socket.get_data(); + + // Make sure stream sockets are connected + if(!connection.connected){ + return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + } + + auto target_ip = connection.server_address; + auto& interface = network::select_interface(target_ip); // Ask the IP layer to craft a packet auto packet = network::ip::prepare_packet(buffer, interface, sizeof(header) + payload_size, target_ip, 0x06); if (packet) { - auto source = socket.local_port; - auto target = socket.server_port; + auto source = connection.local_port; + auto target = connection.server_port; ::prepare_packet(*packet, source, target); @@ -258,8 +290,8 @@ std::expected network::tcp::prepare_packet(char* buff (flag_ack(&flags)) = 1; tcp_header->flags = switch_endian_16(flags); - tcp_header->sequence_number = switch_endian_32(socket.seq_number); - tcp_header->ack_number = switch_endian_32(socket.ack_number); + tcp_header->sequence_number = switch_endian_32(connection.seq_number); + tcp_header->ack_number = switch_endian_32(connection.ack_number); } return packet; @@ -286,18 +318,14 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net return; //TODO Fail } - auto source = socket.local_port; - auto target = socket.server_port; + auto& connection = socket.get_data(); - auto connection_ptr = get_connection(target, source); - - if (!connection_ptr) { - logging::logf(logging::log_level::ERROR, "tcp: Unable to find connection!\n"); - return; //TODO Fail + // Make sure stream sockets are connected + if(!connection.connected){ + //TODO return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + return; } - auto& connection = *connection_ptr; - connection.listening = true; uint32_t seq = 0; @@ -357,22 +385,31 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net if(received){ // Set the future sequence and acknowledgement numbers - socket.seq_number = ack; - socket.ack_number = seq; + connection.seq_number = ack; + connection.ack_number = seq; } else { //TODO We need to be able to make finalize fail! } } -std::expected network::tcp::connect(network::socket& sock, network::interface_descriptor& interface) { - auto target_ip = sock.server_address; - auto source = sock.local_port; - auto target = sock.server_port; +std::expected network::tcp::connect(network::socket& sock, network::interface_descriptor& interface, size_t local_port, size_t server_port, network::ip::address server) { + // Create the connection - sock.seq_number = 0; - sock.ack_number = 0; + auto& connection = create_connection(); - auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0); + connection.local_port = local_port; + connection.server_port = server_port; + connection.server_address = server; + + // Link the socket and connection + sock.data = &connection; + connection.socket = &sock; + + // Prepare the SYN packet + + auto target_ip = connection.server_address; + + auto packet = tcp::prepare_packet(interface, target_ip, local_port, server_port, 0); if (!packet) { return std::make_unexpected(packet.error()); @@ -380,16 +417,12 @@ std::expected network::tcp::connect(network::socket& sock, network::interf auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); - tcp_header->sequence_number = 0; - tcp_header->ack_number = 0; - auto flags = get_default_flags(); (flag_syn(&flags)) = 1; tcp_header->flags = switch_endian_16(flags); - // Create the connection - - auto& connection = create_connection(target, source); + tcp_header->sequence_number = connection.seq_number; + tcp_header->ack_number = connection.ack_number; connection.listening = true; @@ -425,13 +458,13 @@ std::expected network::tcp::connect(network::socket& sock, network::interf connection.listening = false; - sock.seq_number = ack; - sock.ack_number = seq + 1; + connection.seq_number = ack; + connection.ack_number = seq + 1; // At this point we have received the SYN/ACK, only remains to ACK { - auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0); + auto packet = tcp::prepare_packet(interface, target_ip, connection.local_port, connection.server_port, 0); if (!packet) { return std::make_unexpected(packet.error()); @@ -439,8 +472,8 @@ std::expected network::tcp::connect(network::socket& sock, network::interf auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); - tcp_header->sequence_number = switch_endian_32(sock.seq_number); - tcp_header->ack_number = switch_endian_32(sock.ack_number); + tcp_header->sequence_number = switch_endian_32(connection.seq_number); + tcp_header->ack_number = switch_endian_32(connection.ack_number); auto flags = get_default_flags(); (flag_ack(&flags)) = 1; @@ -450,13 +483,25 @@ std::expected network::tcp::connect(network::socket& sock, network::interf tcp::finalize_packet(interface, *packet); } + // Mark the connection as connected + + connection.connected = true; + return {}; } -std::expected network::tcp::disconnect(network::socket& sock, network::interface_descriptor& interface) { - auto target_ip = sock.server_address; - auto source = sock.local_port; - auto target = sock.server_port; +std::expected network::tcp::disconnect(network::socket& sock) { + auto& connection = sock.get_data(); + + if(!connection.connected){ + return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + } + + auto target_ip = connection.server_address; + auto source = connection.local_port; + auto target = connection.server_port; + + auto& interface = network::select_interface(target_ip); auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0); @@ -466,23 +511,14 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); - tcp_header->sequence_number = switch_endian_32(sock.seq_number); - tcp_header->ack_number = switch_endian_32(sock.ack_number); + tcp_header->sequence_number = switch_endian_32(connection.seq_number); + tcp_header->ack_number = switch_endian_32(connection.ack_number); auto flags = get_default_flags(); (flag_fin(&flags)) = 1; (flag_ack(&flags)) = 1; tcp_header->flags = switch_endian_16(flags); - auto connection_ptr = get_connection(target, source); - - if (!connection_ptr) { - logging::logf(logging::log_level::ERROR, "tcp: Unable to find connection!\n"); - return std::make_unexpected(std::ERROR_SOCKET_INVALID_CONNECTION); - } - - auto& connection = *connection_ptr; - connection.listening = true; logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n"); @@ -529,8 +565,8 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int connection.listening = false; - sock.seq_number = ack; - sock.ack_number = seq + 1; + connection.seq_number = ack; + connection.ack_number = seq + 1; // At this point we have received the FIN/ACK, only remains to ACK @@ -543,8 +579,8 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); - tcp_header->sequence_number = switch_endian_32(sock.seq_number); - tcp_header->ack_number = switch_endian_32(sock.ack_number); + tcp_header->sequence_number = switch_endian_32(connection.seq_number); + tcp_header->ack_number = switch_endian_32(connection.ack_number); auto flags = get_default_flags(); (flag_ack(&flags)) = 1; @@ -554,6 +590,10 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int tcp::finalize_packet(interface, *packet); } + // Mark the connection as connected + + connection.connected = false; + remove_connection(connection); return {};