diff --git a/kernel/src/net/tcp_layer.cpp b/kernel/src/net/tcp_layer.cpp index 1361be9e..edfd76e1 100644 --- a/kernel/src/net/tcp_layer.cpp +++ b/kernel/src/net/tcp_layer.cpp @@ -65,9 +65,9 @@ network::connection_handler connections; void compute_checksum(network::ethernet::packet& packet) { auto* ip_header = reinterpret_cast(packet.payload + packet.tag(1)); - auto* tcp_header = reinterpret_cast(packet.payload + packet.index); + auto* tcp_header = reinterpret_cast(packet.payload + packet.tag(2)); - auto tcp_len = switch_endian_16(ip_header->total_len) - sizeof(network::ip::header); + auto tcp_len = switch_endian_16(ip_header->total_len) - (ip_header->version_ihl & 0xF) * 4; tcp_header->checksum = 0; @@ -87,10 +87,12 @@ void compute_checksum(network::ethernet::packet& packet) { tcp_header->checksum = switch_endian_16(network::checksum_finalize_nz(sum)); } +constexpr size_t default_tcp_header_length = 20; + uint16_t get_default_flags() { uint16_t flags = 0; // By default - (flag_data_offset(&flags)) = 5; // No options + (flag_data_offset(&flags)) = default_tcp_header_length / 4; // No options return flags; } @@ -106,7 +108,7 @@ void prepare_packet(network::ethernet::packet& packet, size_t source, size_t tar tcp_header->target_port = switch_endian_16(target); tcp_header->window_size = 1024; - packet.index += sizeof(network::tcp::header); + packet.index += default_tcp_header_length; } size_t tcp_payload_len(const network::ethernet::packet& packet){ @@ -126,7 +128,7 @@ size_t tcp_payload_len(const network::ethernet::packet& packet){ // This is used for raw answer std::expected kernel_prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size) { // Ask the IP layer to craft a packet - network::ip::packet_descriptor desc{sizeof(network::tcp::header) + payload_size, target_ip, 0x06}; + network::ip::packet_descriptor desc{payload_size + default_tcp_header_length, target_ip, 0x06}; auto packet = network::ip::kernel_prepare_packet(interface, desc); if (packet) { @@ -143,7 +145,7 @@ std::expected kernel_prepare_packet(network::interfac auto server_port = connection.server_port; // Ask the IP layer to craft a packet - network::ip::packet_descriptor desc{sizeof(network::tcp::header) + payload_size, target_ip, 0x06}; + network::ip::packet_descriptor desc{payload_size + default_tcp_header_length, target_ip, 0x06}; auto packet = network::ip::kernel_prepare_packet(interface, desc); if (packet) { @@ -160,7 +162,11 @@ std::expected kernel_prepare_packet(network::interfac // finalize without waiting for ACK std::expected finalize_packet_direct(network::interface_descriptor& interface, network::ethernet::packet& p) { - p.index -= sizeof(network::tcp::header); + auto* tcp_header = reinterpret_cast(p.payload + p.tag(2)); + + auto flags = switch_endian_16(tcp_header->flags); + + p.index -= *flag_data_offset(&flags) * 4; // Compute the checksum compute_checksum(p); @@ -224,7 +230,7 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth if (*flag_psh(&flags) && connection.socket) { auto& socket = *connection.socket; - packet.index += sizeof(header); + packet.index += *flag_data_offset(&flags) * 4; if (socket.listen) { auto copy = packet; @@ -364,7 +370,7 @@ std::expected network::tcp::user_prepare_packet(char* auto& interface = network::select_interface(target_ip); // Ask the IP layer to craft a packet - network::ip::packet_descriptor desc{sizeof(header) + descriptor->payload_size, target_ip, 0x06}; + network::ip::packet_descriptor desc{descriptor->payload_size + default_tcp_header_length, target_ip, 0x06}; auto packet = network::ip::user_prepare_packet(buffer, interface, &desc); if (packet) { @@ -388,7 +394,11 @@ std::expected network::tcp::user_prepare_packet(char* } std::expected network::tcp::finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p) { - p.index -= sizeof(header); + auto* tcp_header = reinterpret_cast(p.payload + p.tag(2)); + + auto flags = switch_endian_16(tcp_header->flags); + + p.index -= *flag_data_offset(&flags) * 4; // Compute the checksum compute_checksum(p);