diff --git a/kernel/include/net/network.hpp b/kernel/include/net/network.hpp index a1fe10c2..6c62a3d5 100644 --- a/kernel/include/net/network.hpp +++ b/kernel/include/net/network.hpp @@ -172,6 +172,20 @@ std::expected client_bind(socket_fd_t socket_fd, network::ip::address ad */ std::expected client_bind(socket_fd_t socket_fd, network::ip::address address, size_t port); +/*! + * \brief Bind a socket datagram as a server + * \param socket_fd The file descriptor of the packet + * \return noting or an error code otherwise + */ +std::expected server_bind(socket_fd_t socket_fd, network::ip::address address); + +/*! + * \brief Bind a socket datagram as a server + * \param socket_fd The file descriptor of the packet + * \return noting or an error code otherwise + */ +std::expected server_bind(socket_fd_t socket_fd, network::ip::address address, size_t port); + /*! * \brief Unbind a socket datagram as a client * \param socket_fd The file descriptor of the packet diff --git a/kernel/include/net/udp_layer.hpp b/kernel/include/net/udp_layer.hpp index f8db3e83..059a7270 100644 --- a/kernel/include/net/udp_layer.hpp +++ b/kernel/include/net/udp_layer.hpp @@ -59,6 +59,8 @@ std::expected user_prepare_packet(char* buffer, netwo std::expected finalize_packet(network::interface_descriptor& interface, network::ethernet::packet& p); std::expected client_bind(network::socket& socket, size_t server_port, network::ip::address server); +std::expected server_bind(network::socket& socket, size_t server_port, network::ip::address server); + std::expected client_unbind(network::socket& socket); std::expected receive(char* buffer, network::socket& socket, size_t n); diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index 0e42aebc..431317ca 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -142,6 +142,16 @@ void sysfs_publish(network::interface_descriptor& interface){ } } +size_t datagram_port(network::socket_protocol protocol){ + switch(protocol){ + case network::socket_protocol::DNS: + return 53; + + default: + return 0; + } +} + network::socket_protocol datagram_protocol(network::socket_protocol protocol){ switch(protocol){ case network::socket_protocol::DNS: @@ -522,9 +532,15 @@ std::expected network::client_bind(socket_fd_t socket_fd, network::ip::a return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); } + auto port = datagram_port(socket.protocol); + + if(!port){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); + } + switch(datagram_protocol(socket.protocol)){ case socket_protocol::UDP: - return network::udp::client_bind(socket, /* TODO PORT */ 53, address); + return network::udp::client_bind(socket, port, address); default: return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); @@ -551,6 +567,52 @@ std::expected network::client_bind(socket_fd_t socket_fd, network::ip::a } } +std::expected network::server_bind(socket_fd_t socket_fd, network::ip::address address){ + if(!scheduler::has_socket(socket_fd)){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_FD); + } + + auto& socket = scheduler::get_socket(socket_fd); + + if(socket.type != socket_type::DGRAM){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); + } + + auto port = datagram_port(socket.protocol); + + if(!port){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); + } + + switch(datagram_protocol(socket.protocol)){ + case socket_protocol::UDP: + return network::udp::server_bind(socket, port, address); + + default: + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); + } +} + +std::expected network::server_bind(socket_fd_t socket_fd, network::ip::address address, size_t port){ + if(!scheduler::has_socket(socket_fd)){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_FD); + } + + auto& socket = scheduler::get_socket(socket_fd); + + if(socket.type != socket_type::DGRAM){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); + } + + switch(datagram_protocol(socket.protocol)){ + case socket_protocol::UDP: + return network::udp::server_bind(socket, port, address); + + default: + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); + } +} + std::expected network::client_unbind(socket_fd_t socket_fd){ if(!scheduler::has_socket(socket_fd)){ return std::make_unexpected(std::ERROR_SOCKET_INVALID_FD); diff --git a/kernel/src/net/udp_layer.cpp b/kernel/src/net/udp_layer.cpp index 11fe4ae6..fce66f95 100644 --- a/kernel/src/net/udp_layer.cpp +++ b/kernel/src/net/udp_layer.cpp @@ -26,6 +26,7 @@ struct udp_connection { network::ip::address server_address; ///< The server address bool connected = false; + bool server = false; network::socket* socket = nullptr; }; @@ -169,6 +170,26 @@ std::expected network::udp::client_bind(network::socket& sock, size_t se return {connection.local_port}; } +std::expected network::udp::server_bind(network::socket& sock, size_t server_port, network::ip::address server){ + // Create the connection + + auto& connection = connections.create_connection(); + + connection.server_port = server_port; + connection.server_address = server; + connection.server = true; + + // Link the socket and connection + sock.connection_data = &connection; + connection.socket = &sock; + + // Mark the connection as connected + + connection.connected = true; + + return {}; +} + std::expected network::udp::client_unbind(network::socket& sock){ auto& connection = sock.get_connection_data(); diff --git a/kernel/src/system_calls.cpp b/kernel/src/system_calls.cpp index 42abef01..f8194ff1 100644 --- a/kernel/src/system_calls.cpp +++ b/kernel/src/system_calls.cpp @@ -436,6 +436,23 @@ void sc_client_bind_port(interrupt::syscall_regs* regs){ regs->rax = expected_to_i64(status); } +void sc_server_bind(interrupt::syscall_regs* regs) { + auto socket_fd = regs->rbx; + auto local_ip = regs->rcx; + + auto status = network::server_bind(socket_fd, local_ip); + regs->rax = expected_to_i64(status); +} + +void sc_server_bind_port(interrupt::syscall_regs* regs) { + auto socket_fd = regs->rbx; + auto local_ip = regs->rcx; + auto port = regs->rdx; + + auto status = network::server_bind(socket_fd, local_ip, port); + regs->rax = expected_to_i64(status); +} + void sc_client_unbind(interrupt::syscall_regs* regs){ auto socket_fd = regs->rbx; @@ -731,6 +748,14 @@ void system_call_entry(interrupt::syscall_regs* regs){ sc_client_bind_port(regs); break; + case 0x300E: + sc_server_bind(regs); + break; + + case 0x300F: + sc_server_bind_port(regs); + break; + // Special system calls case 0x6666: diff --git a/programs/nc/src/main.cpp b/programs/nc/src/main.cpp index accf6740..8c84db45 100644 --- a/programs/nc/src/main.cpp +++ b/programs/nc/src/main.cpp @@ -164,11 +164,63 @@ int netcat_udp_client(const tlib::ip::address& server, size_t port){ return 0; } -int netcat_udp_server(const tlib::ip::address& local, size_t port){ +int netcat_tcp_server(const tlib::ip::address& local, size_t port){ return 0; } -int netcat_tcp_server(const tlib::ip::address& local, size_t port){ +int netcat_udp_server(const tlib::ip::address& local, size_t port){ + auto ip_str = ip_to_str(local); + tlib::printf("netcat UDP server %s:%u\n", ip_str.c_str(), port); + + tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::DGRAM, tlib::socket_protocol::UDP); + + sock.server_bind(local, port); + sock.listen(true); + + if (!sock) { + tlib::printf("nc: listen error: %s\n", std::error_message(sock.error())); + return 1; + } + + // Listen for packets from the server + + char message_buffer[2049]; + + auto before = tlib::ms_time(); + auto after = before; + + while (true) { + // Make sure we don't wait for more than the timeout + if (after > before + timeout_ms) { + break; + } + + auto remaining = timeout_ms - (after - before); + + auto size = sock.receive(message_buffer, 2048, remaining); + if (!sock) { + if (sock.error() == std::ERROR_SOCKET_TIMEOUT) { + sock.clear(); + break; + } + + tlib::printf("nc: receive error: %s\n", std::error_message(sock.error())); + return 1; + } else { + message_buffer[size] = '\0'; + tlib::print(message_buffer); + } + + after = tlib::ms_time(); + } + + sock.listen(false); + + if (!sock) { + tlib::printf("nc: listen error: %s\n", std::error_message(sock.error())); + return 1; + } + return 0; } diff --git a/tlib/include/tlib/net.hpp b/tlib/include/tlib/net.hpp index 97aacfc5..aa6a97cc 100644 --- a/tlib/include/tlib/net.hpp +++ b/tlib/include/tlib/net.hpp @@ -130,6 +130,23 @@ std::expected client_bind(size_t socket_fd, tlib::ip::address server); */ std::expected client_bind(size_t socket_fd, tlib::ip::address server, size_t port); +/*! + * \brief Bind a source to the datagram socket + * \param socket_fd The socket file descriptor + * \param server The server address + * \return the local port, or an error + */ +std::expected server_bind(size_t socket_fd, tlib::ip::address local); + +/*! + * \brief Bind a source to the datagram socket + * \param socket_fd The socket file descriptor + * \param local The local address + * \param port The listening port + * \return the local port, or an error + */ +std::expected server_bind(size_t socket_fd, tlib::ip::address local, size_t port); + /*! * \brief Unbind from destination from the datagram socket * \param socket_fd The socket file descriptor @@ -234,6 +251,18 @@ struct socket { */ void client_bind(tlib::ip::address server, size_t port); + /*! + * \brief Bind the socket as a server + * \param local The IP address + */ + void server_bind(tlib::ip::address local); + + /*! + * \brief Bind the socket as a server + * \param local The IP address + */ + void server_bind(tlib::ip::address local, size_t port); + /*! * \brief Unbind the client socket */ diff --git a/tlib/src/net.cpp b/tlib/src/net.cpp index 235d487d..9cfdec4d 100644 --- a/tlib/src/net.cpp +++ b/tlib/src/net.cpp @@ -168,6 +168,34 @@ std::expected tlib::client_bind(size_t socket_fd, tlib::ip::address serv } } +std::expected tlib::server_bind(size_t socket_fd, tlib::ip::address server) { + int64_t code; + asm volatile("mov rax, 0x300E; mov rbx, %[socket]; mov rcx, %[ip]; int 50; mov %[code], rax" + : [code] "=m"(code) + : [socket] "g"(socket_fd), [ip] "g" (size_t(server.raw_address)) + : "rax", "rbx", "rcx"); + + if (code < 0) { + return std::make_unexpected(-code); + } else { + return {}; + } +} + +std::expected tlib::server_bind(size_t socket_fd, tlib::ip::address server, size_t port) { + int64_t code; + asm volatile("mov rax, 0x300F; mov rbx, %[socket]; mov rcx, %[ip]; mov rdx, %[port]; int 50; mov %[code], rax" + : [code] "=m"(code) + : [socket] "g"(socket_fd), [ip] "g" (size_t(server.raw_address)), [port] "g" (port) + : "rax", "rbx", "rcx"); + + if (code < 0) { + return std::make_unexpected(-code); + } else { + return {}; + } +} + std::expected tlib::client_unbind(size_t socket_fd) { int64_t code; asm volatile("mov rax, 0x300A; mov rbx, %[socket]; int 50; mov %[code], rax" @@ -344,6 +372,34 @@ void tlib::socket::client_bind(tlib::ip::address server, size_t port) { } } +void tlib::socket::server_bind(tlib::ip::address server) { + if (!good() || !open()) { + return; + } + + auto status = tlib::server_bind(fd, server); + if (status) { + _bound = true; + } else { + _bound = false; + error_code = status.error(); + } +} + +void tlib::socket::server_bind(tlib::ip::address server, size_t port) { + if (!good() || !open()) { + return; + } + + auto status = tlib::server_bind(fd, server, port); + if (status) { + _bound = true; + } else { + error_code = status.error(); + _bound = false; + } +} + void tlib::socket::client_unbind() { if (!good() || !open()) { return;