Allow to send TCP packets

This commit is contained in:
Baptiste Wicht 2016-09-15 22:30:43 +02:00
parent 91333c808d
commit a71551720f
5 changed files with 148 additions and 15 deletions

View File

@ -33,8 +33,10 @@ struct header {
void decode(network::interface_descriptor& interface, network::ethernet::packet& packet);
std::expected<network::ethernet::packet> prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size);
std::expected<network::ethernet::packet> prepare_packet(char* buffer, network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size);
std::expected<network::ethernet::packet> prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size);
void finalize_packet(network::interface_descriptor& interface, network::ethernet::packet& p);
void finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p);
std::expected<void> connect(network::socket& socket, network::interface_descriptor& interface);
std::expected<void> disconnect(network::socket& socket, network::interface_descriptor& interface);

View File

@ -285,6 +285,11 @@ std::tuple<size_t, size_t> network::prepare_packet(socket_fd_t socket_fd, void*
auto& socket = scheduler::get_socket(socket_fd);
// Make sure stream sockets are connected
if(socket.type == socket_type::STREAM && !socket.connected){
return {-std::ERROR_SOCKET_NOT_CONNECTED, 0};
}
auto return_from_packet = [&socket](std::expected<network::ethernet::packet>& packet) -> std::tuple<size_t, size_t> {
if (packet) {
auto fd = socket.register_packet(*packet);
@ -312,6 +317,14 @@ std::tuple<size_t, size_t> network::prepare_packet(socket_fd_t socket_fd, void*
return return_from_packet(packet);
}
case network::socket_protocol::TCP: {
auto descriptor = static_cast<network::tcp::packet_descriptor*>(desc);
auto& interface = select_interface(socket.server_address);
auto packet = network::tcp::prepare_packet(buffer, interface, socket, descriptor->payload_size);
return return_from_packet(packet);
}
case network::socket_protocol::DNS: {
auto descriptor = static_cast<network::dns::packet_descriptor*>(desc);
auto& interface = select_interface(descriptor->target_ip);
@ -341,6 +354,11 @@ std::expected<void> network::finalize_packet(socket_fd_t socket_fd, size_t packe
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_PACKET_FD);
}
// Make sure stream sockets are connected
if(socket.type == socket_type::STREAM && !socket.connected){
return std::make_unexpected<void>(-std::ERROR_SOCKET_NOT_CONNECTED);
}
auto& packet = socket.get_packet(packet_fd);
auto& interface = network::interface(packet.interface);
@ -351,6 +369,12 @@ std::expected<void> network::finalize_packet(socket_fd_t socket_fd, size_t packe
return std::make_expected();
case network::socket_protocol::TCP:
network::tcp::finalize_packet(interface, socket, packet);
socket.erase_packet(packet_fd);
return std::make_expected();
case network::socket_protocol::DNS:
network::dns::finalize_packet(interface, packet);
socket.erase_packet(packet_fd);

View File

@ -52,10 +52,12 @@ void compute_checksum(network::ethernet::packet& packet) {
auto* ip_header = reinterpret_cast<network::ip::header*>(packet.payload + packet.tag(1));
auto* tcp_header = reinterpret_cast<network::tcp::header*>(packet.payload + packet.index);
auto tcp_len = switch_endian_16(ip_header->total_len) - sizeof(network::ip::header);
tcp_header->checksum = 0;
// Accumulate the Payload
auto sum = network::checksum_add_bytes(packet.payload + packet.index, 20); // TODO What is the length
auto sum = network::checksum_add_bytes(packet.payload + packet.index, tcp_len);
// Accumulate the IP addresses
sum += network::checksum_add_bytes(&ip_header->source_ip, 8);
@ -63,8 +65,8 @@ void compute_checksum(network::ethernet::packet& packet) {
// Accumulate the IP Protocol
sum += ip_header->protocol;
// Accumulate the UDP length
sum += 20;
// Accumulate the TCP length
sum += tcp_len;
// Complete the 1-complement sum
tcp_header->checksum = switch_endian_16(network::checksum_finalize_nz(sum));
@ -80,7 +82,7 @@ uint16_t get_default_flags() {
return flags;
}
void prepare_packet(network::ethernet::packet& packet, size_t source, size_t target, size_t payload_size) {
void prepare_packet(network::ethernet::packet& packet, size_t source, size_t target) {
packet.tag(2, packet.index);
// Set the TCP header
@ -92,8 +94,6 @@ void prepare_packet(network::ethernet::packet& packet, size_t source, size_t tar
tcp_header->window_size = 1024;
packet.index += sizeof(network::tcp::header);
//TODO
}
} //end of anonymous namespace
@ -112,8 +112,6 @@ void network::tcp::decode(network::interface_descriptor& /*interface*/, network:
logging::logf(logging::log_level::TRACE, "tcp: Source Port %u \n", size_t(source_port));
logging::logf(logging::log_level::TRACE, "tcp: Target Port %u \n", size_t(target_port));
logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(tcp_header->sequence_number));
logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(tcp_header->ack_number));
logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(sequence));
logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(ack));
@ -147,18 +145,33 @@ std::expected<network::ethernet::packet> network::tcp::prepare_packet(network::i
auto packet = network::ip::prepare_packet(interface, sizeof(header) + payload_size, target_ip, 0x06);
if (packet) {
::prepare_packet(*packet, source, target, payload_size);
::prepare_packet(*packet, source, target);
}
return packet;
}
std::expected<network::ethernet::packet> network::tcp::prepare_packet(char* buffer, network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size) {
std::expected<network::ethernet::packet> network::tcp::prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size){
auto target_ip = socket.server_address;
// Ask the IP layer to craft a packet
auto packet = network::ip::prepare_packet(buffer, interface, sizeof(header) + payload_size, target_ip, 0x06);
if (packet) {
::prepare_packet(*packet, source, target, payload_size);
auto source = socket.local_port;
auto target = socket.server_port;
::prepare_packet(*packet, source, target);
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
auto flags = get_default_flags();
(flag_psh(&flags)) = 1;
(flag_ack(&flags)) = 1;
tcp_header->flags = switch_endian_16(flags);
tcp_header->sequence_number = switch_endian_32(socket.seq_number);
tcp_header->ack_number = switch_endian_32(socket.ack_number);
}
return packet;
@ -174,6 +187,69 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net
network::ip::finalize_packet(interface, p);
}
void network::tcp::finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p){
p.index -= sizeof(header);
// Compute the checksum
compute_checksum(p);
auto source = socket.local_port;
auto target = socket.server_port;
// TODO Wait for ACK or resend
auto& listener = listeners.emplace_back(target, source);
listener.active = true;
// Give the packet to the IP layer for finalization
network::ip::finalize_packet(interface, p);
uint32_t seq;
uint32_t ack;
while (true) {
// TODO Need a timeout
listener.sem.acquire();
auto received_packet = listener.packets.pop();
auto* tcp_header = reinterpret_cast<header*>(received_packet.payload + received_packet.index);
auto flags = switch_endian_16(tcp_header->flags);
logging::logf(logging::log_level::TRACE, "tcp: Received answer\n");
if (*flag_ack(&flags)) {
logging::logf(logging::log_level::TRACE, "tcp: Received ACK\n");
seq = switch_endian_32(tcp_header->sequence_number);
ack = switch_endian_32(tcp_header->ack_number);
delete[] received_packet.payload;
break;
}
delete[] received_packet.payload;
}
listener.active = false;
socket.seq_number = ack;
socket.ack_number = seq;
auto end = listeners.end();
auto it = listeners.begin();
while (it != end) {
if (&(*it) == &listener) {
listeners.erase(it);
break;
}
++it;
}
}
std::expected<void> network::tcp::connect(network::socket& sock, network::interface_descriptor& interface) {
auto target_ip = sock.server_address;
auto source = sock.local_port;
@ -302,7 +378,7 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
listener.active = true;
logging::logf(logging::log_level::TRACE, "tcp: Send FIN\n");
logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n");
tcp::finalize_packet(interface, *packet);
uint32_t seq;

View File

@ -42,7 +42,30 @@ int main(int argc, char* argv[]) {
sock.listen(true);
if (!sock) {
tlib::printf("nc(2): socket error: %s\n", std::error_message(sock.error()));
tlib::printf("nc: socket error: %s\n", std::error_message(sock.error()));
return 1;
}
tlib::tcp::packet_descriptor desc;
desc.payload_size = 4;
auto packet = sock.prepare_packet(&desc);
if (!sock) {
tlib::printf("nc: socket error: %s\n", std::error_message(sock.error()));
return 1;
}
auto* payload = reinterpret_cast<char*>(packet.payload + packet.index);
payload[0] = 'T';
payload[1] = 'H';
payload[2] = 'O';
payload[3] = 'R';
sock.finalize_packet(packet);
if (!sock) {
tlib::printf("nc: socket error: %s\n", std::error_message(sock.error()));
return 1;
}
@ -51,7 +74,7 @@ int main(int argc, char* argv[]) {
sock.listen(false);
if (!sock) {
tlib::printf("nc(3): socket error: %s\n", std::error_message(sock.error()));
tlib::printf("nc: socket error: %s\n", std::error_message(sock.error()));
return 1;
}

View File

@ -108,6 +108,14 @@ struct packet_descriptor {
} // end of dns namespace
namespace tcp {
struct packet_descriptor {
size_t payload_size;
};
} // end of tcp namespace
enum class socket_domain : size_t {
AF_INET
};