From 5be0a1a22b70b7816de5c93505e07426e67bb290 Mon Sep 17 00:00:00 2001 From: Baptiste Wicht Date: Tue, 13 Sep 2016 22:10:08 +0200 Subject: [PATCH] Prepare TCP connect --- kernel/include/net/network.hpp | 9 +++++++++ kernel/src/net/network.cpp | 31 ++++++++++++++++++++++++++++--- kernel/src/system_calls.cpp | 13 +++++++++++++ programs/nc/src/main.cpp | 10 ++++++++++ tlib/include/tlib/net.hpp | 3 +++ tlib/src/net.cpp | 27 +++++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 3 deletions(-) diff --git a/kernel/include/net/network.hpp b/kernel/include/net/network.hpp index 51e5ec34..9611f000 100644 --- a/kernel/include/net/network.hpp +++ b/kernel/include/net/network.hpp @@ -112,6 +112,15 @@ std::expected listen(socket_fd_t socket_fd, bool listen); */ std::expected client_bind(socket_fd_t socket_fd); +/*! + * \brief Bind a socket stream as a client (bind a local random port) + * \param socket_fd The file descriptor of the packet + * \param server The ip address of the server + * \param port The port of the server + * \return the allocated port on success and a negative error code otherwise + */ +std::expected connect(socket_fd_t socket_fd, network::ip::address address, size_t port); + /*! * \brief Wait for a packet * \param socket_fd The file descriptor of the packet diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index 6a5d4947..10c8b33c 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -237,12 +237,12 @@ std::expected network::open(network::socket_domain domain, } // Make sure the socket type is valid - if(type != socket_type::RAW && type != socket_type::DGRAM){ + if(type != socket_type::RAW && type != socket_type::DGRAM && type != socket_type::STREAM){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE); } // Make sure the socket protocol is valid - if(protocol != socket_protocol::ICMP && protocol != socket_protocol::DNS){ + if(protocol != socket_protocol::ICMP && protocol != socket_protocol::DNS && protocol != socket_protocol::TCP){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_PROTOCOL); } @@ -251,6 +251,11 @@ std::expected network::open(network::socket_domain domain, return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); } + // Make sure the socket protocol is valid for the given socket type + if(type == socket_type::STREAM && !(protocol == socket_protocol::TCP)){ + return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); + } + return scheduler::register_new_socket(domain, type, protocol); } @@ -372,7 +377,27 @@ std::expected network::client_bind(socket_fd_t socket_fd){ socket.local_port = local_port++; - logging::logf(logging::log_level::TRACE, "network: %u socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port); + logging::logf(logging::log_level::TRACE, "network: %u datagram socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port); + + return std::make_expected(socket.local_port); +} + +std::expected network::connect(socket_fd_t socket_fd, network::ip::address server, size_t port){ + if(!scheduler::has_socket(socket_fd)){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_FD); + } + + auto& socket = scheduler::get_socket(socket_fd); + + if(socket.type != socket_type::STREAM){ + return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); + } + + socket.local_port = local_port++; + + logging::logf(logging::log_level::TRACE, "network: %u stream socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port); + + //TODO TCP connect return std::make_expected(socket.local_port); } diff --git a/kernel/src/system_calls.cpp b/kernel/src/system_calls.cpp index b890b1a2..54d85a6c 100644 --- a/kernel/src/system_calls.cpp +++ b/kernel/src/system_calls.cpp @@ -424,6 +424,15 @@ void sc_client_bind(interrupt::syscall_regs* regs){ regs->rax = expected_to_i64(status); } +void sc_connect(interrupt::syscall_regs* regs){ + auto socket_fd = regs->rbx; + auto ip = regs->rcx; + auto port = regs->rdx; + + auto status = network::connect(socket_fd, ip, port); + regs->rax = expected_to_i64(status); +} + void sc_wait_for_packet(interrupt::syscall_regs* regs){ auto socket_fd = regs->rbx; auto user_buffer = reinterpret_cast(regs->rcx); @@ -692,6 +701,10 @@ void system_call_entry(interrupt::syscall_regs* regs){ sc_client_bind(regs); break; + case 0x3008: + sc_connect(regs); + break; + // Special system calls case 0x6666: diff --git a/programs/nc/src/main.cpp b/programs/nc/src/main.cpp index fe3bdca1..1cc08987 100644 --- a/programs/nc/src/main.cpp +++ b/programs/nc/src/main.cpp @@ -27,8 +27,18 @@ int main(int argc, char* argv[]) { std::string port_str(argv[2]); auto port = std::atoui(port_str); + auto ip_parts = std::split(server, '.'); + + if (ip_parts.size() != 4) { + tlib::print_line("Invalid address IP for the server"); + return 1; + } + + auto server_ip = tlib::ip::make_address(std::atoui(ip_parts[0]), std::atoui(ip_parts[1]), std::atoui(ip_parts[2]), std::atoui(ip_parts[3])); + tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::STREAM, tlib::socket_protocol::TCP); + sock.connect(server_ip, port); sock.listen(true); if (!sock) { diff --git a/tlib/include/tlib/net.hpp b/tlib/include/tlib/net.hpp index 373c48df..8f14ddb8 100644 --- a/tlib/include/tlib/net.hpp +++ b/tlib/include/tlib/net.hpp @@ -42,6 +42,7 @@ std::expected prepare_packet(size_t socket_fd, void* desc); std::expected finalize_packet(size_t socket_fd, const packet& p); std::expected listen(size_t socket_fd, bool l); std::expected client_bind(size_t socket_fd); +std::expected connect(size_t socket_fd, tlib::ip::address server, size_t port); std::expected wait_for_packet(size_t socket_fd); std::expected wait_for_packet(size_t socket_fd, size_t ms); @@ -83,6 +84,8 @@ struct socket { */ void client_bind(); + void connect(tlib::ip::address server, size_t port); + void listen(bool l); packet prepare_packet(void* desc); diff --git a/tlib/src/net.cpp b/tlib/src/net.cpp index fef43a4b..9f6f738f 100644 --- a/tlib/src/net.cpp +++ b/tlib/src/net.cpp @@ -120,6 +120,20 @@ std::expected tlib::client_bind(size_t socket_fd) { } } +std::expected tlib::connect(size_t socket_fd, tlib::ip::address server, size_t port) { + int64_t code; + asm volatile("mov rax, 0x3008; mov rbx, %[socket]; mov rcx, %[ip]; mov rdx, %[port]; int 50; mov %[code], rax" + : [code] "=m"(code) + : [socket] "g" (socket_fd), [ip] "g" (size_t(server.raw_address)), [port] "g" (port) + : "rax", "rbx", "rcx", "rdx"); + + if (code < 0) { + return std::make_unexpected(-code); + } else { + return std::make_expected(code); + } +} + std::expected tlib::wait_for_packet(size_t socket_fd) { auto buffer = malloc(2048); @@ -225,6 +239,19 @@ void tlib::socket::client_bind() { local_port = *status; } +void tlib::socket::connect(tlib::ip::address server, size_t port) { + if (!good() || !open()) { + return; + } + + auto status = tlib::connect(fd, server, port); + if (!status) { + error_code = status.error(); + } + + local_port = *status; +} + tlib::packet tlib::socket::prepare_packet(void* desc) { if (!good() || !open()) { return tlib::packet();