Complete TCP Handshake

This commit is contained in:
Baptiste Wicht 2016-09-14 22:08:51 +02:00
parent 5fd3959eee
commit 958aaebf72

View File

@ -6,12 +6,16 @@
//======================================================================= //=======================================================================
#include <bit_field.hpp> #include <bit_field.hpp>
#include <atomic.hpp>
#include <list.hpp>
#include <semaphore.hpp>
#include "net/tcp_layer.hpp" #include "net/tcp_layer.hpp"
#include "net/dns_layer.hpp" #include "net/dns_layer.hpp"
#include "net/checksum.hpp" #include "net/checksum.hpp"
#include "kernel_utils.hpp" #include "kernel_utils.hpp"
#include "circular_buffer.hpp"
namespace { namespace {
@ -27,6 +31,22 @@ using flag_rst = std::bit_field<uint16_t, uint8_t, 2, 1>;
using flag_syn = std::bit_field<uint16_t, uint8_t, 1, 1>; using flag_syn = std::bit_field<uint16_t, uint8_t, 1, 1>;
using flag_fin = std::bit_field<uint16_t, uint8_t, 0, 1>; using flag_fin = std::bit_field<uint16_t, uint8_t, 0, 1>;
struct tcp_listener {
size_t source_port;
size_t target_port;
std::atomic<bool> active;
semaphore sem;
circular_buffer<network::ethernet::packet, 8> packets;
tcp_listener(size_t source_port, size_t target_port)
: source_port(source_port), target_port(target_port), active(false) {
sem.init(0);
}
};
// Note: We need a list to not invalidate the values during insertions
std::list<tcp_listener> listeners;
void compute_checksum(network::ethernet::packet& packet){ void compute_checksum(network::ethernet::packet& packet){
auto* ip_header = reinterpret_cast<network::ip::header*>(packet.payload + packet.tag(1)); auto* ip_header = reinterpret_cast<network::ip::header*>(packet.payload + packet.tag(1));
auto* tcp_header = reinterpret_cast<network::tcp::header*>(packet.payload + packet.index); auto* tcp_header = reinterpret_cast<network::tcp::header*>(packet.payload + packet.index);
@ -70,14 +90,14 @@ void prepare_packet(network::ethernet::packet& packet, size_t source, size_t tar
tcp_header->target_port = switch_endian_16(target); tcp_header->target_port = switch_endian_16(target);
tcp_header->window_size = 1024; tcp_header->window_size = 1024;
//TODO
packet.index += sizeof(network::tcp::header); packet.index += sizeof(network::tcp::header);
//TODO
} }
} //end of anonymous namespace } //end of anonymous namespace
void network::tcp::decode(network::interface_descriptor& interface, network::ethernet::packet& packet){ void network::tcp::decode(network::interface_descriptor& /*interface*/, network::ethernet::packet& packet){
packet.tag(2, packet.index); packet.tag(2, packet.index);
auto* tcp_header = reinterpret_cast<header*>(packet.payload + packet.index); auto* tcp_header = reinterpret_cast<header*>(packet.payload + packet.index);
@ -86,9 +106,36 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
auto source_port = switch_endian_16(tcp_header->source_port); auto source_port = switch_endian_16(tcp_header->source_port);
auto target_port = switch_endian_16(tcp_header->target_port); auto target_port = switch_endian_16(tcp_header->target_port);
auto sequence = switch_endian_32(tcp_header->sequence_number);
auto ack = switch_endian_32(tcp_header->ack_number);
logging::logf(logging::log_level::TRACE, "tcp: Source Port %u \n", size_t(source_port));
logging::logf(logging::log_level::TRACE, "tcp: Target Port %u \n", size_t(target_port));
logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(tcp_header->sequence_number));
logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(tcp_header->ack_number));
logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(sequence));
logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(ack));
// Propagate to kernel listeners
auto end = listeners.end();
auto it = listeners.begin();
while(it != end){
auto& listener = *it;
if(listener.active.load() && listener.source_port == source_port && listener.target_port == target_port){
auto copy = packet;
copy.payload = new char[copy.payload_size];
std::copy_n(packet.payload, packet.payload_size, copy.payload);
listener.packets.push(copy);
listener.sem.release();
}
++it;
}
logging::logf(logging::log_level::TRACE, "tcp: Source Port %h \n", source_port);
logging::logf(logging::log_level::TRACE, "tcp: Target Port %h \n", target_port);
packet.index += sizeof(header); packet.index += sizeof(header);
@ -143,7 +190,66 @@ std::expected<void> network::tcp::connect(network::interface_descriptor& interfa
(flag_syn(&flags)) = 1; (flag_syn(&flags)) = 1;
tcp_header->flags = switch_endian_16(flags); tcp_header->flags = switch_endian_16(flags);
// Create the listener
auto& listener = listeners.emplace_back(target, source);
listener.active = true;
logging::logf(logging::log_level::TRACE, "tcp: Send SYN\n");
tcp::finalize_packet(interface, *packet); tcp::finalize_packet(interface, *packet);
uint32_t seq;
uint32_t ack;
while(true){
// TODO Need a timeout
listener.sem.acquire();
auto received_packet = listener.packets.pop();
auto* tcp_header = reinterpret_cast<header*>(received_packet.payload + received_packet.index);
auto flags = switch_endian_16(tcp_header->flags);
logging::logf(logging::log_level::TRACE, "tcp: Received answer\n");
if(*flag_syn(&flags) && *flag_ack(&flags)){
logging::logf(logging::log_level::TRACE, "tcp: Received SYN/ACK\n");
seq = switch_endian_32(tcp_header->sequence_number);
ack = switch_endian_32(tcp_header->ack_number);
//TODO Release packet
break;
}
//TODO Release packet
}
listener.active = false;
// At this point we have received the SYN/ACK, only remains to ACK
{
auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0);
if (!packet) {
return std::make_unexpected<void>(packet.error());
}
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
tcp_header->sequence_number = switch_endian_32(ack);
tcp_header->ack_number = switch_endian_32(seq + 1);
auto flags = get_default_flags();
(flag_ack(&flags)) = 1;
tcp_header->flags = switch_endian_16(flags);
logging::logf(logging::log_level::TRACE, "tcp: Send ACK\n");
tcp::finalize_packet(interface, *packet);
}
//TODO Remove the listener
return {}; return {};
} }