From 853ab4d421607a01f6b752afeb868da740c87a25 Mon Sep 17 00:00:00 2001 From: Baptiste Wicht Date: Wed, 5 Oct 2016 20:05:01 +0200 Subject: [PATCH] Accept with timeout --- kernel/src/net/tcp_layer.cpp | 129 ++++++++++++++++++++++++++++++++++- 1 file changed, 128 insertions(+), 1 deletion(-) diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index 86bc3e4d..fb54e167 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -688,7 +688,134 @@ std::expected network::tcp::layer::accept(network::socket& socket){ } std::expected network::tcp::layer::accept(network::socket& socket, size_t ms){ - //TODO + auto& connection = socket.get_connection_data(); + + if(!connection.connected){ + return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + } + + // 1. Wait for SYN + + connection.listening = true; + + logging::logf(logging::log_level::TRACE, "tcp:accept: wait for connection\n"); + + uint32_t ack = 0; + uint32_t seq = 0; + + uint16_t source_port = 0; + uint16_t target_port = 0; + uint32_t source_address = 0; + + auto before = timer::milliseconds(); + auto after = before; + + while (true) { + // Make sure we don't wait for more than the timeout + if (after > before + ms) { + return std::make_unexpected(std::ERROR_SOCKET_TIMEOUT); + } + + auto remaining = ms - (after - before); + + if (connection.packets.empty()) { + if (!connection.queue.wait_for(remaining)) { + return std::make_unexpected(std::ERROR_SOCKET_TIMEOUT); + } + } + + auto received_packet = connection.packets.top(); + connection.packets.pop(); + + auto* tcp_header = reinterpret_cast(received_packet->payload + received_packet->index); + auto flags = switch_endian_16(tcp_header->flags); + + if (*flag_syn(&flags)) { + seq = switch_endian_32(tcp_header->sequence_number); + ack = switch_endian_32(tcp_header->ack_number); + + source_port = switch_endian_16(tcp_header->source_port); + target_port = switch_endian_16(tcp_header->target_port); + + auto* ip_header = reinterpret_cast(received_packet->payload + received_packet->tag(1)); + + source_address = switch_endian_32(ip_header->source_ip); + + break; + } + + after = timer::milliseconds(); + } + + logging::logf(logging::log_level::TRACE, "tcp:accept: received SYN from %h\n", source_address); + + connection.listening = false; + + // Set the future sequence and acknowledgement numbers + connection.seq_number = ack; + connection.ack_number = seq + 1; + + // 2. Prepare the child connection + + auto child_fd = scheduler::register_new_socket(socket.domain, socket.type, socket.protocol); + auto& child_sock = scheduler::get_socket(child_fd); + + logging::logf(logging::log_level::TRACE, "tcp:accept: Register new socket %u\n", child_fd); + + // Create the connection + + auto& child_connection = connections.create_connection(); + + child_connection.child = true; + + child_connection.local_port = target_port; + child_connection.server_port = source_port; + child_connection.server_address = source_address; + + // Link the socket and connection + child_sock.connection_data = &child_connection; + child_connection.socket = &child_sock; + + // Child connection numbers + child_connection.seq_number = connection.seq_number; + child_connection.ack_number = connection.ack_number; + + child_connection.connected = true; + + auto& interface = network::select_interface(source_address); + + // 3. Send SYN/ACK + + { + auto p = kernel_prepare_packet(interface, child_connection, 0); + + if (!p) { + return std::make_unexpected(p.error()); + } + + auto& packet = *p; + + auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); + + auto flags = get_default_flags(); + (flag_ack(&flags)) = 1; + (flag_syn(&flags)) = 1; + tcp_header->flags = switch_endian_16(flags); + + logging::logf(logging::log_level::TRACE, "tcp:accept: Send SYN/ACK %h\n", size_t(flags)); + + auto status = finalize_packet(interface, child_sock, packet); + + if(!status){ + return std::make_unexpected(status.error()); + } + } + + // The ACK is enforced by finalize_packet + + logging::logf(logging::log_level::TRACE, "tcp:accept: Done\n"); + + return {child_fd}; } std::expected network::tcp::layer::server_start(network::socket& sock, size_t server_port, network::ip::address server) {