From a71551720fc5d3e935450fc73e23344a46d6912b Mon Sep 17 00:00:00 2001 From: Baptiste Wicht Date: Thu, 15 Sep 2016 22:30:43 +0200 Subject: [PATCH] Allow to send TCP packets --- kernel/include/net/tcp_layer.hpp | 4 +- kernel/src/net/network.cpp | 24 +++++++ kernel/src/net/tcp_layer.cpp | 100 ++++++++++++++++++++++++---- programs/nc/src/main.cpp | 27 +++++++- tlib/include/tlib/net_constants.hpp | 8 +++ 5 files changed, 148 insertions(+), 15 deletions(-) diff --git a/kernel/include/net/tcp_layer.hpp b/kernel/include/net/tcp_layer.hpp index 982a7509..780d7c9d 100644 --- a/kernel/include/net/tcp_layer.hpp +++ b/kernel/include/net/tcp_layer.hpp @@ -33,8 +33,10 @@ struct header { void decode(network::interface_descriptor& interface, network::ethernet::packet& packet); std::expected prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size); -std::expected prepare_packet(char* buffer, network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size); +std::expected prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size); + void finalize_packet(network::interface_descriptor& interface, network::ethernet::packet& p); +void finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p); std::expected connect(network::socket& socket, network::interface_descriptor& interface); std::expected disconnect(network::socket& socket, network::interface_descriptor& interface); diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index c5940f8b..61d4cf08 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -285,6 +285,11 @@ std::tuple network::prepare_packet(socket_fd_t socket_fd, void* auto& socket = scheduler::get_socket(socket_fd); + // Make sure stream sockets are connected + if(socket.type == socket_type::STREAM && !socket.connected){ + return {-std::ERROR_SOCKET_NOT_CONNECTED, 0}; + } + auto return_from_packet = [&socket](std::expected& packet) -> std::tuple { if (packet) { auto fd = socket.register_packet(*packet); @@ -312,6 +317,14 @@ std::tuple network::prepare_packet(socket_fd_t socket_fd, void* return return_from_packet(packet); } + case network::socket_protocol::TCP: { + auto descriptor = static_cast(desc); + auto& interface = select_interface(socket.server_address); + auto packet = network::tcp::prepare_packet(buffer, interface, socket, descriptor->payload_size); + + return return_from_packet(packet); + } + case network::socket_protocol::DNS: { auto descriptor = static_cast(desc); auto& interface = select_interface(descriptor->target_ip); @@ -341,6 +354,11 @@ std::expected network::finalize_packet(socket_fd_t socket_fd, size_t packe return std::make_unexpected(std::ERROR_SOCKET_INVALID_PACKET_FD); } + // Make sure stream sockets are connected + if(socket.type == socket_type::STREAM && !socket.connected){ + return std::make_unexpected(-std::ERROR_SOCKET_NOT_CONNECTED); + } + auto& packet = socket.get_packet(packet_fd); auto& interface = network::interface(packet.interface); @@ -351,6 +369,12 @@ std::expected network::finalize_packet(socket_fd_t socket_fd, size_t packe return std::make_expected(); + case network::socket_protocol::TCP: + network::tcp::finalize_packet(interface, socket, packet); + socket.erase_packet(packet_fd); + + return std::make_expected(); + case network::socket_protocol::DNS: network::dns::finalize_packet(interface, packet); socket.erase_packet(packet_fd); diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index fd87d7d8..d996c621 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -52,10 +52,12 @@ void compute_checksum(network::ethernet::packet& packet) { auto* ip_header = reinterpret_cast(packet.payload + packet.tag(1)); auto* tcp_header = reinterpret_cast(packet.payload + packet.index); + auto tcp_len = switch_endian_16(ip_header->total_len) - sizeof(network::ip::header); + tcp_header->checksum = 0; // Accumulate the Payload - auto sum = network::checksum_add_bytes(packet.payload + packet.index, 20); // TODO What is the length + auto sum = network::checksum_add_bytes(packet.payload + packet.index, tcp_len); // Accumulate the IP addresses sum += network::checksum_add_bytes(&ip_header->source_ip, 8); @@ -63,8 +65,8 @@ void compute_checksum(network::ethernet::packet& packet) { // Accumulate the IP Protocol sum += ip_header->protocol; - // Accumulate the UDP length - sum += 20; + // Accumulate the TCP length + sum += tcp_len; // Complete the 1-complement sum tcp_header->checksum = switch_endian_16(network::checksum_finalize_nz(sum)); @@ -80,7 +82,7 @@ uint16_t get_default_flags() { return flags; } -void prepare_packet(network::ethernet::packet& packet, size_t source, size_t target, size_t payload_size) { +void prepare_packet(network::ethernet::packet& packet, size_t source, size_t target) { packet.tag(2, packet.index); // Set the TCP header @@ -92,8 +94,6 @@ void prepare_packet(network::ethernet::packet& packet, size_t source, size_t tar tcp_header->window_size = 1024; packet.index += sizeof(network::tcp::header); - - //TODO } } //end of anonymous namespace @@ -112,8 +112,6 @@ void network::tcp::decode(network::interface_descriptor& /*interface*/, network: logging::logf(logging::log_level::TRACE, "tcp: Source Port %u \n", size_t(source_port)); logging::logf(logging::log_level::TRACE, "tcp: Target Port %u \n", size_t(target_port)); - logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(tcp_header->sequence_number)); - logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(tcp_header->ack_number)); logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(sequence)); logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(ack)); @@ -147,18 +145,33 @@ std::expected network::tcp::prepare_packet(network::i auto packet = network::ip::prepare_packet(interface, sizeof(header) + payload_size, target_ip, 0x06); if (packet) { - ::prepare_packet(*packet, source, target, payload_size); + ::prepare_packet(*packet, source, target); } return packet; } -std::expected network::tcp::prepare_packet(char* buffer, network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size) { +std::expected network::tcp::prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size){ + auto target_ip = socket.server_address; + // Ask the IP layer to craft a packet auto packet = network::ip::prepare_packet(buffer, interface, sizeof(header) + payload_size, target_ip, 0x06); if (packet) { - ::prepare_packet(*packet, source, target, payload_size); + auto source = socket.local_port; + auto target = socket.server_port; + + ::prepare_packet(*packet, source, target); + + auto* tcp_header = reinterpret_cast(packet->payload + packet->tag(2)); + + auto flags = get_default_flags(); + (flag_psh(&flags)) = 1; + (flag_ack(&flags)) = 1; + tcp_header->flags = switch_endian_16(flags); + + tcp_header->sequence_number = switch_endian_32(socket.seq_number); + tcp_header->ack_number = switch_endian_32(socket.ack_number); } return packet; @@ -174,6 +187,69 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net network::ip::finalize_packet(interface, p); } +void network::tcp::finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p){ + p.index -= sizeof(header); + + // Compute the checksum + compute_checksum(p); + + auto source = socket.local_port; + auto target = socket.server_port; + + // TODO Wait for ACK or resend + + auto& listener = listeners.emplace_back(target, source); + + listener.active = true; + + // Give the packet to the IP layer for finalization + network::ip::finalize_packet(interface, p); + + uint32_t seq; + uint32_t ack; + + while (true) { + // TODO Need a timeout + listener.sem.acquire(); + auto received_packet = listener.packets.pop(); + + auto* tcp_header = reinterpret_cast(received_packet.payload + received_packet.index); + auto flags = switch_endian_16(tcp_header->flags); + + logging::logf(logging::log_level::TRACE, "tcp: Received answer\n"); + + if (*flag_ack(&flags)) { + logging::logf(logging::log_level::TRACE, "tcp: Received ACK\n"); + + seq = switch_endian_32(tcp_header->sequence_number); + ack = switch_endian_32(tcp_header->ack_number); + + delete[] received_packet.payload; + + break; + } + + delete[] received_packet.payload; + } + + listener.active = false; + + socket.seq_number = ack; + socket.ack_number = seq; + + auto end = listeners.end(); + auto it = listeners.begin(); + + while (it != end) { + if (&(*it) == &listener) { + listeners.erase(it); + break; + } + + ++it; + } +} + std::expected network::tcp::connect(network::socket& sock, network::interface_descriptor& interface) { auto target_ip = sock.server_address; auto source = sock.local_port; @@ -302,7 +378,7 @@ std::expected network::tcp::disconnect(network::socket& sock, network::int listener.active = true; - logging::logf(logging::log_level::TRACE, "tcp: Send FIN\n"); + logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n"); tcp::finalize_packet(interface, *packet); uint32_t seq; diff --git a/programs/nc/src/main.cpp b/programs/nc/src/main.cpp index 22df41cc..7dcd5b86 100644 --- a/programs/nc/src/main.cpp +++ b/programs/nc/src/main.cpp @@ -42,7 +42,30 @@ int main(int argc, char* argv[]) { sock.listen(true); if (!sock) { - tlib::printf("nc(2): socket error: %s\n", std::error_message(sock.error())); + tlib::printf("nc: socket error: %s\n", std::error_message(sock.error())); + return 1; + } + + tlib::tcp::packet_descriptor desc; + desc.payload_size = 4; + + auto packet = sock.prepare_packet(&desc); + + if (!sock) { + tlib::printf("nc: socket error: %s\n", std::error_message(sock.error())); + return 1; + } + + auto* payload = reinterpret_cast(packet.payload + packet.index); + payload[0] = 'T'; + payload[1] = 'H'; + payload[2] = 'O'; + payload[3] = 'R'; + + sock.finalize_packet(packet); + + if (!sock) { + tlib::printf("nc: socket error: %s\n", std::error_message(sock.error())); return 1; } @@ -51,7 +74,7 @@ int main(int argc, char* argv[]) { sock.listen(false); if (!sock) { - tlib::printf("nc(3): socket error: %s\n", std::error_message(sock.error())); + tlib::printf("nc: socket error: %s\n", std::error_message(sock.error())); return 1; } diff --git a/tlib/include/tlib/net_constants.hpp b/tlib/include/tlib/net_constants.hpp index 04c2d182..2a83355b 100644 --- a/tlib/include/tlib/net_constants.hpp +++ b/tlib/include/tlib/net_constants.hpp @@ -108,6 +108,14 @@ struct packet_descriptor { } // end of dns namespace +namespace tcp { + +struct packet_descriptor { + size_t payload_size; +}; + +} // end of tcp namespace + enum class socket_domain : size_t { AF_INET };