More TCP server support

This commit is contained in:
Baptiste Wicht 2016-09-26 17:34:20 +02:00
parent 029af0ba36
commit 56badb6843
No known key found for this signature in database
GPG Key ID: C5566B6C7F884532
9 changed files with 450 additions and 32 deletions

View File

@ -37,6 +37,24 @@ struct connection_handler {
return nullptr;
}
template<typename Functor>
void for_each_connection_for_packet(size_t source_port, size_t target_port, Functor fun){
auto lock = connections_lock.reader_lock();
std::lock_guard<reader_rw_lock> l(lock);
for (auto& connection : connections) {
if(connection.server){
if (connection.server_port == target_port) {
fun(connection);
}
} else {
if (connection.server_port == source_port && connection.local_port == target_port) {
fun(connection);
}
}
}
}
connection_type& create_connection() {
auto lock = connections_lock.writer_lock();
std::lock_guard<writer_rw_lock> l(lock);

View File

@ -232,6 +232,21 @@ std::expected<size_t> connect(socket_fd_t socket_fd, network::ip::address addres
*/
std::expected<void> server_start(socket_fd_t socket_fd, network::ip::address address, size_t port);
/*!
* \brief Wait for a connection
* \param socket_fd The file descriptor of the packet
* \return the allocated port on success and a negative error code otherwise
*/
std::expected<size_t> accept(socket_fd_t socket_fd);
/*!
* \brief Wait for a connection
* \param socket_fd The file descriptor of the packet
* \param ms The timeout
* \return the allocated port on success and a negative error code otherwise
*/
std::expected<size_t> accept(socket_fd_t socket_fd, size_t ms);
/*!
* \brief Disconnect from a socket stream
* \param socket_fd The file descriptor of the packet

View File

@ -56,6 +56,8 @@ std::expected<size_t> receive(char* buffer, network::socket& socket, size_t n);
std::expected<size_t> receive(char* buffer, network::socket& socket, size_t n, size_t ms);
std::expected<size_t> connect(network::socket& socket, network::interface_descriptor& interface, size_t server_port, network::ip::address server);
std::expected<size_t> accept(network::socket& socket);
std::expected<size_t> accept(network::socket& socket, size_t ms);
std::expected<void> server_start(network::socket& socket, size_t server_port, network::ip::address server);
std::expected<void> disconnect(network::socket& socket);

View File

@ -783,6 +783,46 @@ std::expected<void> network::server_start(socket_fd_t socket_fd, network::ip::ad
}
}
std::expected<size_t> network::accept(socket_fd_t socket_fd){
if(!scheduler::has_socket(socket_fd)){
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_FD);
}
auto& socket = scheduler::get_socket(socket_fd);
if(socket.type != socket_type::STREAM){
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE);
}
switch(stream_protocol(socket.protocol)){
case socket_protocol::TCP:
return network::tcp::accept(socket);
default:
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL);
}
}
std::expected<size_t> network::accept(socket_fd_t socket_fd, size_t ms){
if(!scheduler::has_socket(socket_fd)){
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_FD);
}
auto& socket = scheduler::get_socket(socket_fd);
if(socket.type != socket_type::STREAM){
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE);
}
switch(stream_protocol(socket.protocol)){
case socket_protocol::TCP:
return network::tcp::accept(socket, ms);
default:
return std::make_unexpected<size_t>(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL);
}
}
std::expected<void> network::disconnect(socket_fd_t socket_fd){
if(!scheduler::has_socket(socket_fd)){
return std::make_unexpected<void>(std::ERROR_SOCKET_INVALID_FD);

View File

@ -192,31 +192,27 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
auto* ip_header = reinterpret_cast<network::ip::header*>(packet.payload + packet.tag(1));
auto* tcp_header = reinterpret_cast<network::tcp::header*>(packet.payload + packet.index);
logging::logf(logging::log_level::TRACE, "tcp: Start TCP packet handling\n");
logging::logf(logging::log_level::TRACE, "tcp:decode: Start TCP packet handling\n");
auto source_port = switch_endian_16(tcp_header->source_port);
auto target_port = switch_endian_16(tcp_header->target_port);
auto seq = switch_endian_32(tcp_header->sequence_number);
auto ack = switch_endian_32(tcp_header->ack_number);
logging::logf(logging::log_level::TRACE, "tcp: Source Port %u \n", size_t(source_port));
logging::logf(logging::log_level::TRACE, "tcp: Target Port %u \n", size_t(target_port));
logging::logf(logging::log_level::TRACE, "tcp: Seq Number %u \n", size_t(seq));
logging::logf(logging::log_level::TRACE, "tcp: Ack Number %u \n", size_t(ack));
logging::logf(logging::log_level::TRACE, "tcp:decode: Source Port %u \n", size_t(source_port));
logging::logf(logging::log_level::TRACE, "tcp:decode: Target Port %u \n", size_t(target_port));
logging::logf(logging::log_level::TRACE, "tcp:decode: Seq Number %u \n", size_t(seq));
logging::logf(logging::log_level::TRACE, "tcp:decode: Ack Number %u \n", size_t(ack));
auto flags = switch_endian_16(tcp_header->flags);
auto next_seq = ack;
auto next_ack = seq + tcp_payload_len(packet);
logging::logf(logging::log_level::TRACE, "tcp: Next Seq Number %u \n", size_t(next_seq));
logging::logf(logging::log_level::TRACE, "tcp: Next Ack Number %u \n", size_t(next_ack));
auto connection_ptr = connections.get_connection_for_packet(source_port, target_port);
if(connection_ptr){
auto& connection = *connection_ptr;
logging::logf(logging::log_level::TRACE, "tcp:decode: Next Seq Number %u \n", size_t(next_seq));
logging::logf(logging::log_level::TRACE, "tcp:decode: Next Ack Number %u \n", size_t(next_ack));
connections.for_each_connection_for_packet(source_port, target_port, [&](tcp_connection& connection) {
// Update the connection status
connection.seq_number = next_seq;
@ -249,9 +245,7 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
socket.listen_queue.notify_one();
}
}
} else {
logging::logf(logging::log_level::DEBUG, "tcp: Received packet for which there are no connection\n");
}
});
// Acknowledge if necessary
@ -259,7 +253,7 @@ void network::tcp::decode(network::interface_descriptor& interface, network::eth
auto p = kernel_prepare_packet(interface, switch_endian_32(ip_header->source_ip), target_port, source_port, 0);
if (!p) {
logging::logf(logging::log_level::ERROR, "tcp: Impossible to prepare TCP packet for ACK\n");
logging::logf(logging::log_level::ERROR, "tcp:decode: Impossible to prepare TCP packet for ACK\n");
return;
}
@ -284,7 +278,7 @@ std::expected<void> network::tcp::send(char* target_buffer, network::socket& soc
return std::make_unexpected<void>(std::ERROR_SOCKET_NOT_CONNECTED);
}
logging::logf(logging::log_level::ERROR, "tcp: Send %s(%u)\n", buffer, n);
logging::logf(logging::log_level::TRACE, "tcp:send: Send %s(%u)\n", buffer, n);
network::tcp::packet_descriptor desc{n};
auto packet = user_prepare_packet(target_buffer, socket, &desc);
@ -462,9 +456,14 @@ std::expected<void> network::tcp::finalize_packet(network::interface_descriptor&
auto flags = switch_endian_16(tcp_header->flags);
bool correct_ack = false;
if(*flag_syn(&source_flags)){
if(*flag_syn(&source_flags) && *flag_ack(&source_flags)){
// SYN/ACK should be acknowledge with ACK
correct_ack = *flag_ack(&flags);
} else if(*flag_syn(&source_flags)){
// SYN should be acknowledge with SYN/ACK
correct_ack = *flag_syn(&flags) && *flag_ack(&flags);
} else {
// Other packets should be acknowledge with ACK
correct_ack = *flag_ack(&flags);
}
@ -472,7 +471,7 @@ std::expected<void> network::tcp::finalize_packet(network::interface_descriptor&
//the sent packet
if (correct_ack) {
logging::logf(logging::log_level::TRACE, "tcp: Received ACK\n");
logging::logf(logging::log_level::TRACE, "tcp:finalize: Received ACK\n");
delete[] received_packet.payload;
@ -480,7 +479,7 @@ std::expected<void> network::tcp::finalize_packet(network::interface_descriptor&
break;
} else {
logging::logf(logging::log_level::TRACE, "tcp: Received unrelated answer\n");
logging::logf(logging::log_level::TRACE, "tcp:finalize: Received unrelated answer\n");
}
delete[] received_packet.payload;
@ -512,6 +511,8 @@ std::expected<void> network::tcp::finalize_packet(network::interface_descriptor&
}
std::expected<size_t> network::tcp::connect(network::socket& sock, network::interface_descriptor& interface, size_t server_port, network::ip::address server) {
logging::logf(logging::log_level::TRACE, "tcp:connect: Start\n");
// Create the connection
auto& connection = connections.create_connection();
@ -538,7 +539,7 @@ std::expected<size_t> network::tcp::connect(network::socket& sock, network::inte
(flag_syn(&flags)) = 1;
tcp_header->flags = switch_endian_16(flags);
logging::logf(logging::log_level::TRACE, "tcp: Send SYN\n");
logging::logf(logging::log_level::TRACE, "tcp:connect: Send SYN\n");
auto status = tcp::finalize_packet(interface, sock, *packet);
@ -548,6 +549,8 @@ std::expected<size_t> network::tcp::connect(network::socket& sock, network::inte
// The SYN/ACK is ensured by finalize_packet
logging::logf(logging::log_level::TRACE, "tcp:connect: Received SYN/ACK\n");
// At this point we have received the SYN/ACK, only remains to ACK
{
@ -563,7 +566,8 @@ std::expected<size_t> network::tcp::connect(network::socket& sock, network::inte
(flag_ack(&flags)) = 1;
tcp_header->flags = switch_endian_16(flags);
logging::logf(logging::log_level::TRACE, "tcp: Send ACK\n");
logging::logf(logging::log_level::TRACE, "tcp:connect: Send ACK\n");
finalize_packet_direct(interface, *packet);
}
@ -571,9 +575,125 @@ std::expected<size_t> network::tcp::connect(network::socket& sock, network::inte
connection.connected = true;
logging::logf(logging::log_level::TRACE, "tcp:connect: Done\n");
return connection.local_port;
}
std::expected<size_t> network::tcp::accept(network::socket& socket){
auto& connection = socket.get_connection_data<tcp_connection>();
if(!connection.connected){
return std::make_unexpected<size_t>(std::ERROR_SOCKET_NOT_CONNECTED);
}
// 1. Wait for SYN
connection.listening = true;
logging::logf(logging::log_level::TRACE, "tcp:accept: wait for connection\n");
uint32_t ack = 0;
uint32_t seq = 0;
uint16_t source_port = 0;
uint16_t target_port = 0;
uint32_t source_address = 0;
while (true) {
if(connection.packets.empty()){
connection.queue.wait();
}
auto received_packet = connection.packets.pop();
auto* tcp_header = reinterpret_cast<header*>(received_packet.payload + received_packet.index);
auto flags = switch_endian_16(tcp_header->flags);
if (*flag_syn(&flags)) {
seq = switch_endian_32(tcp_header->sequence_number);
ack = switch_endian_32(tcp_header->ack_number);
source_port = switch_endian_16(tcp_header->source_port);
target_port = switch_endian_16(tcp_header->target_port);
auto* ip_header = reinterpret_cast<network::ip::header*>(received_packet.payload + received_packet.tag(1));
source_address = ip_header->source_ip;
delete[] received_packet.payload;
break;
}
delete[] received_packet.payload;
}
logging::logf(logging::log_level::TRACE, "tcp:accept: received SYN\n");
connection.listening = false;
// Set the future sequence and acknowledgement numbers
connection.seq_number = ack;
connection.ack_number = seq + 1;
// 2. Prepare the child connection
auto child_fd = scheduler::register_new_socket(socket.domain, socket.type, socket.protocol);
auto& child_sock = scheduler::get_socket(child_fd);
// Create the connection
auto& child_connection = connections.create_connection();
child_connection.local_port = target_port;
child_connection.server_port = source_port;
child_connection.server_address = source_address;
// Link the socket and connection
child_sock.connection_data = &child_connection;
child_connection.socket = &child_sock;
child_connection.connected = true;
auto& interface = network::select_interface(source_address);
// 3. Send SYN/ACK
{
auto packet = kernel_prepare_packet(interface, child_connection, 0);
if (!packet) {
return std::make_unexpected<size_t>(packet.error());
}
auto* tcp_header = reinterpret_cast<header*>(packet->payload + packet->tag(2));
auto flags = get_default_flags();
(flag_syn(&flags)) = 1;
(flag_ack(&flags)) = 1;
tcp_header->flags = switch_endian_16(flags);
logging::logf(logging::log_level::TRACE, "tcp:accept: Send SYN/ACK\n");
auto status = tcp::finalize_packet(interface, child_sock, *packet);
if(!status){
return std::make_unexpected<size_t, size_t>(status.error());
}
}
// The ACK is enforced by finalize_packet
logging::logf(logging::log_level::TRACE, "tcp:accept: Done\n");
return {child_fd};
}
std::expected<size_t> network::tcp::accept(network::socket& socket, size_t ms){
}
std::expected<void> network::tcp::server_start(network::socket& sock, size_t server_port, network::ip::address server) {
// Create the connection
@ -595,7 +715,7 @@ std::expected<void> network::tcp::server_start(network::socket& sock, size_t ser
}
std::expected<void> network::tcp::disconnect(network::socket& sock) {
logging::logf(logging::log_level::TRACE, "tcp: Disconnect\n");
logging::logf(logging::log_level::TRACE, "tcp:disconnect: Disconnect\n");
auto& connection = sock.get_connection_data<tcp_connection>();
@ -621,7 +741,7 @@ std::expected<void> network::tcp::disconnect(network::socket& sock) {
connection.listening = true;
logging::logf(logging::log_level::TRACE, "tcp: Send FIN/ACK\n");
logging::logf(logging::log_level::TRACE, "tcp:disconnect: Send FIN/ACK\n");
bool rec_fin_ack = false;
bool rec_ack = false;
@ -711,7 +831,7 @@ std::expected<void> network::tcp::disconnect(network::socket& sock) {
// If we received an ACK, we must wait for a FIN/ACK from the server now
if(rec_ack){
logging::logf(logging::log_level::TRACE, "tcp: Received ACK waiting for FIN/ACK\n");
logging::logf(logging::log_level::TRACE, "tcp:disconnect: Received ACK waiting for FIN/ACK\n");
received = false;
@ -761,9 +881,9 @@ std::expected<void> network::tcp::disconnect(network::socket& sock) {
connection.seq_number = ack;
connection.ack_number = seq + 1;
logging::logf(logging::log_level::TRACE, "tcp: Received FIN/ACK waiting for ACK\n");
logging::logf(logging::log_level::TRACE, "tcp:disconnect: Received FIN/ACK waiting for ACK\n");
} else if(rec_fin_ack) {
logging::logf(logging::log_level::TRACE, "tcp: Received FIN/ACK directly waiting for ACK\n");
logging::logf(logging::log_level::TRACE, "tcp:disconnect: Received FIN/ACK directly waiting for ACK\n");
}
// Stop listening

View File

@ -506,13 +506,28 @@ void sc_connect(interrupt::syscall_regs* regs){
regs->rax = expected_to_i64(status);
}
void sc_server_start(interrupt::syscall_regs* regs){
void sc_server_start(interrupt::syscall_regs* regs) {
auto socket_fd = regs->rbx;
auto ip = regs->rcx;
auto port = regs->rdx;
auto ip = regs->rcx;
auto port = regs->rdx;
auto status = network::server_start(socket_fd, ip, port);
regs->rax = expected_to_i64(status);
regs->rax = expected_to_i64(status);
}
void sc_accept(interrupt::syscall_regs* regs) {
auto socket_fd = regs->rbx;
auto status = network::accept(socket_fd);
regs->rax = expected_to_i64(status);
}
void sc_accept_timeout(interrupt::syscall_regs* regs) {
auto socket_fd = regs->rbx;
auto ms = regs->rcx;
auto status = network::accept(socket_fd, ms);
regs->rax = expected_to_i64(status);
}
void sc_dns_server(interrupt::syscall_regs* regs){
@ -830,6 +845,14 @@ void system_call_entry(interrupt::syscall_regs* regs){
sc_dns_server(regs);
break;
case 0x3016:
sc_accept(regs);
break;
case 0x3017:
sc_accept_timeout(regs);
break;
// Special system calls
case 0x6666:

View File

@ -179,6 +179,79 @@ int netcat_tcp_server(const tlib::ip::address& local, size_t port){
return 1;
}
auto child = sock.accept();
if (!sock) {
tlib::printf("nc: accept error: %s\n", std::error_message(sock.error()));
return 1;
}
if (!child) {
tlib::printf("nc: accept error: %s\n", std::error_message(sock.error()));
return 1;
}
child.listen(true);
tlib::printf("nc: Received connection\n");
if (!child) {
tlib::printf("nc: listen error: %s\n", std::error_message(sock.error()));
return 1;
}
// Listen for packets from the client
char message_buffer[2049];
auto before = tlib::ms_time();
auto after = before;
while (true) {
// Make sure we don't wait for more than the timeout
if (after > before + server_timeout_ms) {
break;
}
auto remaining = server_timeout_ms - (after - before);
tlib::printf("nc: Wait for message\n");
auto size = child.receive(message_buffer, 2048, remaining);
if (!sock) {
if (sock.error() == std::ERROR_SOCKET_TIMEOUT) {
sock.clear();
break;
}
tlib::printf("nc: receive error: %s\n", std::error_message(sock.error()));
return 1;
} else {
message_buffer[size] = '\0';
tlib::print(message_buffer);
tlib::printf("nc: Send response\n");
sock.send(message_buffer, size);
if (!sock) {
tlib::printf("nc: send error: %s\n", std::error_message(sock.error()));
return 1;
}
}
after = tlib::ms_time();
}
tlib::printf("nc: done... disconnecting\n");
child.listen(false);
if (!child) {
tlib::printf("nc: listen error: %s\n", std::error_message(sock.error()));
return 1;
}
return 0;
}
@ -196,7 +269,7 @@ int netcat_udp_server(const tlib::ip::address& local, size_t port){
return 1;
}
// Listen for packets from the server
// Listen for packets from the client
char message_buffer[2049];

View File

@ -197,6 +197,20 @@ std::expected<size_t> connect(size_t socket_fd, tlib::ip::address server, size_t
*/
std::expected<void> server_start(size_t socket_fd, tlib::ip::address server, size_t port);
/*!
* \brief Wait for a incoming connection
* \param socket_fd The socket file descriptor
* \return a socket of the incoming connection
*/
std::expected<size_t> accept(size_t socket_fd);
/*!
* \brief Wait for a incoming connection
* \param socket_fd The socket file descriptor
* \return a socket of the incoming connection
*/
std::expected<size_t> accept(size_t socket_fd, size_t ms);
/*!
* \brief Disconnect from destination from the datagram socket
* \param socket_fd The socket file descriptor
@ -225,8 +239,12 @@ std::expected<packet> wait_for_packet(size_t socket_fd, size_t ms);
* This is easier to use than the free functions for sockets.
*/
struct socket {
socket();
socket(socket_domain domain, socket_type type, socket_protocol protocol);
socket(socket&& rhs);
socket& operator=(socket&& rhs);
/*!
* \brief Destruct the socket and release all acquired connections
*/
@ -316,6 +334,20 @@ struct socket {
*/
void server_start(tlib::ip::address server, size_t port);
/*!
* \brief Wait for a incoming connection
* \param socket_fd The socket file descriptor
* \return a socket of the incoming connection
*/
socket accept();
/*!
* \brief Wait for a incoming connection
* \param socket_fd The socket file descriptor
* \return a socket of the incoming connection
*/
socket accept(size_t ms);
/*!
* \brief Disconnnect from the server (stream socket)
*/

View File

@ -298,6 +298,34 @@ std::expected<void> tlib::server_start(size_t socket_fd, tlib::ip::address serve
}
}
std::expected<size_t> tlib::accept(size_t socket_fd) {
int64_t code;
asm volatile("mov rax, 0x3016; mov rbx, %[socket]; int 50; mov %[code], rax"
: [code] "=m"(code)
: [socket] "g"(socket_fd)
: "rax", "rbx");
if (code < 0) {
return std::make_unexpected<size_t, size_t>(-code);
} else {
return code;
}
}
std::expected<size_t> tlib::accept(size_t socket_fd, size_t ms) {
int64_t code;
asm volatile("mov rax, 0x3017; mov rbx, %[socket]; mov rcx, %[ms]; int 50; mov %[code], rax"
: [code] "=m"(code)
: [socket] "g"(socket_fd), [ms] "g"(ms)
: "rax", "rbx", "rcx");
if (code < 0) {
return std::make_unexpected<size_t, size_t>(-code);
} else {
return code;
}
}
std::expected<void> tlib::disconnect(size_t socket_fd) {
int64_t code;
asm volatile("mov rax, 0x3009; mov rbx, %[socket]; int 50; mov %[code], rax"
@ -354,6 +382,10 @@ std::expected<tlib::packet> tlib::wait_for_packet(size_t socket_fd, size_t ms) {
}
}
tlib::socket::socket() : fd(0), error_code(0) {
// Nothing else to init
}
tlib::socket::socket(socket_domain domain, socket_type type, socket_protocol protocol)
: domain(domain), type(type), protocol(protocol), fd(0), error_code(0) {
auto open_status = tlib::socket_open(domain, type, protocol);
@ -365,6 +397,27 @@ tlib::socket::socket(socket_domain domain, socket_type type, socket_protocol pro
}
}
tlib::socket::socket(tlib::socket&& rhs)
: domain(rhs.domain), type(rhs.type), protocol(rhs.protocol), fd(rhs.fd), error_code(rhs.error_code), _connected(rhs._connected), _bound(rhs._bound) {
rhs.fd = 0;
}
tlib::socket& tlib::socket::operator=(tlib::socket&& rhs){
if(this != &rhs){
this->domain = rhs.domain;
this->type = rhs.type;
this->protocol = rhs.protocol;
this->fd = rhs.fd;
this->error_code = rhs.error_code;
this->_connected = rhs._connected;
this->_bound = rhs._bound;
rhs.fd = 0;
}
return *this;
}
tlib::socket::~socket() {
if (connected()) {
disconnect();
@ -520,6 +573,48 @@ void tlib::socket::server_start(tlib::ip::address server, size_t port) {
}
}
tlib::socket tlib::socket::accept() {
if (!good() || !open()) {
return {};
}
auto status = tlib::accept(fd);
if (status) {
tlib::socket sock;
sock.fd = *status;
sock.domain = domain;
sock.type = type;
sock.protocol = protocol;
sock._connected = true;
sock._bound = false;
sock.error_code = 0;
return std::move(sock);
} else {
error_code = status.error();
}
return {};
}
tlib::socket tlib::socket::accept(size_t ms) {
if (!good() || !open()) {
return {};
}
auto status = tlib::accept(fd, ms);
if (status) {
//TODO
} else {
error_code = status.error();
}
return {};
}
void tlib::socket::disconnect() {
if (!good() || !open()) {
return;