From caf8d97abf2e41eda1ef276b0f20ca14a75e6214 Mon Sep 17 00:00:00 2001 From: Baptiste Wicht Date: Tue, 20 Sep 2016 20:32:26 +0200 Subject: [PATCH] Fix TCP disconnection --- kernel/src/net/tcp_layer.cpp | 158 +++++++++++++++++++++++++++++++++-- 1 file changed, 151 insertions(+), 7 deletions(-) diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index 9f30d0a4..e4d6328b 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -59,6 +59,9 @@ struct tcp_connection { tcp_connection(){ listening = false; } + + tcp_connection(const tcp_connection& rhs) = delete; + tcp_connection& operator=(const tcp_connection& rhs) = delete; }; network::connection_handler connections; @@ -104,9 +107,10 @@ void prepare_packet(network::ethernet::packet& packet, size_t source, size_t tar auto* tcp_header = reinterpret_cast(packet.payload + packet.index); - tcp_header->source_port = switch_endian_16(source); - tcp_header->target_port = switch_endian_16(target); - tcp_header->window_size = switch_endian_16(1024); + tcp_header->source_port = switch_endian_16(source); + tcp_header->target_port = switch_endian_16(target); + tcp_header->window_size = switch_endian_16(1024); + tcp_header->urgent_pointer = 0; packet.index += default_tcp_header_length; } @@ -605,13 +609,153 @@ std::expected network::tcp::disconnect(network::socket& sock) { logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n"); - auto status = tcp::finalize_packet(interface, sock, *packet); + bool rec_fin_ack = false; + bool rec_ack = false; - if(!status){ - return status; + uint32_t seq = 0; + uint32_t ack = 0; + + bool received = false; + + for(size_t t = 0; t < max_tries; ++t){ + // Give the packet to the IP layer for finalization + auto copy = *packet; + + copy.payload = new char[packet->payload_size]; + + std::copy_n(packet->payload, packet->payload_size, copy.payload); + + auto result = finalize_packet_direct(interface, copy); + + if(!result){ + delete[] copy.payload; + delete[] packet->payload; + + return result; + } + + auto before = timer::milliseconds(); + auto after = before; + + while(true){ + // Make sure we don't wait for more than the timeout + if (after > before + timeout_ms) { + break; + } + + auto remaining = timeout_ms - (after - before); + + if(connection.queue.wait_for(remaining)){ + auto received_packet = connection.packets.pop(); + + auto* tcp_header = reinterpret_cast(received_packet.payload + received_packet.index); + auto flags = switch_endian_16(tcp_header->flags); + + bool correct_ack = false; + if (*flag_fin(&flags) && *flag_ack(&flags)) { + rec_fin_ack = true; + correct_ack = true; + } else if (*flag_ack(&flags)) { + rec_ack = true; + correct_ack = true; + } + + if (correct_ack) { + seq = switch_endian_32(tcp_header->sequence_number); + ack = switch_endian_32(tcp_header->ack_number); + + delete[] received_packet.payload; + + received = true; + + break; + } + + delete[] received_packet.payload; + } else { + break; + } + } + + if(received){ + break; + } + + after = timer::milliseconds(); } - // At this point we have received the FIN/ACK, only remains to ACK + // Release the memory of the original memory since it was copied + delete[] packet->payload; + + if(!received){ + return std::make_unexpected(std::ERROR_SOCKET_TCP_ERROR); + } + + // Set the future sequence and acknowledgement numbers + connection.seq_number = ack; + connection.ack_number = seq + 1; + + // If we received an ACK, we must wait for a FIN/ACK from the server now + if(rec_ack){ + logging::logf(logging::log_level::TRACE, "tcp: Received ACK waiting for FIN/ACK\n"); + + received = false; + + auto before = timer::milliseconds(); + auto after = before; + + while(true){ + // Make sure we don't wait for more than the timeout + if (after > before + timeout_ms) { + break; + } + + auto remaining = timeout_ms - (after - before); + + if(connection.packets.empty()){ + if(!connection.queue.wait_for(remaining)){ + break; + } + } + + auto received_packet = connection.packets.pop(); + + auto* tcp_header = reinterpret_cast(received_packet.payload + received_packet.index); + auto flags = switch_endian_16(tcp_header->flags); + + bool correct_ack = *flag_fin(&flags) && *flag_ack(&flags); + + if (correct_ack) { + seq = switch_endian_32(tcp_header->sequence_number); + ack = switch_endian_32(tcp_header->ack_number); + + delete[] received_packet.payload; + + received = true; + + break; + } + + delete[] received_packet.payload; + } + + if(!received){ + return std::make_unexpected(std::ERROR_SOCKET_TCP_ERROR); + } + + // Set the future sequence and acknowledgement numbers + connection.seq_number = ack; + connection.ack_number = seq + 1; + + logging::logf(logging::log_level::TRACE, "tcp: Received FIN/ACK waiting for ACK\n"); + } else { + logging::logf(logging::log_level::TRACE, "tcp: Received FIN/ACK directly waiting for ACK\n"); + } + + // Stop listening + connection.listening = false; + + // Finally we send the ACK for the FIN/ACK { auto packet = kernel_prepare_packet(interface, connection, 0);