diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index 8f84c54a..efc39f7b 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -16,9 +16,13 @@ #include "kernel_utils.hpp" #include "circular_buffer.hpp" +#include "timer.hpp" namespace { +static constexpr size_t timeout_ms = 1000; +static constexpr size_t max_tries = 5; + using flag_data_offset = std::bit_field; using flag_reserved = std::bit_field; using flag_ns = std::bit_field; @@ -207,6 +211,11 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net // Compute the checksum compute_checksum(p); + if (!p.user) { + logging::logf(logging::log_level::ERROR, "tcp: Function uniquely implemented for user packets!\n"); + return; //TODO Fail + } + auto source = socket.local_port; auto target = socket.server_port; @@ -214,47 +223,75 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net if (!listener_ptr) { logging::logf(logging::log_level::ERROR, "tcp: Unable to find listener!\n"); - return; + return; //TODO Fail } auto& listener = *listener_ptr; listener.active = true; - // Give the packet to the IP layer for finalization - network::ip::finalize_packet(interface, p); + uint32_t seq = 0; + uint32_t ack = 0; - uint32_t seq; - uint32_t ack; + bool received = false; - while (true) { - // TODO Need a timeout - listener.queue.sleep(); - auto received_packet = listener.packets.pop(); + for(size_t t = 0; t < max_tries; ++t){ + // Give the packet to the IP layer for finalization + network::ip::finalize_packet(interface, p); - auto* tcp_header = reinterpret_cast(received_packet.payload + received_packet.index); - auto flags = switch_endian_16(tcp_header->flags); + auto before = timer::milliseconds(); + auto after = before; - logging::logf(logging::log_level::TRACE, "tcp: Received answer\n"); + while(true){ + // Make sure we don't wait for more than the timeout + if (after > before + timeout_ms) { + break; + } - if (*flag_ack(&flags)) { - logging::logf(logging::log_level::TRACE, "tcp: Received ACK\n"); + auto remaining = timeout_ms - (after - before); - seq = switch_endian_32(tcp_header->sequence_number); - ack = switch_endian_32(tcp_header->ack_number); + if(listener.queue.sleep(remaining)){ + auto received_packet = listener.packets.pop(); - delete[] received_packet.payload; + auto* tcp_header = reinterpret_cast(received_packet.payload + received_packet.index); + auto flags = switch_endian_16(tcp_header->flags); + if (*flag_ack(&flags)) { + logging::logf(logging::log_level::TRACE, "tcp: Received ACK\n"); + + seq = switch_endian_32(tcp_header->sequence_number); + ack = switch_endian_32(tcp_header->ack_number); + + delete[] received_packet.payload; + + received = true; + + break; + } else { + logging::logf(logging::log_level::TRACE, "tcp: Received unrelated answer\n"); + } + + delete[] received_packet.payload; + } + } + + if(received){ break; } - delete[] received_packet.payload; + after = timer::milliseconds(); } + // Stop listening listener.active = false; - socket.seq_number = ack; - socket.ack_number = seq; + if(received){ + // Set the future sequence and acknowledgement numbers + socket.seq_number = ack; + socket.ack_number = seq; + } else { + //TODO We need to be able to make finalize fail! + } } std::expected network::tcp::connect(network::socket& sock, network::interface_descriptor& interface) {