diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index 231f05aa..2fc725d5 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -6,12 +6,16 @@ //======================================================================= #include +#include +#include +#include #include "net/tcp_layer.hpp" #include "net/dns_layer.hpp" #include "net/checksum.hpp" #include "kernel_utils.hpp" +#include "circular_buffer.hpp" namespace { @@ -27,6 +31,22 @@ using flag_rst = std::bit_field; using flag_syn = std::bit_field; using flag_fin = std::bit_field; +struct tcp_listener { + size_t source_port; + size_t target_port; + std::atomic active; + semaphore sem; + circular_buffer packets; + + tcp_listener(size_t source_port, size_t target_port) + : source_port(source_port), target_port(target_port), active(false) { + sem.init(0); + } +}; + +// Note: We need a list to not invalidate the values during insertions +std::list listeners; + void compute_checksum(network::ethernet::packet& packet){ auto* ip_header = reinterpret_cast(packet.payload + packet.tag(1)); auto* tcp_header = reinterpret_cast(packet.payload + packet.index); @@ -70,14 +90,14 @@ void prepare_packet(network::ethernet::packet& packet, size_t source, size_t tar tcp_header->target_port = switch_endian_16(target); tcp_header->window_size = 1024; - //TODO - packet.index += sizeof(network::tcp::header); + + //TODO } } //end of anonymous namespace -void network::tcp::decode(network::interface_descriptor& interface, network::ethernet::packet& packet){ +void network::tcp::decode(network::interface_descriptor& /*interface*/, network::ethernet::packet& packet){ packet.tag(2, packet.index); auto* tcp_header = reinterpret_cast(packet.payload + packet.index); @@ -86,9 +106,36 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth auto source_port = switch_endian_16(tcp_header->source_port); auto target_port = switch_endian_16(tcp_header->target_port); + auto sequence = switch_endian_32(tcp_header->sequence_number); + auto ack = switch_endian_32(tcp_header->ack_number); + + logging::logf(logging::log_level::TRACE, "tcp: Source Port %u \n", size_t(source_port)); + logging::logf(logging::log_level::TRACE, "tcp: Target Port %u \n", size_t(target_port)); + logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(tcp_header->sequence_number)); + logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(tcp_header->ack_number)); + logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(sequence)); + logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(ack)); + + // Propagate to kernel listeners + + auto end = listeners.end(); + auto it = listeners.begin(); + + while(it != end){ + auto& listener = *it; + + if(listener.active.load() && listener.source_port == source_port && listener.target_port == target_port){ + auto copy = packet; + copy.payload = new char[copy.payload_size]; + std::copy_n(packet.payload, packet.payload_size, copy.payload); + + listener.packets.push(copy); + listener.sem.release(); + } + + ++it; + } - logging::logf(logging::log_level::TRACE, "tcp: Source Port %h \n", source_port); - logging::logf(logging::log_level::TRACE, "tcp: Target Port %h \n", target_port); packet.index += sizeof(header); @@ -143,7 +190,66 @@ std::expected network::tcp::connect(network::interface_descriptor& interfa (flag_syn(&flags)) = 1; tcp_header->flags = switch_endian_16(flags); + // Create the listener + + auto& listener = listeners.emplace_back(target, source); + + listener.active = true; + + logging::logf(logging::log_level::TRACE, "tcp: Send SYN\n"); tcp::finalize_packet(interface, *packet); + uint32_t seq; + uint32_t ack; + + while(true){ + // TODO Need a timeout + listener.sem.acquire(); + auto received_packet = listener.packets.pop(); + + auto* tcp_header = reinterpret_cast(received_packet.payload + received_packet.index); + auto flags = switch_endian_16(tcp_header->flags); + + logging::logf(logging::log_level::TRACE, "tcp: Received answer\n"); + + if(*flag_syn(&flags) && *flag_ack(&flags)){ + logging::logf(logging::log_level::TRACE, "tcp: Received SYN/ACK\n"); + + seq = switch_endian_32(tcp_header->sequence_number); + ack = switch_endian_32(tcp_header->ack_number); + + //TODO Release packet + break; + } + + //TODO Release packet + } + + listener.active = false; + + // At this point we have received the SYN/ACK, only remains to ACK + + { + auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0); + + if (!packet) { + return std::make_unexpected(packet.error()); + } + + auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); + + tcp_header->sequence_number = switch_endian_32(ack); + tcp_header->ack_number = switch_endian_32(seq + 1); + + auto flags = get_default_flags(); + (flag_ack(&flags)) = 1; + tcp_header->flags = switch_endian_16(flags); + + logging::logf(logging::log_level::TRACE, "tcp: Send ACK\n"); + tcp::finalize_packet(interface, *packet); + } + + //TODO Remove the listener + return {}; }