diff --git a/kernel/include/net/network.hpp b/kernel/include/net/network.hpp index ccf78b54..6fb49fa9 100644 --- a/kernel/include/net/network.hpp +++ b/kernel/include/net/network.hpp @@ -100,6 +100,13 @@ std::expected finalize_packet(socket_fd_t socket_fd, size_t packet_fd); */ std::expected listen(socket_fd_t socket_fd, bool listen); +/*! + * \brief Bind a socket datagram as a client (bind a local random port) + * \param socket_fd The file descriptor of the packet + * \return the allocated port on success and a negative error code otherwise + */ +std::expected client_bind(socket_fd_t socket_fd); + /*! * \brief Wait for a packet * \param socket_fd The file descriptor of the packet diff --git a/kernel/include/net/socket.hpp b/kernel/include/net/socket.hpp index d9f7b005..f2443350 100644 --- a/kernel/include/net/socket.hpp +++ b/kernel/include/net/socket.hpp @@ -29,6 +29,7 @@ struct socket { socket_protocol protocol; size_t next_fd; bool listen; + size_t local_port; std::vector packets; @@ -37,7 +38,7 @@ struct socket { socket(){} socket(size_t id, socket_domain domain, socket_type type, socket_protocol protocol, size_t next_fd, bool listen) - : id(id), domain(domain), type(type), protocol(protocol), next_fd(next_fd), listen(listen) {} + : id(id), domain(domain), type(type), protocol(protocol), next_fd(next_fd), listen(listen), local_port(0) {} void invalidate(){ id = 0xFFFFFFFF; diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index f1d66c31..cc5accdd 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -28,6 +28,9 @@ namespace { +// TODO need to be atomic! +size_t local_port = 1234; + std::vector interfaces; void rx_thread(void* data){ @@ -211,18 +214,26 @@ network::interface_descriptor& network::interface(size_t index){ } std::expected network::open(network::socket_domain domain, network::socket_type type, network::socket_protocol protocol){ + // Make sure the socket domain is valid if(domain != socket_domain::AF_INET){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_DOMAIN); } - if(type != socket_type::RAW){ + // Make sure the socket type is valid + if(type != socket_type::RAW && type != socket_type::DGRAM){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE); } + // Make sure the socket protocol is valid if(protocol != socket_protocol::ICMP && protocol != socket_protocol::DNS){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_PROTOCOL); } + // Make sure the socket protocol is valid for the given socket type + if(type == socket_type::DGRAM && !(protocol == socket_protocol::DNS)){ + return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); + } + return scheduler::register_new_socket(domain, type, protocol); } @@ -322,6 +333,24 @@ std::expected network::listen(socket_fd_t socket_fd, bool listen){ return std::make_expected(); } +std::expected network::client_bind(socket_fd_t socket_fd){ + if(!scheduler::has_socket(socket_fd)){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_FD); + } + + auto& socket = scheduler::get_socket(socket_fd); + + if(socket.type != socket_type::DGRAM){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); + } + + socket.local_port = local_port++; + + logging::logf(logging::log_level::TRACE, "network: %u socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port); + + return std::make_expected(socket.local_port); +} + std::expected network::wait_for_packet(char* buffer, socket_fd_t socket_fd){ if(!scheduler::has_socket(socket_fd)){ return std::make_unexpected(std::ERROR_SOCKET_INVALID_FD); diff --git a/kernel/src/system_calls.cpp b/kernel/src/system_calls.cpp index 5c110bb8..b890b1a2 100644 --- a/kernel/src/system_calls.cpp +++ b/kernel/src/system_calls.cpp @@ -417,6 +417,13 @@ void sc_listen(interrupt::syscall_regs* regs){ regs->rax = expected_to_i64(status); } +void sc_client_bind(interrupt::syscall_regs* regs){ + auto socket_fd = regs->rbx; + + auto status = network::client_bind(socket_fd); + regs->rax = expected_to_i64(status); +} + void sc_wait_for_packet(interrupt::syscall_regs* regs){ auto socket_fd = regs->rbx; auto user_buffer = reinterpret_cast(regs->rcx); @@ -681,6 +688,10 @@ void system_call_entry(interrupt::syscall_regs* regs){ sc_wait_for_packet_ms(regs); break; + case 0x3007: + sc_client_bind(regs); + break; + // Special system calls case 0x6666: diff --git a/programs/nslookup/src/main.cpp b/programs/nslookup/src/main.cpp index bc310ca5..c041221c 100644 --- a/programs/nslookup/src/main.cpp +++ b/programs/nslookup/src/main.cpp @@ -37,13 +37,9 @@ int main(int argc, char* argv[]) { std::string domain(argv[1]); - tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::RAW, tlib::socket_protocol::DNS); - - if (!sock) { - tlib::printf("nslookup: socket error: %s\n", std::error_message(sock.error())); - return 1; - } + tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::DGRAM, tlib::socket_protocol::DNS); + sock.client_bind(); sock.listen(true); if (!sock) { diff --git a/tlib/include/tlib/errors.hpp b/tlib/include/tlib/errors.hpp index 7c559d20..e657e175 100644 --- a/tlib/include/tlib/errors.hpp +++ b/tlib/include/tlib/errors.hpp @@ -41,6 +41,7 @@ constexpr const size_t ERROR_SOCKET_INVALID_PACKET_FD = 26; constexpr const size_t ERROR_SOCKET_NOT_LISTEN = 27; constexpr const size_t ERROR_SOCKET_TIMEOUT = 28; constexpr const size_t ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR = 29; +constexpr const size_t ERROR_SOCKET_INVALID_TYPE_PROTOCOL = 30; inline const char* error_message(size_t error){ switch(error){ @@ -102,6 +103,8 @@ inline const char* error_message(size_t error){ return "Network timeout"; case ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR: return "The packet descriptor for the packet to send is invalid"; + case ERROR_SOCKET_INVALID_TYPE_PROTOCOL: + return "The socket protocol is not vaild with this type"; default: return "Unknonwn error"; } diff --git a/tlib/include/tlib/net.hpp b/tlib/include/tlib/net.hpp index 5aa30a2f..373c48df 100644 --- a/tlib/include/tlib/net.hpp +++ b/tlib/include/tlib/net.hpp @@ -41,6 +41,7 @@ void socket_close(size_t socket_fd); std::expected prepare_packet(size_t socket_fd, void* desc); std::expected finalize_packet(size_t socket_fd, const packet& p); std::expected listen(size_t socket_fd, bool l); +std::expected client_bind(size_t socket_fd); std::expected wait_for_packet(size_t socket_fd); std::expected wait_for_packet(size_t socket_fd, size_t ms); @@ -77,6 +78,11 @@ struct socket { */ void clear(); + /*! + * \brief Bind the socket as a client + */ + void client_bind(); + void listen(bool l); packet prepare_packet(void* desc); @@ -90,6 +96,7 @@ private: socket_protocol protocol; ///< The socket protocol size_t fd; ///< The socket file descriptor size_t error_code; ///< The error code + size_t local_port; ///< The local port }; inline uint16_t switch_endian_16(uint16_t nb) { diff --git a/tlib/include/tlib/net_constants.hpp b/tlib/include/tlib/net_constants.hpp index 5057d1b0..0ecdb2e1 100644 --- a/tlib/include/tlib/net_constants.hpp +++ b/tlib/include/tlib/net_constants.hpp @@ -113,7 +113,8 @@ enum class socket_domain : size_t { }; enum class socket_type : size_t { - RAW + RAW, + DGRAM }; enum class socket_protocol : size_t { diff --git a/tlib/src/net.cpp b/tlib/src/net.cpp index 5aeab1ff..fef43a4b 100644 --- a/tlib/src/net.cpp +++ b/tlib/src/net.cpp @@ -36,7 +36,7 @@ tlib::packet::~packet() { std::expected tlib::socket_open(socket_domain domain, socket_type type, socket_protocol protocol) { int64_t fd; - asm volatile("mov rax, 0x3000; mov rbx, %[type]; mov rcx, %[type]; mov rdx, %[protocol]; int 50; mov %[fd], rax" + asm volatile("mov rax, 0x3000; mov rbx, %[domain]; mov rcx, %[type]; mov rdx, %[protocol]; int 50; mov %[fd], rax" : [fd] "=m"(fd) : [domain] "g"(static_cast(domain)), [type] "g"(static_cast(type)), [protocol] "g"(static_cast(protocol)) : "rax", "rbx", "rcx", "rdx"); @@ -106,6 +106,20 @@ std::expected tlib::listen(size_t socket_fd, bool l) { } } +std::expected tlib::client_bind(size_t socket_fd) { + int64_t code; + asm volatile("mov rax, 0x3007; mov rbx, %[socket]; int 50; mov %[code], rax" + : [code] "=m"(code) + : [socket] "g"(socket_fd) + : "rax", "rbx", "rcx"); + + if (code < 0) { + return std::make_unexpected(-code); + } else { + return std::make_expected(code); + } +} + std::expected tlib::wait_for_packet(size_t socket_fd) { auto buffer = malloc(2048); @@ -157,6 +171,8 @@ tlib::socket::socket(socket_domain domain, socket_type type, socket_protocol pro } else { error_code = open_status.error(); } + + local_port = 0; } tlib::socket::~socket() { @@ -196,6 +212,19 @@ void tlib::socket::listen(bool l) { } } +void tlib::socket::client_bind() { + if (!good() || !open()) { + return; + } + + auto status = tlib::client_bind(fd); + if (!status) { + error_code = status.error(); + } + + local_port = *status; +} + tlib::packet tlib::socket::prepare_packet(void* desc) { if (!good() || !open()) { return tlib::packet();