diff --git a/kernel/include/net/socket.hpp b/kernel/include/net/socket.hpp index f2443350..cc719afe 100644 --- a/kernel/include/net/socket.hpp +++ b/kernel/include/net/socket.hpp @@ -29,7 +29,9 @@ struct socket { socket_protocol protocol; size_t next_fd; bool listen; - size_t local_port; + bool connected; + uint32_t local_port; + uint32_t server_port; std::vector packets; @@ -38,7 +40,7 @@ struct socket { socket(){} socket(size_t id, socket_domain domain, socket_type type, socket_protocol protocol, size_t next_fd, bool listen) - : id(id), domain(domain), type(type), protocol(protocol), next_fd(next_fd), listen(listen), local_port(0) {} + : id(id), domain(domain), type(type), protocol(protocol), next_fd(next_fd), listen(listen) {} void invalidate(){ id = 0xFFFFFFFF; diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index d58b1ea3..e898517c 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -257,7 +257,15 @@ std::expected network::open(network::socket_domain domain, return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); } - return scheduler::register_new_socket(domain, type, protocol); + auto socket_fd = scheduler::register_new_socket(domain, type, protocol); + + // Initialize TCP connection values + auto& socket = scheduler::get_socket(socket_fd); + socket.connected = false; + socket.local_port = 0; + socket.server_port = 0; + + return socket_fd; } void network::close(size_t fd){ @@ -394,12 +402,19 @@ std::expected network::connect(socket_fd_t socket_fd, network::ip::addre return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE); } - socket.local_port = local_port++; + socket.local_port = local_port++; + socket.server_port = port; logging::logf(logging::log_level::TRACE, "network: %u stream socket %u was assigned port %u\n", scheduler::get_pid(), socket_fd, socket.local_port); if(socket.protocol == socket_protocol::TCP){ - network::tcp::connect(select_interface(server), server, socket.local_port, port); + auto connection = network::tcp::connect(select_interface(server), server, socket.local_port, port); + + if(connection){ + socket.connected = true; + } else { + return std::make_unexpected(connection.error()); + } } else { return std::make_unexpected(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); }