mirror of
https://github.com/wichtounet/thor-os.git
synced 2025-08-04 01:36:10 -04:00
Delete all TCP work to the TCP layer
This commit is contained in:
parent
f55de06866
commit
9d5da4769d
@ -67,6 +67,8 @@ size_t number_of_interfaces();
|
|||||||
|
|
||||||
interface_descriptor& interface(size_t index);
|
interface_descriptor& interface(size_t index);
|
||||||
|
|
||||||
|
interface_descriptor& select_interface(network::ip::address address);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Open a new socket
|
* \brief Open a new socket
|
||||||
* \param domain The socket domain
|
* \param domain The socket domain
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include <vector.hpp>
|
#include <vector.hpp>
|
||||||
#include <algorithms.hpp>
|
#include <algorithms.hpp>
|
||||||
#include <circular_buffer.hpp>
|
#include <circular_buffer.hpp>
|
||||||
|
#include <type_traits.hpp>
|
||||||
|
|
||||||
#include "tlib/net_constants.hpp"
|
#include "tlib/net_constants.hpp"
|
||||||
|
|
||||||
@ -24,19 +25,16 @@
|
|||||||
namespace network {
|
namespace network {
|
||||||
|
|
||||||
struct socket {
|
struct socket {
|
||||||
size_t id;
|
size_t id; ///< The socket file descriptor
|
||||||
socket_domain domain;
|
socket_domain domain;
|
||||||
socket_type type;
|
socket_type type;
|
||||||
socket_protocol protocol;
|
socket_protocol protocol;
|
||||||
size_t next_fd;
|
size_t next_fd;
|
||||||
bool listen;
|
bool listen;
|
||||||
|
|
||||||
bool connected;
|
void* data = nullptr;
|
||||||
uint32_t local_port;
|
|
||||||
uint32_t server_port;
|
uint32_t local_port; //TODO This should not be here since it belongs to UDP
|
||||||
ip::address server_address;
|
|
||||||
uint32_t ack_number;
|
|
||||||
uint32_t seq_number;
|
|
||||||
|
|
||||||
std::vector<network::ethernet::packet> packets;
|
std::vector<network::ethernet::packet> packets;
|
||||||
|
|
||||||
@ -90,6 +88,18 @@ struct socket {
|
|||||||
return packet.fd == fd;
|
return packet.fd == fd;
|
||||||
}), packets.end());
|
}), packets.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
T& get_data(){
|
||||||
|
thor_assert(data);
|
||||||
|
return *reinterpret_cast<T*>(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
std::add_const_t<T>& get_data() const {
|
||||||
|
thor_assert(data);
|
||||||
|
return *reinterpret_cast<std::add_const_t<T>*>(data);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end of network namespace
|
} // end of network namespace
|
||||||
|
@ -22,13 +22,13 @@ namespace tcp {
|
|||||||
void decode(network::interface_descriptor& interface, network::ethernet::packet& packet);
|
void decode(network::interface_descriptor& interface, network::ethernet::packet& packet);
|
||||||
|
|
||||||
std::expected<network::ethernet::packet> prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size);
|
std::expected<network::ethernet::packet> prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t source, size_t target, size_t payload_size);
|
||||||
std::expected<network::ethernet::packet> prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size);
|
std::expected<network::ethernet::packet> prepare_packet(char* buffer, network::socket& socket, size_t payload_size);
|
||||||
|
|
||||||
void finalize_packet(network::interface_descriptor& interface, network::ethernet::packet& p);
|
void finalize_packet(network::interface_descriptor& interface, network::ethernet::packet& p);
|
||||||
void finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p);
|
void finalize_packet(network::interface_descriptor& interface, network::socket& socket, network::ethernet::packet& p);
|
||||||
|
|
||||||
std::expected<void> connect(network::socket& socket, network::interface_descriptor& interface);
|
std::expected<void> connect(network::socket& socket, network::interface_descriptor& interface, size_t local_port, size_t server_port, network::ip::address server);
|
||||||
std::expected<void> disconnect(network::socket& socket, network::interface_descriptor& interface);
|
std::expected<void> disconnect(network::socket& socket);
|
||||||
|
|
||||||
} // end of tcp namespace
|
} // end of tcp namespace
|
||||||
|
|
||||||
|
@ -74,26 +74,6 @@ void tx_thread(void* data){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
network::interface_descriptor& select_interface(network::ip::address address){
|
|
||||||
if(address == network::ip::make_address(127, 0, 0, 1)){
|
|
||||||
for(auto& interface : interfaces){
|
|
||||||
if(interface.enabled && interface.is_loopback()){
|
|
||||||
return interface;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise return the first enabled interface
|
|
||||||
|
|
||||||
for(auto& interface : interfaces){
|
|
||||||
if(interface.enabled){
|
|
||||||
return interface;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
thor_unreachable("network: Should never happen");
|
|
||||||
}
|
|
||||||
|
|
||||||
void sysfs_publish(const network::interface_descriptor& interface){
|
void sysfs_publish(const network::interface_descriptor& interface){
|
||||||
auto p = path("/net") / interface.name;
|
auto p = path("/net") / interface.name;
|
||||||
|
|
||||||
@ -232,6 +212,27 @@ network::interface_descriptor& network::interface(size_t index){
|
|||||||
return interfaces[index];
|
return interfaces[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
network::interface_descriptor& network::select_interface(network::ip::address address){
|
||||||
|
if(address == network::ip::make_address(127, 0, 0, 1)){
|
||||||
|
for(auto& interface : interfaces){
|
||||||
|
if(interface.enabled && interface.is_loopback()){
|
||||||
|
return interface;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise return the first enabled interface
|
||||||
|
|
||||||
|
for(auto& interface : interfaces){
|
||||||
|
if(interface.enabled){
|
||||||
|
return interface;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thor_unreachable("network: Should never happen");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
std::expected<network::socket_fd_t> network::open(network::socket_domain domain, network::socket_type type, network::socket_protocol protocol){
|
std::expected<network::socket_fd_t> network::open(network::socket_domain domain, network::socket_type type, network::socket_protocol protocol){
|
||||||
// Make sure the socket domain is valid
|
// Make sure the socket domain is valid
|
||||||
if(domain != socket_domain::AF_INET){
|
if(domain != socket_domain::AF_INET){
|
||||||
@ -258,15 +259,7 @@ std::expected<network::socket_fd_t> network::open(network::socket_domain domain,
|
|||||||
return std::make_expected_from_error<network::socket_fd_t>(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL);
|
return std::make_expected_from_error<network::socket_fd_t>(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto socket_fd = scheduler::register_new_socket(domain, type, protocol);
|
return scheduler::register_new_socket(domain, type, protocol);
|
||||||
|
|
||||||
// Initialize TCP connection values
|
|
||||||
auto& socket = scheduler::get_socket(socket_fd);
|
|
||||||
socket.connected = false;
|
|
||||||
socket.local_port = 0;
|
|
||||||
socket.server_port = 0;
|
|
||||||
|
|
||||||
return socket_fd;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void network::close(size_t fd){
|
void network::close(size_t fd){
|
||||||
@ -286,11 +279,6 @@ std::tuple<size_t, size_t> network::prepare_packet(socket_fd_t socket_fd, void*
|
|||||||
|
|
||||||
auto& socket = scheduler::get_socket(socket_fd);
|
auto& socket = scheduler::get_socket(socket_fd);
|
||||||
|
|
||||||
// Make sure stream sockets are connected
|
|
||||||
if(socket.type == socket_type::STREAM && !socket.connected){
|
|
||||||
return {-std::ERROR_SOCKET_NOT_CONNECTED, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
auto return_from_packet = [&socket](std::expected<network::ethernet::packet>& packet) -> std::tuple<size_t, size_t> {
|
auto return_from_packet = [&socket](std::expected<network::ethernet::packet>& packet) -> std::tuple<size_t, size_t> {
|
||||||
if (packet) {
|
if (packet) {
|
||||||
auto fd = socket.register_packet(*packet);
|
auto fd = socket.register_packet(*packet);
|
||||||
@ -320,8 +308,7 @@ std::tuple<size_t, size_t> network::prepare_packet(socket_fd_t socket_fd, void*
|
|||||||
|
|
||||||
case network::socket_protocol::TCP: {
|
case network::socket_protocol::TCP: {
|
||||||
auto descriptor = static_cast<network::tcp::packet_descriptor*>(desc);
|
auto descriptor = static_cast<network::tcp::packet_descriptor*>(desc);
|
||||||
auto& interface = select_interface(socket.server_address);
|
auto packet = network::tcp::prepare_packet(buffer, socket, descriptor->payload_size);
|
||||||
auto packet = network::tcp::prepare_packet(buffer, interface, socket, descriptor->payload_size);
|
|
||||||
|
|
||||||
return return_from_packet(packet);
|
return return_from_packet(packet);
|
||||||
}
|
}
|
||||||
@ -355,11 +342,6 @@ std::expected<void> network::finalize_packet(socket_fd_t socket_fd, size_t packe
|
|||||||
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_PACKET_FD);
|
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_PACKET_FD);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure stream sockets are connected
|
|
||||||
if(socket.type == socket_type::STREAM && !socket.connected){
|
|
||||||
return std::make_unexpected<void>(-std::ERROR_SOCKET_NOT_CONNECTED);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& packet = socket.get_packet(packet_fd);
|
auto& packet = socket.get_packet(packet_fd);
|
||||||
auto& interface = network::interface(packet.interface);
|
auto& interface = network::interface(packet.interface);
|
||||||
|
|
||||||
@ -427,25 +409,21 @@ std::expected<size_t> network::connect(socket_fd_t socket_fd, network::ip::addre
|
|||||||
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE);
|
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
socket.local_port = local_port++;
|
auto selected_port = local_port++;
|
||||||
socket.server_port = port;
|
|
||||||
socket.server_address = server;
|
|
||||||
|
|
||||||
logging::logf(logging::log_level::TRACE, "network: %u stream socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port);
|
logging::logf(logging::log_level::TRACE, "network: %u stream socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port);
|
||||||
|
|
||||||
if(socket.protocol == socket_protocol::TCP){
|
if(socket.protocol == socket_protocol::TCP){
|
||||||
auto connection = network::tcp::connect(socket, select_interface(server));
|
auto connection = network::tcp::connect(socket, select_interface(server), selected_port, port, server);
|
||||||
|
|
||||||
if(connection){
|
if(!connection){
|
||||||
socket.connected = true;
|
|
||||||
} else {
|
|
||||||
return std::make_unexpected<size_t>(connection.error());
|
return std::make_unexpected<size_t>(connection.error());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL);
|
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_expected<size_t>(socket.local_port);
|
return std::make_expected<size_t>(selected_port);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<void> network::disconnect(socket_fd_t socket_fd){
|
std::expected<void> network::disconnect(socket_fd_t socket_fd){
|
||||||
@ -459,18 +437,12 @@ std::expected<void> network::disconnect(socket_fd_t socket_fd){
|
|||||||
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_TYPE);
|
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_TYPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!socket.connected){
|
|
||||||
return std::make_unexpected<void>(std::ERROR_SOCKET_NOT_CONNECTED);
|
|
||||||
}
|
|
||||||
|
|
||||||
logging::logf(logging::log_level::TRACE, "network: %u disconnect from stream socket %u\n", scheduler::get_pid(), socket_fd);
|
logging::logf(logging::log_level::TRACE, "network: %u disconnect from stream socket %u\n", scheduler::get_pid(), socket_fd);
|
||||||
|
|
||||||
if(socket.protocol == socket_protocol::TCP){
|
if(socket.protocol == socket_protocol::TCP){
|
||||||
auto disconnection = network::tcp::disconnect(socket, select_interface(socket.server_address));
|
auto disconnection = network::tcp::disconnect(socket);
|
||||||
|
|
||||||
if(disconnection){
|
if(!disconnection){
|
||||||
socket.connected = false;
|
|
||||||
} else {
|
|
||||||
return std::make_unexpected<void>(disconnection.error());
|
return std::make_unexpected<void>(disconnection.error());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -567,30 +539,10 @@ void network::propagate_packet(const ethernet::packet& packet, socket_protocol p
|
|||||||
propagate = true;
|
propagate = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if(socket.type == socket_type::STREAM){
|
|
||||||
if(socket.protocol == protocol && socket.connected){
|
|
||||||
auto local_port = socket.local_port;
|
|
||||||
auto server_port = socket.server_port;
|
|
||||||
|
|
||||||
auto tcp_index = packet.tag(2);
|
|
||||||
auto* tcp_header = reinterpret_cast<network::tcp::header*>(packet.payload + tcp_index);
|
|
||||||
auto source_port = switch_endian_16(tcp_header->source_port);
|
|
||||||
auto target_port = switch_endian_16(tcp_header->target_port);
|
|
||||||
auto flags = switch_endian_16(tcp_header->flags);
|
|
||||||
|
|
||||||
using flag_psh = std::bit_field<uint16_t, uint8_t, 3, 1>;
|
|
||||||
|
|
||||||
// Don't propagate ack
|
|
||||||
if (*flag_psh(&flags)) {
|
|
||||||
logging::logf(logging::log_level::TRACE, "network: propagate on socket %u\n", socket.id);
|
|
||||||
|
|
||||||
if(local_port == target_port && server_port == source_port){
|
|
||||||
propagate = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: Stream sockets are responsible for propagation
|
||||||
|
|
||||||
if (propagate) {
|
if (propagate) {
|
||||||
auto copy = packet;
|
auto copy = packet;
|
||||||
copy.payload = new char[copy.payload_size];
|
copy.payload = new char[copy.payload_size];
|
||||||
|
@ -40,17 +40,20 @@ using flag_syn = std::bit_field<uint16_t, uint8_t, 1, 1>;
|
|||||||
using flag_fin = std::bit_field<uint16_t, uint8_t, 0, 1>;
|
using flag_fin = std::bit_field<uint16_t, uint8_t, 0, 1>;
|
||||||
|
|
||||||
struct tcp_connection {
|
struct tcp_connection {
|
||||||
size_t source_port; ///< The source port of the connection
|
size_t local_port; ///< The local source port
|
||||||
size_t target_port; ///< The target port of the connection
|
size_t server_port; ///< The server port
|
||||||
|
network::ip::address server_address; ///< The server address
|
||||||
|
|
||||||
std::atomic<bool> listening; ///< Indicates if a kernel thread is listening on this connection
|
std::atomic<bool> listening; ///< Indicates if a kernel thread is listening on this connection
|
||||||
condition_variable queue; ///< The listening queue
|
condition_variable queue; ///< The listening queue
|
||||||
circular_buffer<network::ethernet::packet, 8> packets; ///< The packets for the listening queue
|
circular_buffer<network::ethernet::packet, 8> packets; ///< The packets for the listening queue
|
||||||
|
|
||||||
tcp_connection(size_t source_port, size_t target_port)
|
bool connected = false;
|
||||||
: source_port(source_port), target_port(target_port), listening(false) {
|
|
||||||
//Nothing else to init
|
uint32_t ack_number = 0; ///< The next ack number
|
||||||
}
|
uint32_t seq_number = 0; ///< The next sequence number
|
||||||
|
|
||||||
|
network::socket* socket = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The lock used to protect the list of connections
|
// The lock used to protect the list of connections
|
||||||
@ -59,12 +62,12 @@ rw_lock connections_lock;
|
|||||||
// Note: We need a list to not invalidate the values during insertions
|
// Note: We need a list to not invalidate the values during insertions
|
||||||
std::list<tcp_connection> connections;
|
std::list<tcp_connection> connections;
|
||||||
|
|
||||||
tcp_connection* get_connection(size_t source_port, size_t target_port) {
|
tcp_connection* get_connection_for_packet(size_t source_port, size_t target_port) {
|
||||||
auto lock = connections_lock.reader_lock();
|
auto lock = connections_lock.reader_lock();
|
||||||
std::lock_guard<reader_rw_lock> l(lock);
|
std::lock_guard<reader_rw_lock> l(lock);
|
||||||
|
|
||||||
for(auto& connection : connections){
|
for(auto& connection : connections){
|
||||||
if (connection.source_port == source_port && connection.target_port == target_port) {
|
if (connection.server_port == source_port && connection.local_port == target_port) {
|
||||||
return &connection;
|
return &connection;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -72,11 +75,15 @@ tcp_connection* get_connection(size_t source_port, size_t target_port) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
tcp_connection& create_connection(size_t target, size_t source){
|
tcp_connection& create_connection(){
|
||||||
auto lock = connections_lock.writer_lock();
|
auto lock = connections_lock.writer_lock();
|
||||||
std::lock_guard<writer_rw_lock> l(lock);
|
std::lock_guard<writer_rw_lock> l(lock);
|
||||||
|
|
||||||
return connections.emplace_back(target, source);
|
auto& connection = connections.emplace_back();
|
||||||
|
|
||||||
|
connection.listening = false;
|
||||||
|
|
||||||
|
return connection;
|
||||||
}
|
}
|
||||||
|
|
||||||
void remove_connection(tcp_connection& connection){
|
void remove_connection(tcp_connection& connection){
|
||||||
@ -178,14 +185,22 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
|
|||||||
|
|
||||||
auto flags = switch_endian_16(tcp_header->flags);
|
auto flags = switch_endian_16(tcp_header->flags);
|
||||||
|
|
||||||
|
auto next_seq = ack;
|
||||||
|
auto next_ack = seq + tcp_payload_len(packet);;
|
||||||
|
|
||||||
|
auto connection_ptr = get_connection_for_packet(source_port, target_port);
|
||||||
|
|
||||||
|
if(connection_ptr){
|
||||||
|
auto& connection = *connection_ptr;
|
||||||
|
|
||||||
|
// Update the connection status
|
||||||
|
|
||||||
|
connection.seq_number = next_seq;
|
||||||
|
connection.ack_number = next_ack;
|
||||||
|
|
||||||
// Propagate to kernel connections
|
// Propagate to kernel connections
|
||||||
|
|
||||||
{
|
if (connection.listening.load()) {
|
||||||
auto lock = connections_lock.reader_lock();
|
|
||||||
std::lock_guard<reader_rw_lock> l(lock);
|
|
||||||
|
|
||||||
for (auto& connection : connections) {
|
|
||||||
if (connection.listening.load() && connection.source_port == source_port && connection.target_port == target_port) {
|
|
||||||
auto copy = packet;
|
auto copy = packet;
|
||||||
copy.payload = new char[copy.payload_size];
|
copy.payload = new char[copy.payload_size];
|
||||||
std::copy_n(packet.payload, packet.payload_size, copy.payload);
|
std::copy_n(packet.payload, packet.payload_size, copy.payload);
|
||||||
@ -193,20 +208,29 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
|
|||||||
connection.packets.push(copy);
|
connection.packets.push(copy);
|
||||||
connection.queue.notify_one();
|
connection.queue.notify_one();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto seq_number = ack;
|
// Propagate to the kernel socket
|
||||||
auto ack_number = seq + tcp_payload_len(packet);
|
|
||||||
|
|
||||||
//TODO socket.seq_number = ack;
|
if (*flag_psh(&flags) && connection.socket) {
|
||||||
//TODO socket.ack_number = seq + tcp_payload_len(packet);
|
auto& socket = *connection.socket;
|
||||||
|
|
||||||
packet.index += sizeof(header);
|
packet.index += sizeof(header);
|
||||||
|
|
||||||
network::propagate_packet(packet, network::socket_protocol::TCP);
|
if (socket.listen) {
|
||||||
|
auto copy = packet;
|
||||||
|
copy.payload = new char[copy.payload_size];
|
||||||
|
std::copy_n(packet.payload, packet.payload_size, copy.payload);
|
||||||
|
|
||||||
|
socket.listen_packets.push(copy);
|
||||||
|
socket.listen_queue.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logging::logf(logging::log_level::DEBUG, "tcp: Received packet for which there are no connection\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acknowledge if necessary
|
||||||
|
|
||||||
// A push needs to be acknowledged
|
|
||||||
if (*flag_psh(&flags)) {
|
if (*flag_psh(&flags)) {
|
||||||
auto p = tcp::prepare_packet(interface, switch_endian_32(ip_header->source_ip), target_port, source_port, 0);
|
auto p = tcp::prepare_packet(interface, switch_endian_32(ip_header->source_ip), target_port, source_port, 0);
|
||||||
|
|
||||||
@ -217,12 +241,12 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
|
|||||||
|
|
||||||
auto* ack_tcp_header = reinterpret_cast<header*>(p->payload + p->tag(2));
|
auto* ack_tcp_header = reinterpret_cast<header*>(p->payload + p->tag(2));
|
||||||
|
|
||||||
ack_tcp_header->sequence_number = switch_endian_32(seq_number);
|
ack_tcp_header->sequence_number = switch_endian_32(next_seq);
|
||||||
ack_tcp_header->ack_number = switch_endian_32(ack_number);
|
ack_tcp_header->ack_number = switch_endian_32(next_ack);
|
||||||
|
|
||||||
auto flags = get_default_flags();
|
auto ack_flags = get_default_flags();
|
||||||
(flag_ack(&flags)) = 1;
|
(flag_ack(&ack_flags)) = 1;
|
||||||
ack_tcp_header->flags = switch_endian_16(flags);
|
ack_tcp_header->flags = switch_endian_16(ack_flags);
|
||||||
|
|
||||||
tcp::finalize_packet(interface, *p);
|
tcp::finalize_packet(interface, *p);
|
||||||
}
|
}
|
||||||
@ -239,15 +263,23 @@ std::expected<network::ethernet::packet> network::tcp::prepare_packet(network::i
|
|||||||
return packet;
|
return packet;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<network::ethernet::packet> network::tcp::prepare_packet(char* buffer, network::interface_descriptor& interface, network::socket& socket, size_t payload_size) {
|
std::expected<network::ethernet::packet> network::tcp::prepare_packet(char* buffer, network::socket& socket, size_t payload_size) {
|
||||||
auto target_ip = socket.server_address;
|
auto& connection = socket.get_data<tcp_connection>();
|
||||||
|
|
||||||
|
// Make sure stream sockets are connected
|
||||||
|
if(!connection.connected){
|
||||||
|
return std::make_unexpected<network::ethernet::packet>(std::ERROR_SOCKET_NOT_CONNECTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto target_ip = connection.server_address;
|
||||||
|
auto& interface = network::select_interface(target_ip);
|
||||||
|
|
||||||
// Ask the IP layer to craft a packet
|
// Ask the IP layer to craft a packet
|
||||||
auto packet = network::ip::prepare_packet(buffer, interface, sizeof(header) + payload_size, target_ip, 0x06);
|
auto packet = network::ip::prepare_packet(buffer, interface, sizeof(header) + payload_size, target_ip, 0x06);
|
||||||
|
|
||||||
if (packet) {
|
if (packet) {
|
||||||
auto source = socket.local_port;
|
auto source = connection.local_port;
|
||||||
auto target = socket.server_port;
|
auto target = connection.server_port;
|
||||||
|
|
||||||
::prepare_packet(*packet, source, target);
|
::prepare_packet(*packet, source, target);
|
||||||
|
|
||||||
@ -258,8 +290,8 @@ std::expected<network::ethernet::packet> network::tcp::prepare_packet(char* buff
|
|||||||
(flag_ack(&flags)) = 1;
|
(flag_ack(&flags)) = 1;
|
||||||
tcp_header->flags = switch_endian_16(flags);
|
tcp_header->flags = switch_endian_16(flags);
|
||||||
|
|
||||||
tcp_header->sequence_number = switch_endian_32(socket.seq_number);
|
tcp_header->sequence_number = switch_endian_32(connection.seq_number);
|
||||||
tcp_header->ack_number = switch_endian_32(socket.ack_number);
|
tcp_header->ack_number = switch_endian_32(connection.ack_number);
|
||||||
}
|
}
|
||||||
|
|
||||||
return packet;
|
return packet;
|
||||||
@ -286,18 +318,14 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net
|
|||||||
return; //TODO Fail
|
return; //TODO Fail
|
||||||
}
|
}
|
||||||
|
|
||||||
auto source = socket.local_port;
|
auto& connection = socket.get_data<tcp_connection>();
|
||||||
auto target = socket.server_port;
|
|
||||||
|
|
||||||
auto connection_ptr = get_connection(target, source);
|
// Make sure stream sockets are connected
|
||||||
|
if(!connection.connected){
|
||||||
if (!connection_ptr) {
|
//TODO return std::make_unexpected<void>(std::ERROR_SOCKET_NOT_CONNECTED);
|
||||||
logging::logf(logging::log_level::ERROR, "tcp: Unable to find connection!\n");
|
return;
|
||||||
return; //TODO Fail
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& connection = *connection_ptr;
|
|
||||||
|
|
||||||
connection.listening = true;
|
connection.listening = true;
|
||||||
|
|
||||||
uint32_t seq = 0;
|
uint32_t seq = 0;
|
||||||
@ -357,22 +385,31 @@ void network::tcp::finalize_packet(network::interface_descriptor& interface, net
|
|||||||
|
|
||||||
if(received){
|
if(received){
|
||||||
// Set the future sequence and acknowledgement numbers
|
// Set the future sequence and acknowledgement numbers
|
||||||
socket.seq_number = ack;
|
connection.seq_number = ack;
|
||||||
socket.ack_number = seq;
|
connection.ack_number = seq;
|
||||||
} else {
|
} else {
|
||||||
//TODO We need to be able to make finalize fail!
|
//TODO We need to be able to make finalize fail!
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<void> network::tcp::connect(network::socket& sock, network::interface_descriptor& interface) {
|
std::expected<void> network::tcp::connect(network::socket& sock, network::interface_descriptor& interface, size_t local_port, size_t server_port, network::ip::address server) {
|
||||||
auto target_ip = sock.server_address;
|
// Create the connection
|
||||||
auto source = sock.local_port;
|
|
||||||
auto target = sock.server_port;
|
|
||||||
|
|
||||||
sock.seq_number = 0;
|
auto& connection = create_connection();
|
||||||
sock.ack_number = 0;
|
|
||||||
|
|
||||||
auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0);
|
connection.local_port = local_port;
|
||||||
|
connection.server_port = server_port;
|
||||||
|
connection.server_address = server;
|
||||||
|
|
||||||
|
// Link the socket and connection
|
||||||
|
sock.data = &connection;
|
||||||
|
connection.socket = &sock;
|
||||||
|
|
||||||
|
// Prepare the SYN packet
|
||||||
|
|
||||||
|
auto target_ip = connection.server_address;
|
||||||
|
|
||||||
|
auto packet = tcp::prepare_packet(interface, target_ip, local_port, server_port, 0);
|
||||||
|
|
||||||
if (!packet) {
|
if (!packet) {
|
||||||
return std::make_unexpected<void>(packet.error());
|
return std::make_unexpected<void>(packet.error());
|
||||||
@ -380,16 +417,12 @@ std::expected<void> network::tcp::connect(network::socket& sock, network::interf
|
|||||||
|
|
||||||
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
||||||
|
|
||||||
tcp_header->sequence_number = 0;
|
|
||||||
tcp_header->ack_number = 0;
|
|
||||||
|
|
||||||
auto flags = get_default_flags();
|
auto flags = get_default_flags();
|
||||||
(flag_syn(&flags)) = 1;
|
(flag_syn(&flags)) = 1;
|
||||||
tcp_header->flags = switch_endian_16(flags);
|
tcp_header->flags = switch_endian_16(flags);
|
||||||
|
|
||||||
// Create the connection
|
tcp_header->sequence_number = connection.seq_number;
|
||||||
|
tcp_header->ack_number = connection.ack_number;
|
||||||
auto& connection = create_connection(target, source);
|
|
||||||
|
|
||||||
connection.listening = true;
|
connection.listening = true;
|
||||||
|
|
||||||
@ -425,13 +458,13 @@ std::expected<void> network::tcp::connect(network::socket& sock, network::interf
|
|||||||
|
|
||||||
connection.listening = false;
|
connection.listening = false;
|
||||||
|
|
||||||
sock.seq_number = ack;
|
connection.seq_number = ack;
|
||||||
sock.ack_number = seq + 1;
|
connection.ack_number = seq + 1;
|
||||||
|
|
||||||
// At this point we have received the SYN/ACK, only remains to ACK
|
// At this point we have received the SYN/ACK, only remains to ACK
|
||||||
|
|
||||||
{
|
{
|
||||||
auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0);
|
auto packet = tcp::prepare_packet(interface, target_ip, connection.local_port, connection.server_port, 0);
|
||||||
|
|
||||||
if (!packet) {
|
if (!packet) {
|
||||||
return std::make_unexpected<void>(packet.error());
|
return std::make_unexpected<void>(packet.error());
|
||||||
@ -439,8 +472,8 @@ std::expected<void> network::tcp::connect(network::socket& sock, network::interf
|
|||||||
|
|
||||||
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
||||||
|
|
||||||
tcp_header->sequence_number = switch_endian_32(sock.seq_number);
|
tcp_header->sequence_number = switch_endian_32(connection.seq_number);
|
||||||
tcp_header->ack_number = switch_endian_32(sock.ack_number);
|
tcp_header->ack_number = switch_endian_32(connection.ack_number);
|
||||||
|
|
||||||
auto flags = get_default_flags();
|
auto flags = get_default_flags();
|
||||||
(flag_ack(&flags)) = 1;
|
(flag_ack(&flags)) = 1;
|
||||||
@ -450,13 +483,25 @@ std::expected<void> network::tcp::connect(network::socket& sock, network::interf
|
|||||||
tcp::finalize_packet(interface, *packet);
|
tcp::finalize_packet(interface, *packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark the connection as connected
|
||||||
|
|
||||||
|
connection.connected = true;
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<void> network::tcp::disconnect(network::socket& sock, network::interface_descriptor& interface) {
|
std::expected<void> network::tcp::disconnect(network::socket& sock) {
|
||||||
auto target_ip = sock.server_address;
|
auto& connection = sock.get_data<tcp_connection>();
|
||||||
auto source = sock.local_port;
|
|
||||||
auto target = sock.server_port;
|
if(!connection.connected){
|
||||||
|
return std::make_unexpected<void>(std::ERROR_SOCKET_NOT_CONNECTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto target_ip = connection.server_address;
|
||||||
|
auto source = connection.local_port;
|
||||||
|
auto target = connection.server_port;
|
||||||
|
|
||||||
|
auto& interface = network::select_interface(target_ip);
|
||||||
|
|
||||||
auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0);
|
auto packet = tcp::prepare_packet(interface, target_ip, source, target, 0);
|
||||||
|
|
||||||
@ -466,23 +511,14 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
|
|||||||
|
|
||||||
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
||||||
|
|
||||||
tcp_header->sequence_number = switch_endian_32(sock.seq_number);
|
tcp_header->sequence_number = switch_endian_32(connection.seq_number);
|
||||||
tcp_header->ack_number = switch_endian_32(sock.ack_number);
|
tcp_header->ack_number = switch_endian_32(connection.ack_number);
|
||||||
|
|
||||||
auto flags = get_default_flags();
|
auto flags = get_default_flags();
|
||||||
(flag_fin(&flags)) = 1;
|
(flag_fin(&flags)) = 1;
|
||||||
(flag_ack(&flags)) = 1;
|
(flag_ack(&flags)) = 1;
|
||||||
tcp_header->flags = switch_endian_16(flags);
|
tcp_header->flags = switch_endian_16(flags);
|
||||||
|
|
||||||
auto connection_ptr = get_connection(target, source);
|
|
||||||
|
|
||||||
if (!connection_ptr) {
|
|
||||||
logging::logf(logging::log_level::ERROR, "tcp: Unable to find connection!\n");
|
|
||||||
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_CONNECTION);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& connection = *connection_ptr;
|
|
||||||
|
|
||||||
connection.listening = true;
|
connection.listening = true;
|
||||||
|
|
||||||
logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n");
|
logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n");
|
||||||
@ -529,8 +565,8 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
|
|||||||
|
|
||||||
connection.listening = false;
|
connection.listening = false;
|
||||||
|
|
||||||
sock.seq_number = ack;
|
connection.seq_number = ack;
|
||||||
sock.ack_number = seq + 1;
|
connection.ack_number = seq + 1;
|
||||||
|
|
||||||
// At this point we have received the FIN/ACK, only remains to ACK
|
// At this point we have received the FIN/ACK, only remains to ACK
|
||||||
|
|
||||||
@ -543,8 +579,8 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
|
|||||||
|
|
||||||
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
|
||||||
|
|
||||||
tcp_header->sequence_number = switch_endian_32(sock.seq_number);
|
tcp_header->sequence_number = switch_endian_32(connection.seq_number);
|
||||||
tcp_header->ack_number = switch_endian_32(sock.ack_number);
|
tcp_header->ack_number = switch_endian_32(connection.ack_number);
|
||||||
|
|
||||||
auto flags = get_default_flags();
|
auto flags = get_default_flags();
|
||||||
(flag_ack(&flags)) = 1;
|
(flag_ack(&flags)) = 1;
|
||||||
@ -554,6 +590,10 @@ std::expected<void> network::tcp::disconnect(network::socket& sock, network::int
|
|||||||
tcp::finalize_packet(interface, *packet);
|
tcp::finalize_packet(interface, *packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark the connection as connected
|
||||||
|
|
||||||
|
connection.connected = false;
|
||||||
|
|
||||||
remove_connection(connection);
|
remove_connection(connection);
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user